Skip to content

Commit 26a17ac

Browse files
committed
Support portals
1 parent e4a1ec2 commit 26a17ac

File tree

6 files changed

+225
-10
lines changed

6 files changed

+225
-10
lines changed

tokio-postgres/src/bind.rs

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
use crate::client::InnerClient;
2+
use crate::codec::FrontendMessage;
3+
use crate::connection::RequestMessages;
4+
use crate::types::ToSql;
5+
use crate::{query, Error, Portal, Statement};
6+
use postgres_protocol::message::backend::Message;
7+
use postgres_protocol::message::frontend;
8+
use std::sync::atomic::{AtomicUsize, Ordering};
9+
use std::sync::Arc;
10+
11+
static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
12+
13+
pub async fn bind(
14+
client: Arc<InnerClient>,
15+
statement: Statement,
16+
bind: Result<PendingBind, Error>,
17+
) -> Result<Portal, Error> {
18+
let bind = bind?;
19+
20+
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(bind.buf)))?;
21+
22+
match responses.next().await? {
23+
Message::BindComplete => {}
24+
_ => return Err(Error::unexpected_message()),
25+
}
26+
27+
Ok(Portal::new(&client, bind.name, statement))
28+
}
29+
30+
pub struct PendingBind {
31+
buf: Vec<u8>,
32+
name: String,
33+
}
34+
35+
pub fn encode<'a, I>(statement: &Statement, params: I) -> Result<PendingBind, Error>
36+
where
37+
I: IntoIterator<Item = &'a dyn ToSql>,
38+
I::IntoIter: ExactSizeIterator,
39+
{
40+
let name = format!("p{}", NEXT_ID.fetch_add(1, Ordering::SeqCst));
41+
let mut buf = query::encode_bind(statement, params, &name)?;
42+
frontend::sync(&mut buf);
43+
44+
Ok(PendingBind { buf, name })
45+
}

tokio-postgres/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ pub use crate::config::Config;
113113
pub use crate::connection::Connection;
114114
use crate::error::DbError;
115115
pub use crate::error::Error;
116+
pub use crate::portal::Portal;
116117
pub use crate::row::{Row, SimpleQueryRow};
117118
#[cfg(feature = "runtime")]
118119
pub use crate::socket::Socket;
@@ -122,6 +123,7 @@ pub use crate::tls::NoTls;
122123
pub use crate::transaction::Transaction;
123124
pub use statement::{Column, Statement};
124125

126+
mod bind;
125127
#[cfg(feature = "runtime")]
126128
mod cancel_query;
127129
mod cancel_query_raw;
@@ -139,6 +141,7 @@ mod copy_in;
139141
mod copy_out;
140142
pub mod error;
141143
mod maybe_tls_stream;
144+
mod portal;
142145
mod prepare;
143146
mod query;
144147
pub mod row;

tokio-postgres/src/portal.rs

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
use crate::client::InnerClient;
2+
use crate::codec::FrontendMessage;
3+
use crate::connection::RequestMessages;
4+
use crate::Statement;
5+
use postgres_protocol::message::frontend;
6+
use std::sync::{Arc, Weak};
7+
8+
struct Inner {
9+
client: Weak<InnerClient>,
10+
name: String,
11+
statement: Statement,
12+
}
13+
14+
impl Drop for Inner {
15+
fn drop(&mut self) {
16+
if let Some(client) = self.client.upgrade() {
17+
let mut buf = vec![];
18+
frontend::close(b'P', &self.name, &mut buf).expect("portal name not valid");
19+
frontend::sync(&mut buf);
20+
let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
21+
}
22+
}
23+
}
24+
25+
/// A portal.
26+
///
27+
/// Portals can only be used with the connection that created them, and only exist for the duration of the transaction
28+
/// in which they were created.
29+
#[derive(Clone)]
30+
pub struct Portal(Arc<Inner>);
31+
32+
impl Portal {
33+
pub(crate) fn new(client: &Arc<InnerClient>, name: String, statement: Statement) -> Portal {
34+
Portal(Arc::new(Inner {
35+
client: Arc::downgrade(client),
36+
name,
37+
statement,
38+
}))
39+
}
40+
41+
pub(crate) fn name(&self) -> &str {
42+
&self.0.name
43+
}
44+
45+
pub(crate) fn statement(&self) -> &Statement {
46+
&self.0.statement
47+
}
48+
}

tokio-postgres/src/query.rs

+39-9
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::client::{InnerClient, Responses};
22
use crate::codec::FrontendMessage;
33
use crate::connection::RequestMessages;
44
use crate::types::{IsNull, ToSql};
5-
use crate::{Error, Row, Statement};
5+
use crate::{Error, Portal, Row, Statement};
66
use futures::{ready, Stream, TryFutureExt};
77
use postgres_protocol::message::backend::Message;
88
use postgres_protocol::message::frontend;
@@ -23,6 +23,27 @@ pub fn query(
2323
.try_flatten_stream()
2424
}
2525

26+
pub fn query_portal(
27+
client: Arc<InnerClient>,
28+
portal: Portal,
29+
max_rows: i32,
30+
) -> impl Stream<Item = Result<Row, Error>> {
31+
let start = async move {
32+
let mut buf = vec![];
33+
frontend::execute(portal.name(), max_rows, &mut buf).map_err(Error::encode)?;
34+
frontend::sync(&mut buf);
35+
36+
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
37+
38+
Ok(Query {
39+
statement: portal.statement().clone(),
40+
responses,
41+
})
42+
};
43+
44+
start.try_flatten_stream()
45+
}
46+
2647
pub async fn execute(client: Arc<InnerClient>, buf: Result<Vec<u8>, Error>) -> Result<u64, Error> {
2748
let mut responses = start(client, buf).await?;
2849

@@ -59,6 +80,18 @@ async fn start(client: Arc<InnerClient>, buf: Result<Vec<u8>, Error>) -> Result<
5980
}
6081

6182
pub fn encode<'a, I>(statement: &Statement, params: I) -> Result<Vec<u8>, Error>
83+
where
84+
I: IntoIterator<Item = &'a dyn ToSql>,
85+
I::IntoIter: ExactSizeIterator,
86+
{
87+
let mut buf = encode_bind(statement, params, "")?;
88+
frontend::execute("", 0, &mut buf).map_err(Error::encode)?;
89+
frontend::sync(&mut buf);
90+
91+
Ok(buf)
92+
}
93+
94+
pub fn encode_bind<'a, I>(statement: &Statement, params: I, portal: &str) -> Result<Vec<u8>, Error>
6295
where
6396
I: IntoIterator<Item = &'a dyn ToSql>,
6497
I::IntoIter: ExactSizeIterator,
@@ -76,7 +109,7 @@ where
76109

77110
let mut error_idx = 0;
78111
let r = frontend::bind(
79-
"",
112+
portal,
80113
statement.name(),
81114
Some(1),
82115
params.zip(statement.params()).enumerate(),
@@ -92,15 +125,10 @@ where
92125
&mut buf,
93126
);
94127
match r {
95-
Ok(()) => {}
128+
Ok(()) => Ok(buf),
96129
Err(frontend::BindError::Conversion(e)) => return Err(Error::to_sql(e, error_idx)),
97130
Err(frontend::BindError::Serialization(e)) => return Err(Error::encode(e)),
98131
}
99-
100-
frontend::execute("", 0, &mut buf).map_err(Error::encode)?;
101-
frontend::sync(&mut buf);
102-
103-
Ok(buf)
104132
}
105133

106134
struct Query {
@@ -116,7 +144,9 @@ impl Stream for Query {
116144
Message::DataRow(body) => {
117145
Poll::Ready(Some(Ok(Row::new(self.statement.clone(), body)?)))
118146
}
119-
Message::EmptyQueryResponse | Message::CommandComplete(_) => Poll::Ready(None),
147+
Message::EmptyQueryResponse
148+
| Message::CommandComplete(_)
149+
| Message::PortalSuspended => Poll::Ready(None),
120150
Message::ErrorResponse(body) => Poll::Ready(Some(Err(Error::db(body)))),
121151
_ => Poll::Ready(Some(Err(Error::unexpected_message()))),
122152
}

tokio-postgres/src/transaction.rs

+47-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::tls::TlsConnect;
66
use crate::types::{ToSql, Type};
77
#[cfg(feature = "runtime")]
88
use crate::Socket;
9-
use crate::{query, Client, Error, Row, SimpleQueryMessage, Statement};
9+
use crate::{bind, query, Client, Error, Portal, Row, SimpleQueryMessage, Statement};
1010
use bytes::{Bytes, IntoBuf};
1111
use futures::{Stream, TryStream};
1212
use postgres_protocol::message::frontend;
@@ -122,6 +122,52 @@ impl<'a> Transaction<'a> {
122122
query::execute(self.client.inner(), buf)
123123
}
124124

125+
/// Binds a statement to a set of parameters, creating a `Portal` which can be incrementally queried.
126+
///
127+
/// Portals only last for the duration of the transaction in which they are created, and can only be used on the
128+
/// connection that created them.
129+
///
130+
/// # Panics
131+
///
132+
/// Panics if the number of parameters provided does not match the number expected.
133+
pub fn bind(
134+
&mut self,
135+
statement: &Statement,
136+
params: &[&dyn ToSql],
137+
) -> impl Future<Output = Result<Portal, Error>> {
138+
// https://github.com/rust-lang/rust/issues/63032
139+
let buf = bind::encode(statement, params.iter().cloned());
140+
bind::bind(self.client.inner(), statement.clone(), buf)
141+
}
142+
143+
/// Like [`bind`], but takes an iterator of parameters rather than a slice.
144+
///
145+
/// [`bind`]: #method.bind
146+
pub fn bind_iter<'b, I>(
147+
&mut self,
148+
statement: &Statement,
149+
params: I,
150+
) -> impl Future<Output = Result<Portal, Error>>
151+
where
152+
I: IntoIterator<Item = &'b dyn ToSql>,
153+
I::IntoIter: ExactSizeIterator,
154+
{
155+
let buf = bind::encode(statement, params);
156+
bind::bind(self.client.inner(), statement.clone(), buf)
157+
}
158+
159+
/// Continues execution of a portal, returning a stream of the resulting rows.
160+
///
161+
/// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
162+
/// `query_portal`. If the requested number is negative or 0, all rows will be returned.
163+
pub fn query_portal(
164+
&mut self,
165+
portal: &Portal,
166+
max_rows: i32,
167+
) -> impl Stream<Item = Result<Row, Error>> {
168+
query::query_portal(self.client.inner(), portal.clone(), max_rows)
169+
}
170+
125171
/// Like `Client::copy_in`.
126172
pub fn copy_in<S>(
127173
&mut self,

tokio-postgres/tests/test/main.rs

+43
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,49 @@ async fn notifications() {
569569
assert_eq!(notifications[1].payload(), "world");
570570
}
571571

572+
#[tokio::test]
573+
async fn query_portal() {
574+
let mut client = connect("user=postgres").await;
575+
576+
client
577+
.batch_execute(
578+
"CREATE TEMPORARY TABLE foo (
579+
id SERIAL,
580+
name TEXT
581+
);
582+
583+
INSERT INTO foo (name) VALUES ('alice'), ('bob'), ('charlie');",
584+
)
585+
.await
586+
.unwrap();
587+
588+
let stmt = client
589+
.prepare("SELECT id, name FROM foo ORDER BY id")
590+
.await
591+
.unwrap();
592+
593+
let mut transaction = client.transaction().await.unwrap();
594+
595+
let portal = transaction.bind(&stmt, &[]).await.unwrap();
596+
let f1 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
597+
let f2 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
598+
let f3 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
599+
600+
let (r1, r2, r3) = try_join!(f1, f2, f3).unwrap();
601+
602+
assert_eq!(r1.len(), 2);
603+
assert_eq!(r1[0].get::<_, i32>(0), 1);
604+
assert_eq!(r1[0].get::<_, &str>(1), "alice");
605+
assert_eq!(r1[1].get::<_, i32>(0), 2);
606+
assert_eq!(r1[1].get::<_, &str>(1), "bob");
607+
608+
assert_eq!(r2.len(), 1);
609+
assert_eq!(r2[0].get::<_, i32>(0), 3);
610+
assert_eq!(r2[0].get::<_, &str>(1), "charlie");
611+
612+
assert_eq!(r3.len(), 0);
613+
}
614+
572615
/*
573616
#[test]
574617
fn query_portal() {

0 commit comments

Comments
 (0)