Skip to content

Commit 53657b8

Browse files
committed
Implement batch_execute
1 parent 08df4b3 commit 53657b8

File tree

5 files changed

+133
-61
lines changed

5 files changed

+133
-61
lines changed

tokio-postgres/src/lib.rs

+16
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ impl Client {
9191
pub fn query(&mut self, statement: &Statement, params: &[&ToSql]) -> Query {
9292
Query(self.0.query(&statement.0, params))
9393
}
94+
95+
pub fn batch_execute(&mut self, query: &str) -> BatchExecute {
96+
BatchExecute(self.0.batch_execute(query))
97+
}
9498
}
9599

96100
#[must_use = "futures do nothing unless polled"]
@@ -234,3 +238,15 @@ impl Row {
234238
self.0.try_get(idx)
235239
}
236240
}
241+
242+
#[must_use = "futures do nothing unless polled"]
243+
pub struct BatchExecute(proto::SimpleQueryFuture);
244+
245+
impl Future for BatchExecute {
246+
type Item = ();
247+
type Error = Error;
248+
249+
fn poll(&mut self) -> Poll<(), Error> {
250+
self.0.poll()
251+
}
252+
}

tokio-postgres/src/proto/client.rs

+10
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use proto::connection::Request;
1212
use proto::execute::ExecuteFuture;
1313
use proto::prepare::PrepareFuture;
1414
use proto::query::QueryStream;
15+
use proto::simple_query::SimpleQueryFuture;
1516
use proto::statement::Statement;
1617
use types::{IsNull, Oid, ToSql, Type};
1718

@@ -99,6 +100,15 @@ impl Client {
99100
.map_err(|_| disconnected())
100101
}
101102

103+
pub fn batch_execute(&self, query: &str) -> SimpleQueryFuture {
104+
let pending = self.pending(|buf| {
105+
frontend::query(query, buf)?;
106+
Ok(())
107+
});
108+
109+
SimpleQueryFuture::new(self.clone(), pending)
110+
}
111+
102112
pub fn prepare(&self, name: String, query: &str, param_types: &[Type]) -> PrepareFuture {
103113
let pending = self.pending(|buf| {
104114
frontend::parse(&name, query, param_types.iter().map(|t| t.oid()), buf)?;

tokio-postgres/src/proto/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ mod handshake;
1818
mod prepare;
1919
mod query;
2020
mod row;
21+
mod simple_query;
2122
mod socket;
2223
mod statement;
2324
mod typeinfo;
@@ -33,5 +34,6 @@ pub use proto::handshake::HandshakeFuture;
3334
pub use proto::prepare::PrepareFuture;
3435
pub use proto::query::QueryStream;
3536
pub use proto::row::Row;
37+
pub use proto::simple_query::SimpleQueryFuture;
3638
pub use proto::socket::Socket;
3739
pub use proto::statement::Statement;
+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
use futures::sync::mpsc;
2+
use futures::{Poll, Stream};
3+
use postgres_protocol::message::backend::Message;
4+
use state_machine_future::RentToOwn;
5+
6+
use error::{self, Error};
7+
use proto::client::{Client, PendingRequest};
8+
use {bad_response, disconnected};
9+
10+
#[derive(StateMachineFuture)]
11+
pub enum SimpleQuery {
12+
#[state_machine_future(start, transitions(ReadResponse))]
13+
Start {
14+
client: Client,
15+
request: PendingRequest,
16+
},
17+
#[state_machine_future(transitions(Finished))]
18+
ReadResponse { receiver: mpsc::Receiver<Message> },
19+
#[state_machine_future(ready)]
20+
Finished(()),
21+
#[state_machine_future(error)]
22+
Failed(Error),
23+
}
24+
25+
impl PollSimpleQuery for SimpleQuery {
26+
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll<AfterStart, Error> {
27+
let state = state.take();
28+
let receiver = state.client.send(state.request)?;
29+
30+
transition!(ReadResponse { receiver })
31+
}
32+
33+
fn poll_read_response<'a>(
34+
state: &'a mut RentToOwn<'a, ReadResponse>,
35+
) -> Poll<AfterReadResponse, Error> {
36+
loop {
37+
let message = try_receive!(state.receiver.poll());
38+
39+
match message {
40+
Some(Message::CommandComplete(_))
41+
| Some(Message::RowDescription(_))
42+
| Some(Message::DataRow(_))
43+
| Some(Message::EmptyQueryResponse) => {}
44+
Some(Message::ErrorResponse(body)) => return Err(error::__db(body)),
45+
Some(Message::ReadyForQuery(_)) => transition!(Finished(())),
46+
Some(_) => return Err(bad_response()),
47+
None => return Err(disconnected()),
48+
}
49+
}
50+
}
51+
}
52+
53+
impl SimpleQueryFuture {
54+
pub fn new(client: Client, request: PendingRequest) -> SimpleQueryFuture {
55+
SimpleQuery::start(client, request)
56+
}
57+
}

tokio-postgres/tests/test.rs

+48-61
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,7 @@ fn insert_select() {
197197
runtime.handle().spawn(connection).unwrap();
198198

199199
runtime
200-
.block_on(
201-
client
202-
.prepare("CREATE TEMPORARY TABLE foo (id SERIAL, name TEXT)")
203-
.and_then(|create| client.execute(&create, &[]))
204-
.map(|n| assert_eq!(n, 0)),
205-
)
200+
.block_on(client.batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL, name TEXT)"))
206201
.unwrap();
207202

208203
let insert = client.prepare("INSERT INTO foo (name) VALUES ($1), ($2)");
@@ -238,14 +233,13 @@ fn cancel_query() {
238233
let connection = connection.map_err(|e| panic!("{}", e));
239234
runtime.handle().spawn(connection).unwrap();
240235

241-
let sleep = client.prepare("SELECT pg_sleep(100)");
242-
let sleep = runtime.block_on(sleep).unwrap();
243-
244-
let sleep = client.execute(&sleep, &[]).then(|r| match r {
245-
Ok(_) => panic!("unexpected success"),
246-
Err(ref e) if e.code() == Some(&SqlState::QUERY_CANCELED) => Ok::<(), ()>(()),
247-
Err(e) => panic!("unexpected error {}", e),
248-
});
236+
let sleep = client
237+
.batch_execute("SELECT pg_sleep(100)")
238+
.then(|r| match r {
239+
Ok(_) => panic!("unexpected success"),
240+
Err(ref e) if e.code() == Some(&SqlState::QUERY_CANCELED) => Ok::<(), ()>(()),
241+
Err(e) => panic!("unexpected error {}", e),
242+
});
249243
let cancel = Delay::new(Instant::now() + Duration::from_millis(100))
250244
.then(|r| {
251245
r.unwrap();
@@ -276,17 +270,15 @@ fn custom_enum() {
276270
let connection = connection.map_err(|e| panic!("{}", e));
277271
runtime.handle().spawn(connection).unwrap();
278272

279-
let create_type = client.prepare(
280-
"CREATE TYPE pg_temp.mood AS ENUM (
281-
'sad',
282-
'ok',
283-
'happy'
284-
)",
285-
);
286-
let create_type = runtime.block_on(create_type).unwrap();
287-
288-
let create_type = client.execute(&create_type, &[]);
289-
runtime.block_on(create_type).unwrap();
273+
runtime
274+
.block_on(client.batch_execute(
275+
"CREATE TYPE pg_temp.mood AS ENUM (
276+
'sad',
277+
'ok',
278+
'happy'
279+
)",
280+
))
281+
.unwrap();
290282

291283
let select = client.prepare("SELECT $1::mood");
292284
let select = runtime.block_on(select).unwrap();
@@ -316,12 +308,11 @@ fn custom_domain() {
316308
let connection = connection.map_err(|e| panic!("{}", e));
317309
runtime.handle().spawn(connection).unwrap();
318310

319-
let create_type =
320-
client.prepare("CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16)");
321-
let create_type = runtime.block_on(create_type).unwrap();
322-
323-
let create_type = client.execute(&create_type, &[]);
324-
runtime.block_on(create_type).unwrap();
311+
runtime
312+
.block_on(client.batch_execute(
313+
"CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16)",
314+
))
315+
.unwrap();
325316

326317
let select = client.prepare("SELECT $1::session_id");
327318
let select = runtime.block_on(select).unwrap();
@@ -371,17 +362,15 @@ fn custom_composite() {
371362
let connection = connection.map_err(|e| panic!("{}", e));
372363
runtime.handle().spawn(connection).unwrap();
373364

374-
let create_type = client.prepare(
375-
"CREATE TYPE pg_temp.inventory_item AS (
376-
name TEXT,
377-
supplier INTEGER,
378-
price NUMERIC
379-
)",
380-
);
381-
let create_type = runtime.block_on(create_type).unwrap();
382-
383-
let create_type = client.execute(&create_type, &[]);
384-
runtime.block_on(create_type).unwrap();
365+
runtime
366+
.block_on(client.batch_execute(
367+
"CREATE TYPE pg_temp.inventory_item AS (
368+
name TEXT,
369+
supplier INTEGER,
370+
price NUMERIC
371+
)",
372+
))
373+
.unwrap();
385374

386375
let select = client.prepare("SELECT $1::inventory_item");
387376
let select = runtime.block_on(select).unwrap();
@@ -414,16 +403,14 @@ fn custom_range() {
414403
let connection = connection.map_err(|e| panic!("{}", e));
415404
runtime.handle().spawn(connection).unwrap();
416405

417-
let create_type = client.prepare(
418-
"CREATE TYPE pg_temp.floatrange AS RANGE (
419-
subtype = float8,
420-
subtype_diff = float8mi
421-
)",
422-
);
423-
let create_type = runtime.block_on(create_type).unwrap();
424-
425-
let create_type = client.execute(&create_type, &[]);
426-
runtime.block_on(create_type).unwrap();
406+
runtime
407+
.block_on(client.batch_execute(
408+
"CREATE TYPE pg_temp.floatrange AS RANGE (
409+
subtype = float8,
410+
subtype_diff = float8mi
411+
)",
412+
))
413+
.unwrap();
427414

428415
let select = client.prepare("SELECT $1::floatrange");
429416
let select = runtime.block_on(select).unwrap();
@@ -479,17 +466,17 @@ fn notifications() {
479466
});
480467
runtime.handle().spawn(connection).unwrap();
481468

482-
let listen = client.prepare("LISTEN test_notifications");
483-
let listen = runtime.block_on(listen).unwrap();
484-
runtime.block_on(client.execute(&listen, &[])).unwrap();
469+
runtime
470+
.block_on(client.batch_execute("LISTEN test_notifications"))
471+
.unwrap();
485472

486-
let notify = client.prepare("NOTIFY test_notifications, 'hello'");
487-
let notify = runtime.block_on(notify).unwrap();
488-
runtime.block_on(client.execute(&notify, &[])).unwrap();
473+
runtime
474+
.block_on(client.batch_execute("NOTIFY test_notifications, 'hello'"))
475+
.unwrap();
489476

490-
let notify = client.prepare("NOTIFY test_notifications, 'world'");
491-
let notify = runtime.block_on(notify).unwrap();
492-
runtime.block_on(client.execute(&notify, &[])).unwrap();
477+
runtime
478+
.block_on(client.batch_execute("NOTIFY test_notifications, 'world'"))
479+
.unwrap();
493480

494481
drop(client);
495482
runtime.run().unwrap();

0 commit comments

Comments
 (0)