Skip to content

Commit 7056e3e

Browse files
committed
Copy out support
1 parent 9e399aa commit 7056e3e

File tree

6 files changed

+207
-41
lines changed

6 files changed

+207
-41
lines changed

postgres-protocol/src/message/backend.rs

+40-39
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#![allow(missing_docs)]
22

3-
use byteorder::{ReadBytesExt, BigEndian};
3+
use byteorder::{BigEndian, ReadBytesExt};
44
use bytes::{Bytes, BytesMut};
55
use fallible_iterator::FallibleIterator;
66
use memchr::memchr;
@@ -148,45 +148,41 @@ impl Message {
148148
let storage = buf.read_all();
149149
Message::NoticeResponse(NoticeResponseBody { storage: storage })
150150
}
151-
b'R' => {
152-
match buf.read_i32::<BigEndian>()? {
153-
0 => Message::AuthenticationOk,
154-
2 => Message::AuthenticationKerberosV5,
155-
3 => Message::AuthenticationCleartextPassword,
156-
5 => {
157-
let mut salt = [0; 4];
158-
buf.read_exact(&mut salt)?;
159-
Message::AuthenticationMd5Password(
160-
AuthenticationMd5PasswordBody { salt: salt },
161-
)
162-
}
163-
6 => Message::AuthenticationScmCredential,
164-
7 => Message::AuthenticationGss,
165-
8 => {
166-
let storage = buf.read_all();
167-
Message::AuthenticationGssContinue(AuthenticationGssContinueBody(storage))
168-
}
169-
9 => Message::AuthenticationSspi,
170-
10 => {
171-
let storage = buf.read_all();
172-
Message::AuthenticationSasl(AuthenticationSaslBody(storage))
173-
}
174-
11 => {
175-
let storage = buf.read_all();
176-
Message::AuthenticationSaslContinue(AuthenticationSaslContinueBody(storage))
177-
}
178-
12 => {
179-
let storage = buf.read_all();
180-
Message::AuthenticationSaslFinal(AuthenticationSaslFinalBody(storage))
181-
}
182-
tag => {
183-
return Err(io::Error::new(
184-
io::ErrorKind::InvalidInput,
185-
format!("unknown authentication tag `{}`", tag),
186-
));
187-
}
151+
b'R' => match buf.read_i32::<BigEndian>()? {
152+
0 => Message::AuthenticationOk,
153+
2 => Message::AuthenticationKerberosV5,
154+
3 => Message::AuthenticationCleartextPassword,
155+
5 => {
156+
let mut salt = [0; 4];
157+
buf.read_exact(&mut salt)?;
158+
Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody { salt: salt })
188159
}
189-
}
160+
6 => Message::AuthenticationScmCredential,
161+
7 => Message::AuthenticationGss,
162+
8 => {
163+
let storage = buf.read_all();
164+
Message::AuthenticationGssContinue(AuthenticationGssContinueBody(storage))
165+
}
166+
9 => Message::AuthenticationSspi,
167+
10 => {
168+
let storage = buf.read_all();
169+
Message::AuthenticationSasl(AuthenticationSaslBody(storage))
170+
}
171+
11 => {
172+
let storage = buf.read_all();
173+
Message::AuthenticationSaslContinue(AuthenticationSaslContinueBody(storage))
174+
}
175+
12 => {
176+
let storage = buf.read_all();
177+
Message::AuthenticationSaslFinal(AuthenticationSaslFinalBody(storage))
178+
}
179+
tag => {
180+
return Err(io::Error::new(
181+
io::ErrorKind::InvalidInput,
182+
format!("unknown authentication tag `{}`", tag),
183+
));
184+
}
185+
},
190186
b's' => Message::PortalSuspended,
191187
b'S' => {
192188
let name = buf.read_cstr()?;
@@ -394,6 +390,11 @@ impl CopyDataBody {
394390
pub fn data(&self) -> &[u8] {
395391
&self.storage
396392
}
393+
394+
#[inline]
395+
pub fn into_bytes(self) -> Bytes {
396+
self.storage
397+
}
397398
}
398399

399400
pub struct CopyInResponseBody {

tokio-postgres/src/lib.rs

+17
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ extern crate state_machine_future;
2121
#[cfg(unix)]
2222
extern crate tokio_uds;
2323

24+
use bytes::Bytes;
2425
use futures::{Async, Future, Poll, Stream};
2526
use postgres_shared::rows::RowIndex;
2627
use std::fmt;
@@ -95,6 +96,10 @@ impl Client {
9596
Query(self.0.query(&statement.0, params))
9697
}
9798

99+
pub fn copy_out(&mut self, statement: &Statement, params: &[&ToSql]) -> CopyOut {
100+
CopyOut(self.0.copy_out(&statement.0, params))
101+
}
102+
98103
pub fn transaction<T>(&mut self, future: T) -> Transaction<T>
99104
where
100105
T: Future,
@@ -222,6 +227,18 @@ impl Stream for Query {
222227
}
223228
}
224229

230+
#[must_use = "streams do nothing unless polled"]
231+
pub struct CopyOut(proto::CopyOutStream);
232+
233+
impl Stream for CopyOut {
234+
type Item = Bytes;
235+
type Error = Error;
236+
237+
fn poll(&mut self) -> Poll<Option<Bytes>, Error> {
238+
self.0.poll()
239+
}
240+
}
241+
225242
pub struct Row(proto::Row);
226243

227244
impl Row {

tokio-postgres/src/proto/client.rs

+6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use std::sync::{Arc, Weak};
99
use disconnected;
1010
use error::{self, Error};
1111
use proto::connection::Request;
12+
use proto::copy_out::CopyOutStream;
1213
use proto::execute::ExecuteFuture;
1314
use proto::prepare::PrepareFuture;
1415
use proto::query::QueryStream;
@@ -130,6 +131,11 @@ impl Client {
130131
QueryStream::new(self.clone(), pending, statement.clone())
131132
}
132133

134+
pub fn copy_out(&self, statement: &Statement, params: &[&ToSql]) -> CopyOutStream {
135+
let pending = self.pending_execute(statement, params);
136+
CopyOutStream::new(self.clone(), pending, statement.clone())
137+
}
138+
133139
pub fn close_statement(&self, name: &str) {
134140
let mut buf = vec![];
135141
frontend::close(b'S', name, &mut buf).expect("statement name not valid");

tokio-postgres/src/proto/copy_out.rs

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
use bytes::Bytes;
2+
use futures::sync::mpsc;
3+
use futures::{Async, Poll, Stream};
4+
use postgres_protocol::message::backend::Message;
5+
use std::mem;
6+
7+
use error::{self, Error};
8+
use proto::client::{Client, PendingRequest};
9+
use proto::statement::Statement;
10+
use {bad_response, disconnected};
11+
12+
enum State {
13+
Start {
14+
client: Client,
15+
request: PendingRequest,
16+
statement: Statement,
17+
},
18+
ReadingCopyOutResponse {
19+
receiver: mpsc::Receiver<Message>,
20+
},
21+
ReadingCopyData {
22+
receiver: mpsc::Receiver<Message>,
23+
},
24+
Done,
25+
}
26+
27+
pub struct CopyOutStream(State);
28+
29+
impl Stream for CopyOutStream {
30+
type Item = Bytes;
31+
type Error = Error;
32+
33+
fn poll(&mut self) -> Poll<Option<Bytes>, Error> {
34+
loop {
35+
match mem::replace(&mut self.0, State::Done) {
36+
State::Start {
37+
client,
38+
request,
39+
statement,
40+
} => {
41+
let receiver = client.send(request)?;
42+
// it's ok for the statement to close now that we've queued the query
43+
drop(statement);
44+
self.0 = State::ReadingCopyOutResponse { receiver };
45+
}
46+
State::ReadingCopyOutResponse { mut receiver } => {
47+
let message = match receiver.poll() {
48+
Ok(Async::Ready(message)) => message,
49+
Ok(Async::NotReady) => {
50+
self.0 = State::ReadingCopyOutResponse { receiver };
51+
break Ok(Async::NotReady);
52+
}
53+
Err(()) => unreachable!("mpsc::Receiver doesn't return errors"),
54+
};
55+
56+
match message {
57+
Some(Message::BindComplete) => {
58+
self.0 = State::ReadingCopyOutResponse { receiver };
59+
}
60+
Some(Message::CopyOutResponse(_)) => {
61+
self.0 = State::ReadingCopyData { receiver };
62+
}
63+
Some(Message::ErrorResponse(body)) => break Err(error::__db(body)),
64+
Some(_) => break Err(bad_response()),
65+
None => break Err(disconnected()),
66+
}
67+
}
68+
State::ReadingCopyData { mut receiver } => {
69+
let message = match receiver.poll() {
70+
Ok(Async::Ready(message)) => message,
71+
Ok(Async::NotReady) => {
72+
self.0 = State::ReadingCopyData { receiver };
73+
break Ok(Async::NotReady);
74+
}
75+
Err(()) => unreachable!("mpsc::Reciever doesn't return errors"),
76+
};
77+
78+
match message {
79+
Some(Message::CopyData(body)) => {
80+
self.0 = State::ReadingCopyData { receiver };
81+
break Ok(Async::Ready(Some(body.into_bytes())));
82+
}
83+
Some(Message::CopyDone) | Some(Message::CommandComplete(_)) => {
84+
self.0 = State::ReadingCopyData { receiver };
85+
}
86+
Some(Message::ReadyForQuery(_)) => break Ok(Async::Ready(None)),
87+
Some(Message::ErrorResponse(body)) => break Err(error::__db(body)),
88+
Some(_) => break Err(bad_response()),
89+
None => break Err(disconnected()),
90+
}
91+
}
92+
State::Done => break Ok(Async::Ready(None)),
93+
}
94+
}
95+
}
96+
}
97+
98+
impl CopyOutStream {
99+
pub fn new(client: Client, request: PendingRequest, statement: Statement) -> CopyOutStream {
100+
CopyOutStream(State::Start {
101+
client,
102+
request,
103+
statement,
104+
})
105+
}
106+
}

tokio-postgres/src/proto/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ mod client;
1313
mod codec;
1414
mod connect;
1515
mod connection;
16+
mod copy_out;
1617
mod execute;
1718
mod handshake;
1819
mod prepare;
@@ -30,6 +31,7 @@ pub use proto::cancel::CancelFuture;
3031
pub use proto::client::Client;
3132
pub use proto::codec::PostgresCodec;
3233
pub use proto::connection::Connection;
34+
pub use proto::copy_out::CopyOutStream;
3335
pub use proto::execute::ExecuteFuture;
3436
pub use proto::handshake::HandshakeFuture;
3537
pub use proto::prepare::PrepareFuture;

tokio-postgres/tests/test.rs

+36-2
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ fn notifications() {
480480
}
481481

482482
#[test]
483-
fn test_transaction_commit() {
483+
fn transaction_commit() {
484484
let _ = env_logger::try_init();
485485
let mut runtime = Runtime::new().unwrap();
486486

@@ -518,7 +518,7 @@ fn test_transaction_commit() {
518518
}
519519

520520
#[test]
521-
fn test_transaction_abort() {
521+
fn transaction_abort() {
522522
let _ = env_logger::try_init();
523523
let mut runtime = Runtime::new().unwrap();
524524

@@ -556,3 +556,37 @@ fn test_transaction_abort() {
556556

557557
assert_eq!(rows.len(), 0);
558558
}
559+
560+
#[test]
561+
fn copy_out() {
562+
let _ = env_logger::try_init();
563+
let mut runtime = Runtime::new().unwrap();
564+
565+
let (mut client, connection) = runtime
566+
.block_on(tokio_postgres::connect(
567+
"postgres://postgres@localhost:5433".parse().unwrap(),
568+
TlsMode::None,
569+
))
570+
.unwrap();
571+
let connection = connection.map_err(|e| panic!("{}", e));
572+
runtime.handle().spawn(connection).unwrap();
573+
574+
runtime
575+
.block_on(client.batch_execute(
576+
"CREATE TEMPORARY TABLE foo (
577+
id SERIAL,
578+
name TEXT
579+
);
580+
INSERT INTO foo (name) VALUES ('jim'), ('joe');",
581+
))
582+
.unwrap();
583+
584+
let data = runtime
585+
.block_on(
586+
client
587+
.prepare("COPY foo TO STDOUT")
588+
.and_then(|s| client.copy_out(&s, &[]).concat2()),
589+
)
590+
.unwrap();
591+
assert_eq!(&data[..], b"1\tjim\n2\tjoe\n");
592+
}

0 commit comments

Comments
 (0)