Skip to content

Commit 77caff9

Browse files
committed
Add query/select
1 parent 90eb58d commit 77caff9

File tree

12 files changed

+500
-36
lines changed

12 files changed

+500
-36
lines changed

tokio-postgres/src/client.rs

+53-8
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
use crate::codec::BackendMessages;
22
use crate::connection::{Request, RequestMessages};
33
use crate::prepare::prepare;
4-
use crate::types::{Oid, Type};
4+
use crate::query::{execute, query, Query};
5+
use crate::types::{Oid, ToSql, Type};
56
use crate::{Error, Statement};
67
use fallible_iterator::FallibleIterator;
78
use futures::channel::mpsc;
8-
use futures::{Stream, StreamExt};
9+
use futures::future;
10+
use futures::{ready, StreamExt};
911
use parking_lot::Mutex;
1012
use postgres_protocol::message::backend::Message;
1113
use std::collections::HashMap;
1214
use std::future::Future;
13-
use std::pin::Pin;
1415
use std::sync::Arc;
1516
use std::task::{Context, Poll};
1617

@@ -20,20 +21,24 @@ pub struct Responses {
2021
}
2122

2223
impl Responses {
23-
pub async fn next(&mut self) -> Result<Message, Error> {
24+
pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
2425
loop {
2526
match self.cur.next().map_err(Error::parse)? {
26-
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
27-
Some(message) => return Ok(message),
27+
Some(Message::ErrorResponse(body)) => return Poll::Ready(Err(Error::db(body))),
28+
Some(message) => return Poll::Ready(Ok(message)),
2829
None => {}
2930
}
3031

31-
match self.receiver.next().await {
32+
match ready!(self.receiver.poll_next_unpin(cx)) {
3233
Some(messages) => self.cur = messages,
33-
None => return Err(Error::closed()),
34+
None => return Poll::Ready(Err(Error::closed())),
3435
}
3536
}
3637
}
38+
39+
pub async fn next(&mut self) -> Result<Message, Error> {
40+
future::poll_fn(|cx| self.poll_next(cx)).await
41+
}
3742
}
3843

3944
struct State {
@@ -140,4 +145,44 @@ impl Client {
140145
) -> impl Future<Output = Result<Statement, Error>> + 'a {
141146
prepare(self.inner(), query, parameter_types)
142147
}
148+
149+
pub fn query<'a>(
150+
&mut self,
151+
statement: &'a Statement,
152+
params: &'a [&dyn ToSql],
153+
) -> impl Future<Output = Result<Query, Error>> + 'a {
154+
self.query_iter(statement, params.iter().cloned())
155+
}
156+
157+
pub fn query_iter<'a, I>(
158+
&mut self,
159+
statement: &'a Statement,
160+
params: I,
161+
) -> impl Future<Output = Result<Query, Error>> + 'a
162+
where
163+
I: IntoIterator<Item = &'a dyn ToSql> + 'a,
164+
I::IntoIter: ExactSizeIterator,
165+
{
166+
query(self.inner(), statement, params)
167+
}
168+
169+
pub fn execute<'a>(
170+
&mut self,
171+
statement: &'a Statement,
172+
params: &'a [&dyn ToSql],
173+
) -> impl Future<Output = Result<u64, Error>> + 'a {
174+
self.execute_iter(statement, params.iter().cloned())
175+
}
176+
177+
pub fn execute_iter<'a, I>(
178+
&mut self,
179+
statement: &'a Statement,
180+
params: I,
181+
) -> impl Future<Output = Result<u64, Error>> + 'a
182+
where
183+
I: IntoIterator<Item = &'a dyn ToSql> + 'a,
184+
I::IntoIter: ExactSizeIterator,
185+
{
186+
execute(self.inner(), statement, params)
187+
}
143188
}

tokio-postgres/src/connection.rs

-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ use postgres_protocol::message::backend::Message;
1010
use postgres_protocol::message::frontend;
1111
use std::collections::{HashMap, VecDeque};
1212
use std::future::Future;
13-
use std::io;
1413
use std::pin::Pin;
1514
use std::task::{Context, Poll};
1615
use tokio::codec::Framed;

tokio-postgres/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ pub use crate::config::Config;
117117
pub use crate::connection::Connection;
118118
use crate::error::DbError;
119119
pub use crate::error::Error;
120+
pub use crate::row::{Row, SimpleQueryRow};
120121
#[cfg(feature = "runtime")]
121122
pub use crate::socket::Socket;
122123
#[cfg(feature = "runtime")]
@@ -137,6 +138,8 @@ mod connection;
137138
pub mod error;
138139
mod maybe_tls_stream;
139140
mod prepare;
141+
mod query;
142+
pub mod row;
140143
#[cfg(feature = "runtime")]
141144
mod socket;
142145
mod statement;

tokio-postgres/src/maybe_tls_stream.rs

+3-5
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,9 @@ where
5858
cx: &mut Context<'_>,
5959
buf: &[u8],
6060
) -> Poll<io::Result<usize>> {
61-
unsafe {
62-
match &mut *self {
63-
MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf),
64-
MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
65-
}
61+
match &mut *self {
62+
MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf),
63+
MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
6664
}
6765
}
6866

tokio-postgres/src/prepare.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
use crate::client::InnerClient;
22
use crate::codec::FrontendMessage;
3-
use crate::connection::{Request, RequestMessages};
3+
use crate::connection::RequestMessages;
44
use crate::types::{Oid, Type};
55
use crate::{Column, Error, Statement};
66
use fallible_iterator::FallibleIterator;
7-
use futures::StreamExt;
87
use postgres_protocol::message::backend::Message;
98
use postgres_protocol::message::frontend;
109
use std::sync::atomic::{AtomicUsize, Ordering};

tokio-postgres/src/query.rs

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
use crate::client::{InnerClient, Responses};
2+
use crate::codec::FrontendMessage;
3+
use crate::connection::RequestMessages;
4+
use crate::types::{IsNull, ToSql};
5+
use crate::{Error, Row, Statement};
6+
use futures::{ready, Stream};
7+
use postgres_protocol::message::backend::Message;
8+
use postgres_protocol::message::frontend;
9+
use std::future::Future;
10+
use std::pin::Pin;
11+
use std::sync::Arc;
12+
use std::task::{Context, Poll};
13+
14+
pub async fn query<'a, I>(
15+
client: Arc<InnerClient>,
16+
statement: &Statement,
17+
params: I,
18+
) -> Result<Query, Error>
19+
where
20+
I: IntoIterator<Item = &'a dyn ToSql>,
21+
I::IntoIter: ExactSizeIterator,
22+
{
23+
let responses = start(&client, &statement, params).await?;
24+
25+
Ok(Query {
26+
statement: statement.clone(),
27+
responses,
28+
})
29+
}
30+
31+
pub async fn execute<'a, I>(
32+
client: Arc<InnerClient>,
33+
statement: &Statement,
34+
params: I,
35+
) -> Result<u64, Error>
36+
where
37+
I: IntoIterator<Item = &'a dyn ToSql>,
38+
I::IntoIter: ExactSizeIterator,
39+
{
40+
let mut responses = start(&client, &statement, params).await?;
41+
42+
loop {
43+
match responses.next().await? {
44+
Message::DataRow(_) => {}
45+
Message::CommandComplete(body) => {
46+
let rows = body
47+
.tag()
48+
.map_err(Error::parse)?
49+
.rsplit(' ')
50+
.next()
51+
.unwrap()
52+
.parse()
53+
.unwrap_or(0);
54+
return Ok(rows);
55+
}
56+
Message::EmptyQueryResponse => return Ok(0),
57+
_ => return Err(Error::unexpected_message()),
58+
}
59+
}
60+
}
61+
62+
async fn start<'a, I>(
63+
client: &Arc<InnerClient>,
64+
statement: &Statement,
65+
params: I,
66+
) -> Result<Responses, Error>
67+
where
68+
I: IntoIterator<Item = &'a dyn ToSql>,
69+
I::IntoIter: ExactSizeIterator,
70+
{
71+
let params = params.into_iter();
72+
73+
assert!(
74+
statement.params().len() == params.len(),
75+
"expected {} parameters but got {}",
76+
statement.params().len(),
77+
params.len()
78+
);
79+
80+
let mut buf = vec![];
81+
82+
let mut error_idx = 0;
83+
let r = frontend::bind(
84+
"",
85+
statement.name(),
86+
Some(1),
87+
params.zip(statement.params()).enumerate(),
88+
|(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) {
89+
Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
90+
Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
91+
Err(e) => {
92+
error_idx = idx;
93+
Err(e)
94+
}
95+
},
96+
Some(1),
97+
&mut buf,
98+
);
99+
match r {
100+
Ok(()) => {}
101+
Err(frontend::BindError::Conversion(e)) => return Err(Error::to_sql(e, error_idx)),
102+
Err(frontend::BindError::Serialization(e)) => return Err(Error::encode(e)),
103+
}
104+
105+
frontend::execute("", 0, &mut buf).map_err(Error::encode)?;
106+
frontend::sync(&mut buf);
107+
108+
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
109+
110+
match responses.next().await? {
111+
Message::BindComplete => {}
112+
_ => return Err(Error::unexpected_message()),
113+
}
114+
115+
Ok(responses)
116+
}
117+
118+
pub struct Query {
119+
statement: Statement,
120+
responses: Responses,
121+
}
122+
123+
impl Stream for Query {
124+
type Item = Result<Row, Error>;
125+
126+
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
127+
match ready!(self.responses.poll_next(cx)?) {
128+
Message::DataRow(body) => {
129+
Poll::Ready(Some(Ok(Row::new(self.statement.clone(), body)?)))
130+
}
131+
Message::EmptyQueryResponse | Message::CommandComplete(_) => Poll::Ready(None),
132+
Message::ErrorResponse(body) => Poll::Ready(Some(Err(Error::db(body)))),
133+
_ => Poll::Ready(Some(Err(Error::unexpected_message()))),
134+
}
135+
}
136+
}

0 commit comments

Comments
 (0)