Skip to content

Commit b8577b4

Browse files
committed
Overhaul query_portal
1 parent 2517100 commit b8577b4

File tree

8 files changed

+57
-106
lines changed

8 files changed

+57
-106
lines changed

postgres-native-tls/src/test.rs

+3-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use futures::{FutureExt};
1+
use futures::FutureExt;
22
use native_tls::{self, Certificate};
33
use tokio::net::TcpStream;
44
use tokio_postgres::tls::TlsConnect;
@@ -21,10 +21,7 @@ where
2121
tokio::spawn(connection);
2222

2323
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
24-
let rows = client
25-
.query(&stmt, &[&1i32])
26-
.await
27-
.unwrap();
24+
let rows = client.query(&stmt, &[&1i32]).await.unwrap();
2825

2926
assert_eq!(rows.len(), 1);
3027
assert_eq!(rows[0].get::<_, i32>(0), 1);
@@ -96,10 +93,7 @@ async fn runtime() {
9693
tokio::spawn(connection);
9794

9895
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
99-
let rows = client
100-
.query(&stmt, &[&1i32])
101-
.await
102-
.unwrap();
96+
let rows = client.query(&stmt, &[&1i32]).await.unwrap();
10397

10498
assert_eq!(rows.len(), 1);
10599
assert_eq!(rows[0].get::<_, i32>(0), 1);

postgres-openssl/src/test.rs

+3-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use futures::{FutureExt};
1+
use futures::FutureExt;
22
use openssl::ssl::{SslConnector, SslMethod};
33
use tokio::net::TcpStream;
44
use tokio_postgres::tls::TlsConnect;
@@ -19,10 +19,7 @@ where
1919
tokio::spawn(connection);
2020

2121
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
22-
let rows = client
23-
.query(&stmt, &[&1i32])
24-
.await
25-
.unwrap();
22+
let rows = client.query(&stmt, &[&1i32]).await.unwrap();
2623

2724
assert_eq!(rows.len(), 1);
2825
assert_eq!(rows[0].get::<_, i32>(0), 1);
@@ -107,10 +104,7 @@ async fn runtime() {
107104
tokio::spawn(connection);
108105

109106
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
110-
let rows = client
111-
.query(&stmt, &[&1i32])
112-
.await
113-
.unwrap();
107+
let rows = client.query(&stmt, &[&1i32]).await.unwrap();
114108

115109
assert_eq!(rows.len(), 1);
116110
assert_eq!(rows[0].get::<_, i32>(0), 1);

postgres/src/transaction.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -95,17 +95,17 @@ impl<'a> Transaction<'a> {
9595
/// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
9696
/// `query_portal`. If the requested number is negative or 0, all remaining rows will be returned.
9797
pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
98-
self.query_portal_iter(portal, max_rows)?.collect()
98+
executor::block_on(self.0.query_portal(portal, max_rows))
9999
}
100100

101-
/// Like `query_portal`, except that it returns a fallible iterator over the resulting rows rather than buffering
102-
/// the entire response in memory.
103-
pub fn query_portal_iter<'b>(
104-
&'b mut self,
105-
portal: &'b Portal,
101+
/// The maximally flexible version of `query_portal`.
102+
pub fn query_portal_raw(
103+
&mut self,
104+
portal: &Portal,
106105
max_rows: i32,
107-
) -> Result<impl FallibleIterator<Item = Row, Error = Error> + 'b, Error> {
108-
Ok(Iter::new(self.0.query_portal(&portal, max_rows)))
106+
) -> Result<impl FallibleIterator<Item = Row, Error = Error>, Error> {
107+
let stream = executor::block_on(self.0.query_portal_raw(portal, max_rows))?;
108+
Ok(Iter::new(stream))
109109
}
110110

111111
/// Like `Client::copy_in`.

tokio-postgres/src/query.rs

+13-17
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::codec::FrontendMessage;
33
use crate::connection::RequestMessages;
44
use crate::types::{IsNull, ToSql};
55
use crate::{Error, Portal, Row, Statement};
6-
use futures::{ready, Stream, TryFutureExt};
6+
use futures::{ready, Stream};
77
use postgres_protocol::message::backend::Message;
88
use postgres_protocol::message::frontend;
99
use std::pin::Pin;
@@ -26,25 +26,21 @@ where
2626
})
2727
}
2828

29-
pub fn query_portal<'a>(
30-
client: &'a InnerClient,
31-
portal: &'a Portal,
29+
pub async fn query_portal(
30+
client: &InnerClient,
31+
portal: &Portal,
3232
max_rows: i32,
33-
) -> impl Stream<Item = Result<Row, Error>> + 'a {
34-
let start = async move {
35-
let mut buf = vec![];
36-
frontend::execute(portal.name(), max_rows, &mut buf).map_err(Error::encode)?;
37-
frontend::sync(&mut buf);
38-
39-
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
33+
) -> Result<RowStream, Error> {
34+
let mut buf = vec![];
35+
frontend::execute(portal.name(), max_rows, &mut buf).map_err(Error::encode)?;
36+
frontend::sync(&mut buf);
4037

41-
Ok(RowStream {
42-
statement: portal.statement().clone(),
43-
responses,
44-
})
45-
};
38+
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
4639

47-
start.try_flatten_stream()
40+
Ok(RowStream {
41+
statement: portal.statement().clone(),
42+
responses,
43+
})
4844
}
4945

5046
pub async fn execute<'a, I>(

tokio-postgres/src/transaction.rs

+14-6
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::{
1111
bind, query, slice_iter, Client, Error, Portal, Row, SimpleQueryMessage, Statement, ToStatement,
1212
};
1313
use bytes::{Bytes, IntoBuf};
14-
use futures::{Stream, TryStream};
14+
use futures::{Stream, TryStream, TryStreamExt};
1515
use postgres_protocol::message::frontend;
1616
use std::error;
1717
use tokio::io::{AsyncRead, AsyncWrite};
@@ -177,12 +177,20 @@ impl<'a> Transaction<'a> {
177177
///
178178
/// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
179179
/// `query_portal`. If the requested number is negative or 0, all rows will be returned.
180-
pub fn query_portal<'b>(
181-
&'b self,
182-
portal: &'b Portal,
180+
pub async fn query_portal(&self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
181+
self.query_portal_raw(portal, max_rows)
182+
.await?
183+
.try_collect()
184+
.await
185+
}
186+
187+
/// The maximally flexible version of `query_portal`.
188+
pub async fn query_portal_raw(
189+
&self,
190+
portal: &Portal,
183191
max_rows: i32,
184-
) -> impl Stream<Item = Result<Row, Error>> + 'b {
185-
query::query_portal(self.client.inner(), portal, max_rows)
192+
) -> Result<RowStream, Error> {
193+
query::query_portal(self.client.inner(), portal, max_rows).await
186194
}
187195

188196
/// Like `Client::copy_in`.

tokio-postgres/tests/test/main.rs

+8-23
Original file line numberDiff line numberDiff line change
@@ -335,10 +335,7 @@ async fn transaction_commit() {
335335
transaction.commit().await.unwrap();
336336

337337
let stmt = client.prepare("SELECT name FROM foo").await.unwrap();
338-
let rows = client
339-
.query(&stmt, &[])
340-
.await
341-
.unwrap();
338+
let rows = client.query(&stmt, &[]).await.unwrap();
342339

343340
assert_eq!(rows.len(), 1);
344341
assert_eq!(rows[0].get::<_, &str>(0), "steven");
@@ -366,10 +363,7 @@ async fn transaction_rollback() {
366363
transaction.rollback().await.unwrap();
367364

368365
let stmt = client.prepare("SELECT name FROM foo").await.unwrap();
369-
let rows = client
370-
.query(&stmt, &[])
371-
.await
372-
.unwrap();
366+
let rows = client.query(&stmt, &[]).await.unwrap();
373367

374368
assert_eq!(rows.len(), 0);
375369
}
@@ -396,10 +390,7 @@ async fn transaction_rollback_drop() {
396390
drop(transaction);
397391

398392
let stmt = client.prepare("SELECT name FROM foo").await.unwrap();
399-
let rows = client
400-
.query(&stmt, &[])
401-
.await
402-
.unwrap();
393+
let rows = client.query(&stmt, &[]).await.unwrap();
403394

404395
assert_eq!(rows.len(), 0);
405396
}
@@ -431,10 +422,7 @@ async fn copy_in() {
431422
.prepare("SELECT id, name FROM foo ORDER BY id")
432423
.await
433424
.unwrap();
434-
let rows = client
435-
.query(&stmt, &[])
436-
.await
437-
.unwrap();
425+
let rows = client.query(&stmt, &[]).await.unwrap();
438426

439427
assert_eq!(rows.len(), 2);
440428
assert_eq!(rows[0].get::<_, i32>(0), 1);
@@ -497,10 +485,7 @@ async fn copy_in_error() {
497485
.prepare("SELECT id, name FROM foo ORDER BY id")
498486
.await
499487
.unwrap();
500-
let rows = client
501-
.query(&stmt, &[])
502-
.await
503-
.unwrap();
488+
let rows = client.query(&stmt, &[]).await.unwrap();
504489
assert_eq!(rows.len(), 0);
505490
}
506491

@@ -583,9 +568,9 @@ async fn query_portal() {
583568
let transaction = client.transaction().await.unwrap();
584569

585570
let portal = transaction.bind(&stmt, &[]).await.unwrap();
586-
let f1 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
587-
let f2 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
588-
let f3 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
571+
let f1 = transaction.query_portal(&portal, 2);
572+
let f2 = transaction.query_portal(&portal, 2);
573+
let f3 = transaction.query_portal(&portal, 2);
589574

590575
let (r1, r2, r3) = try_join!(f1, f2, f3).unwrap();
591576

tokio-postgres/tests/test/runtime.rs

+1-4
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@ async fn smoke_test(s: &str) {
1616
let client = connect(s).await;
1717

1818
let stmt = client.prepare("SELECT $1::INT").await.unwrap();
19-
let rows = client
20-
.query(&stmt, &[&1i32])
21-
.await
22-
.unwrap();
19+
let rows = client.query(&stmt, &[&1i32]).await.unwrap();
2320
assert_eq!(rows[0].get::<_, i32>(0), 1i32);
2421
}
2522

tokio-postgres/tests/test/types/mod.rs

+7-30
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,7 @@ async fn test_borrowed_text() {
195195
let client = connect("user=postgres").await;
196196

197197
let stmt = client.prepare("SELECT 'foo'").await.unwrap();
198-
let rows = client
199-
.query(&stmt, &[])
200-
.await
201-
.unwrap();
198+
let rows = client.query(&stmt, &[]).await.unwrap();
202199
let s: &str = rows[0].get(0);
203200
assert_eq!(s, "foo");
204201
}
@@ -298,10 +295,7 @@ async fn test_bytea_params() {
298295
async fn test_borrowed_bytea() {
299296
let client = connect("user=postgres").await;
300297
let stmt = client.prepare("SELECT 'foo'::BYTEA").await.unwrap();
301-
let rows = client
302-
.query(&stmt, &[])
303-
.await
304-
.unwrap();
298+
let rows = client.query(&stmt, &[]).await.unwrap();
305299
let s: &[u8] = rows[0].get(0);
306300
assert_eq!(s, b"foo");
307301
}
@@ -360,10 +354,7 @@ where
360354
.prepare(&format!("SELECT 'NaN'::{}", sql_type))
361355
.await
362356
.unwrap();
363-
let rows = client
364-
.query(&stmt, &[])
365-
.await
366-
.unwrap();
357+
let rows = client.query(&stmt, &[]).await.unwrap();
367358
let val: T = rows[0].get(0);
368359
assert!(val != val);
369360
}
@@ -385,10 +376,7 @@ async fn test_pg_database_datname() {
385376
.prepare("SELECT datname FROM pg_database")
386377
.await
387378
.unwrap();
388-
let rows = client
389-
.query(&stmt, &[])
390-
.await
391-
.unwrap();
379+
let rows = client.query(&stmt, &[]).await.unwrap();
392380
assert_eq!(rows[0].get::<_, &str>(0), "postgres");
393381
}
394382

@@ -439,11 +427,7 @@ async fn test_slice_wrong_type() {
439427
.prepare("SELECT * FROM foo WHERE id = ANY($1)")
440428
.await
441429
.unwrap();
442-
let err = client
443-
.query(&stmt, &[&&[&"hi"][..]])
444-
.await
445-
.err()
446-
.unwrap();
430+
let err = client.query(&stmt, &[&&[&"hi"][..]]).await.err().unwrap();
447431
match err.source() {
448432
Some(e) if e.is::<WrongType>() => {}
449433
_ => panic!("Unexpected error {:?}", err),
@@ -455,11 +439,7 @@ async fn test_slice_range() {
455439
let client = connect("user=postgres").await;
456440

457441
let stmt = client.prepare("SELECT $1::INT8RANGE").await.unwrap();
458-
let err = client
459-
.query(&stmt, &[&&[&1i64][..]])
460-
.await
461-
.err()
462-
.unwrap();
442+
let err = client.query(&stmt, &[&&[&1i64][..]]).await.err().unwrap();
463443
match err.source() {
464444
Some(e) if e.is::<WrongType>() => {}
465445
_ => panic!("Unexpected error {:?}", err),
@@ -527,10 +507,7 @@ async fn domain() {
527507
client.execute(&stmt, &[&id]).await.unwrap();
528508

529509
let stmt = client.prepare("SELECT id FROM pg_temp.foo").await.unwrap();
530-
let rows = client
531-
.query(&stmt, &[])
532-
.await
533-
.unwrap();
510+
let rows = client.query(&stmt, &[]).await.unwrap();
534511
assert_eq!(id, rows[0].get(0));
535512
}
536513

0 commit comments

Comments
 (0)