Skip to content

Commit 5169820

Browse files
committed
Return iterators from query in sync API
1 parent 45593f5 commit 5169820

File tree

7 files changed

+138
-22
lines changed

7 files changed

+138
-22
lines changed

postgres/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ runtime = ["tokio-postgres/runtime", "tokio", "lazy_static", "log"]
1111

1212
[dependencies]
1313
bytes = "0.4"
14+
fallible-iterator = "0.1"
1415
futures = "0.1"
1516
tokio-postgres = { version = "0.3", path = "../tokio-postgres", default-features = false }
1617

postgres/src/client.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ use futures::{Async, Future, Poll, Stream};
44
use std::io::{self, BufRead, Cursor, Read};
55
use std::marker::PhantomData;
66
use tokio_postgres::types::{ToSql, Type};
7-
use tokio_postgres::{Error, Row};
7+
use tokio_postgres::Error;
88
#[cfg(feature = "runtime")]
99
use tokio_postgres::{MakeTlsMode, Socket, TlsMode};
1010

1111
#[cfg(feature = "runtime")]
1212
use crate::Builder;
13-
use crate::{Statement, ToStatement, Transaction};
13+
use crate::{Query, Statement, ToStatement, Transaction};
1414

1515
pub struct Client(tokio_postgres::Client);
1616

@@ -48,12 +48,12 @@ impl Client {
4848
self.0.execute(&statement.0, params).wait()
4949
}
5050

51-
pub fn query<T>(&mut self, query: &T, params: &[&dyn ToSql]) -> Result<Vec<Row>, Error>
51+
pub fn query<T>(&mut self, query: &T, params: &[&dyn ToSql]) -> Result<Query<'_>, Error>
5252
where
5353
T: ?Sized + ToStatement,
5454
{
5555
let statement = query.__statement(self)?;
56-
self.0.query(&statement.0, params).collect().wait()
56+
Ok(Query::new(self.0.query(&statement.0, params)))
5757
}
5858

5959
pub fn copy_in<T, R>(

postgres/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ use tokio::runtime::{self, Runtime};
77
mod builder;
88
mod client;
99
mod portal;
10+
mod query;
11+
mod query_portal;
1012
mod statement;
1113
mod to_statement;
1214
mod transaction;
@@ -19,6 +21,8 @@ mod test;
1921
pub use crate::builder::*;
2022
pub use crate::client::*;
2123
pub use crate::portal::*;
24+
pub use crate::query::*;
25+
pub use crate::query_portal::*;
2226
pub use crate::statement::*;
2327
pub use crate::to_statement::*;
2428
pub use crate::transaction::*;

postgres/src/query.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
use fallible_iterator::FallibleIterator;
2+
use futures::stream::{self, Stream};
3+
use std::marker::PhantomData;
4+
use tokio_postgres::{Error, Row};
5+
6+
pub struct Query<'a> {
7+
it: stream::Wait<tokio_postgres::Query>,
8+
_p: PhantomData<&'a mut ()>,
9+
}
10+
11+
// no-op impl to extend the borrow until drop
12+
impl<'a> Drop for Query<'a> {
13+
fn drop(&mut self) {}
14+
}
15+
16+
impl<'a> Query<'a> {
17+
pub(crate) fn new(stream: tokio_postgres::Query) -> Query<'a> {
18+
Query {
19+
it: stream.wait(),
20+
_p: PhantomData,
21+
}
22+
}
23+
}
24+
25+
impl<'a> FallibleIterator for Query<'a> {
26+
type Item = Row;
27+
type Error = Error;
28+
29+
fn next(&mut self) -> Result<Option<Row>, Error> {
30+
match self.it.next() {
31+
Some(Ok(row)) => Ok(Some(row)),
32+
Some(Err(e)) => Err(e),
33+
None => Ok(None),
34+
}
35+
}
36+
}

postgres/src/query_portal.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
use fallible_iterator::FallibleIterator;
2+
use futures::stream::{self, Stream};
3+
use std::marker::PhantomData;
4+
use tokio_postgres::{Error, Row};
5+
6+
pub struct QueryPortal<'a> {
7+
it: stream::Wait<tokio_postgres::QueryPortal>,
8+
_p: PhantomData<&'a mut ()>,
9+
}
10+
11+
// no-op impl to extend the borrow until drop
12+
impl<'a> Drop for QueryPortal<'a> {
13+
fn drop(&mut self) {}
14+
}
15+
16+
impl<'a> QueryPortal<'a> {
17+
pub(crate) fn new(stream: tokio_postgres::QueryPortal) -> QueryPortal<'a> {
18+
QueryPortal {
19+
it: stream.wait(),
20+
_p: PhantomData,
21+
}
22+
}
23+
}
24+
25+
impl<'a> FallibleIterator for QueryPortal<'a> {
26+
type Item = Row;
27+
type Error = Error;
28+
29+
fn next(&mut self) -> Result<Option<Row>, Error> {
30+
match self.it.next() {
31+
Some(Ok(row)) => Ok(Some(row)),
32+
Some(Err(e)) => Err(e),
33+
None => Ok(None),
34+
}
35+
}
36+
}

postgres/src/test.rs

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use fallible_iterator::FallibleIterator;
12
use std::io::Read;
23
use tokio_postgres::types::Type;
34
use tokio_postgres::NoTls;
@@ -20,7 +21,11 @@ fn query_prepared() {
2021
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
2122

2223
let stmt = client.prepare("SELECT $1::TEXT").unwrap();
23-
let rows = client.query(&stmt, &[&"hello"]).unwrap();
24+
let rows = client
25+
.query(&stmt, &[&"hello"])
26+
.unwrap()
27+
.collect::<Vec<_>>()
28+
.unwrap();
2429
assert_eq!(rows.len(), 1);
2530
assert_eq!(rows[0].get::<_, &str>(0), "hello");
2631
}
@@ -29,7 +34,11 @@ fn query_prepared() {
2934
fn query_unprepared() {
3035
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
3136

32-
let rows = client.query("SELECT $1::TEXT", &[&"hello"]).unwrap();
37+
let rows = client
38+
.query("SELECT $1::TEXT", &[&"hello"])
39+
.unwrap()
40+
.collect::<Vec<_>>()
41+
.unwrap();
3342
assert_eq!(rows.len(), 1);
3443
assert_eq!(rows[0].get::<_, &str>(0), "hello");
3544
}
@@ -50,7 +59,11 @@ fn transaction_commit() {
5059

5160
transaction.commit().unwrap();
5261

53-
let rows = client.query("SELECT * FROM foo", &[]).unwrap();
62+
let rows = client
63+
.query("SELECT * FROM foo", &[])
64+
.unwrap()
65+
.collect::<Vec<_>>()
66+
.unwrap();
5467
assert_eq!(rows.len(), 1);
5568
assert_eq!(rows[0].get::<_, i32>(0), 1);
5669
}
@@ -71,7 +84,11 @@ fn transaction_rollback() {
7184

7285
transaction.rollback().unwrap();
7386

74-
let rows = client.query("SELECT * FROM foo", &[]).unwrap();
87+
let rows = client
88+
.query("SELECT * FROM foo", &[])
89+
.unwrap()
90+
.collect::<Vec<_>>()
91+
.unwrap();
7592
assert_eq!(rows.len(), 0);
7693
}
7794

@@ -91,7 +108,11 @@ fn transaction_drop() {
91108

92109
drop(transaction);
93110

94-
let rows = client.query("SELECT * FROM foo", &[]).unwrap();
111+
let rows = client
112+
.query("SELECT * FROM foo", &[])
113+
.unwrap()
114+
.collect::<Vec<_>>()
115+
.unwrap();
95116
assert_eq!(rows.len(), 0);
96117
}
97118

@@ -119,6 +140,8 @@ fn nested_transactions() {
119140

120141
let rows = transaction
121142
.query("SELECT id FROM foo ORDER BY id", &[])
143+
.unwrap()
144+
.collect::<Vec<_>>()
122145
.unwrap();
123146
assert_eq!(rows.len(), 1);
124147
assert_eq!(rows[0].get::<_, i32>(0), 1);
@@ -139,7 +162,11 @@ fn nested_transactions() {
139162
transaction3.commit().unwrap();
140163
transaction.commit().unwrap();
141164

142-
let rows = client.query("SELECT id FROM foo ORDER BY id", &[]).unwrap();
165+
let rows = client
166+
.query("SELECT id FROM foo ORDER BY id", &[])
167+
.unwrap()
168+
.collect::<Vec<_>>()
169+
.unwrap();
143170
assert_eq!(rows.len(), 3);
144171
assert_eq!(rows[0].get::<_, i32>(0), 1);
145172
assert_eq!(rows[1].get::<_, i32>(0), 3);
@@ -164,6 +191,8 @@ fn copy_in() {
164191

165192
let rows = client
166193
.query("SELECT id, name FROM foo ORDER BY id", &[])
194+
.unwrap()
195+
.collect::<Vec<_>>()
167196
.unwrap();
168197

169198
assert_eq!(rows.len(), 2);
@@ -219,12 +248,20 @@ fn portal() {
219248
.bind("SELECT * FROM foo ORDER BY id", &[])
220249
.unwrap();
221250

222-
let rows = transaction.query_portal(&portal, 2).unwrap();
251+
let rows = transaction
252+
.query_portal(&portal, 2)
253+
.unwrap()
254+
.collect::<Vec<_>>()
255+
.unwrap();
223256
assert_eq!(rows.len(), 2);
224257
assert_eq!(rows[0].get::<_, i32>(0), 1);
225258
assert_eq!(rows[1].get::<_, i32>(0), 2);
226259

227-
let rows = transaction.query_portal(&portal, 2).unwrap();
260+
let rows = transaction
261+
.query_portal(&portal, 2)
262+
.unwrap()
263+
.collect::<Vec<_>>()
264+
.unwrap();
228265
assert_eq!(rows.len(), 1);
229266
assert_eq!(rows[0].get::<_, i32>(0), 3);
230267
}

postgres/src/transaction.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
use futures::{Future, Stream};
1+
use futures::Future;
22
use std::io::Read;
33
use tokio_postgres::types::{ToSql, Type};
4-
use tokio_postgres::{Error, Row};
4+
use tokio_postgres::Error;
55

6-
use crate::{Client, CopyOutReader, Portal, Statement, ToStatement};
6+
use crate::{Client, CopyOutReader, Portal, Query, QueryPortal, Statement, ToStatement};
77

88
pub struct Transaction<'a> {
99
client: &'a mut Client,
@@ -67,7 +67,7 @@ impl<'a> Transaction<'a> {
6767
self.client.execute(query, params)
6868
}
6969

70-
pub fn query<T>(&mut self, query: &T, params: &[&dyn ToSql]) -> Result<Vec<Row>, Error>
70+
pub fn query<T>(&mut self, query: &T, params: &[&dyn ToSql]) -> Result<Query<'_>, Error>
7171
where
7272
T: ?Sized + ToStatement,
7373
{
@@ -86,12 +86,14 @@ impl<'a> Transaction<'a> {
8686
.map(Portal)
8787
}
8888

89-
pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
90-
self.client
91-
.get_mut()
92-
.query_portal(&portal.0, max_rows)
93-
.collect()
94-
.wait()
89+
pub fn query_portal(
90+
&mut self,
91+
portal: &Portal,
92+
max_rows: i32,
93+
) -> Result<QueryPortal<'_>, Error> {
94+
Ok(QueryPortal::new(
95+
self.client.get_mut().query_portal(&portal.0, max_rows),
96+
))
9597
}
9698

9799
pub fn copy_in<T, R>(

0 commit comments

Comments
 (0)