Skip to content

Commit b74f5c8

Browse files
committed
copy in support
1 parent daeb538 commit b74f5c8

File tree

7 files changed

+483
-73
lines changed

7 files changed

+483
-73
lines changed

postgres-shared/src/error/mod.rs

+8
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,14 @@ pub fn __db(e: ErrorResponseBody) -> Error {
325325
}
326326
}
327327

328+
#[doc(hidden)]
329+
pub fn __user<T>(e: T) -> Error
330+
where
331+
T: Into<Box<error::Error + Sync + Send>>,
332+
{
333+
Error(Box::new(ErrorKind::Conversion(e.into())))
334+
}
335+
328336
#[doc(hidden)]
329337
pub fn io(e: io::Error) -> Error {
330338
Error(Box::new(ErrorKind::Io(e)))

tokio-postgres/src/lib.rs

+28
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ extern crate tokio_uds;
2424
use bytes::Bytes;
2525
use futures::{Async, Future, Poll, Stream};
2626
use postgres_shared::rows::RowIndex;
27+
use std::error::Error as StdError;
2728
use std::fmt;
2829
use std::io;
2930
use std::sync::atomic::{AtomicUsize, Ordering};
@@ -96,6 +97,14 @@ impl Client {
9697
Query(self.0.query(&statement.0, params))
9798
}
9899

100+
pub fn copy_in<S>(&mut self, statement: &Statement, params: &[&ToSql], stream: S) -> CopyIn<S>
101+
where
102+
S: Stream<Item = Vec<u8>>,
103+
S::Error: Into<Box<StdError + Sync + Send>>,
104+
{
105+
CopyIn(self.0.copy_in(&statement.0, params, stream))
106+
}
107+
99108
pub fn copy_out(&mut self, statement: &Statement, params: &[&ToSql]) -> CopyOut {
100109
CopyOut(self.0.copy_out(&statement.0, params))
101110
}
@@ -227,6 +236,25 @@ impl Stream for Query {
227236
}
228237
}
229238

239+
#[must_use = "futures do nothing unless polled"]
240+
pub struct CopyIn<S>(proto::CopyInFuture<S>)
241+
where
242+
S: Stream<Item = Vec<u8>>,
243+
S::Error: Into<Box<StdError + Sync + Send>>;
244+
245+
impl<S> Future for CopyIn<S>
246+
where
247+
S: Stream<Item = Vec<u8>>,
248+
S::Error: Into<Box<StdError + Sync + Send>>,
249+
{
250+
type Item = u64;
251+
type Error = Error;
252+
253+
fn poll(&mut self) -> Poll<u64, Error> {
254+
self.0.poll()
255+
}
256+
}
257+
230258
#[must_use = "streams do nothing unless polled"]
231259
pub struct CopyOut(proto::CopyOutStream);
232260

tokio-postgres/src/proto/client.rs

+61-31
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
use antidote::Mutex;
22
use futures::sync::mpsc;
3+
use futures::{AsyncSink, Sink, Stream};
34
use postgres_protocol;
45
use postgres_protocol::message::backend::Message;
56
use postgres_protocol::message::frontend;
67
use std::collections::HashMap;
8+
use std::error::Error as StdError;
79
use std::sync::{Arc, Weak};
810

911
use disconnected;
1012
use error::{self, Error};
11-
use proto::connection::Request;
13+
use proto::connection::{Request, RequestMessages};
14+
use proto::copy_in::{CopyInFuture, CopyInReceiver, CopyMessage};
1215
use proto::copy_out::CopyOutStream;
1316
use proto::execute::ExecuteFuture;
1417
use proto::prepare::PrepareFuture;
@@ -17,7 +20,7 @@ use proto::simple_query::SimpleQueryFuture;
1720
use proto::statement::Statement;
1821
use types::{IsNull, Oid, ToSql, Type};
1922

20-
pub struct PendingRequest(Result<Vec<u8>, Error>);
23+
pub struct PendingRequest(Result<RequestMessages, Error>);
2124

2225
pub struct WeakClient(Weak<Inner>);
2326

@@ -122,17 +125,45 @@ impl Client {
122125
}
123126

124127
pub fn execute(&self, statement: &Statement, params: &[&ToSql]) -> ExecuteFuture {
125-
let pending = self.pending_execute(statement, params);
128+
let pending = PendingRequest(
129+
self.excecute_message(statement, params)
130+
.map(RequestMessages::Single),
131+
);
126132
ExecuteFuture::new(self.clone(), pending, statement.clone())
127133
}
128134

129135
pub fn query(&self, statement: &Statement, params: &[&ToSql]) -> QueryStream {
130-
let pending = self.pending_execute(statement, params);
136+
let pending = PendingRequest(
137+
self.excecute_message(statement, params)
138+
.map(RequestMessages::Single),
139+
);
131140
QueryStream::new(self.clone(), pending, statement.clone())
132141
}
133142

143+
pub fn copy_in<S>(&self, statement: &Statement, params: &[&ToSql], stream: S) -> CopyInFuture<S>
144+
where
145+
S: Stream<Item = Vec<u8>>,
146+
S::Error: Into<Box<StdError + Sync + Send>>,
147+
{
148+
let (mut sender, receiver) = mpsc::channel(0);
149+
let pending = PendingRequest(self.excecute_message(statement, params).map(|buf| {
150+
match sender.start_send(CopyMessage::Data(buf)) {
151+
Ok(AsyncSink::Ready) => {}
152+
_ => unreachable!("channel should have capacity"),
153+
}
154+
RequestMessages::CopyIn {
155+
receiver: CopyInReceiver::new(receiver),
156+
pending_message: None,
157+
}
158+
}));
159+
CopyInFuture::new(self.clone(), pending, statement.clone(), stream, sender)
160+
}
161+
134162
pub fn copy_out(&self, statement: &Statement, params: &[&ToSql]) -> CopyOutStream {
135-
let pending = self.pending_execute(statement, params);
163+
let pending = PendingRequest(
164+
self.excecute_message(statement, params)
165+
.map(RequestMessages::Single),
166+
);
136167
CopyOutStream::new(self.clone(), pending, statement.clone())
137168
}
138169

@@ -142,42 +173,41 @@ impl Client {
142173
frontend::sync(&mut buf);
143174
let (sender, _) = mpsc::channel(0);
144175
let _ = self.0.sender.unbounded_send(Request {
145-
messages: buf,
176+
messages: RequestMessages::Single(buf),
146177
sender,
147178
});
148179
}
149180

150-
fn pending_execute(&self, statement: &Statement, params: &[&ToSql]) -> PendingRequest {
151-
self.pending(|buf| {
152-
let r = frontend::bind(
153-
"",
154-
statement.name(),
155-
Some(1),
156-
params.iter().zip(statement.params()),
157-
|(param, ty), buf| match param.to_sql_checked(ty, buf) {
158-
Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
159-
Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
160-
Err(e) => Err(e),
161-
},
162-
Some(1),
163-
buf,
164-
);
165-
match r {
166-
Ok(()) => {}
167-
Err(frontend::BindError::Conversion(e)) => return Err(error::conversion(e)),
168-
Err(frontend::BindError::Serialization(e)) => return Err(Error::from(e)),
169-
}
170-
frontend::execute("", 0, buf)?;
171-
frontend::sync(buf);
172-
Ok(())
173-
})
181+
fn excecute_message(&self, statement: &Statement, params: &[&ToSql]) -> Result<Vec<u8>, Error> {
182+
let mut buf = vec![];
183+
let r = frontend::bind(
184+
"",
185+
statement.name(),
186+
Some(1),
187+
params.iter().zip(statement.params()),
188+
|(param, ty), buf| match param.to_sql_checked(ty, buf) {
189+
Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
190+
Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
191+
Err(e) => Err(e),
192+
},
193+
Some(1),
194+
&mut buf,
195+
);
196+
match r {
197+
Ok(()) => {}
198+
Err(frontend::BindError::Conversion(e)) => return Err(error::conversion(e)),
199+
Err(frontend::BindError::Serialization(e)) => return Err(Error::from(e)),
200+
}
201+
frontend::execute("", 0, &mut buf)?;
202+
frontend::sync(&mut buf);
203+
Ok(buf)
174204
}
175205

176206
fn pending<F>(&self, messages: F) -> PendingRequest
177207
where
178208
F: FnOnce(&mut Vec<u8>) -> Result<(), Error>,
179209
{
180210
let mut buf = vec![];
181-
PendingRequest(messages(&mut buf).map(|()| buf))
211+
PendingRequest(messages(&mut buf).map(|()| RequestMessages::Single(buf)))
182212
}
183213
}

tokio-postgres/src/proto/connection.rs

+66-14
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,20 @@ use tokio_codec::Framed;
88

99
use error::{self, DbError, Error};
1010
use proto::codec::PostgresCodec;
11+
use proto::copy_in::CopyInReceiver;
1112
use tls::TlsStream;
1213
use {bad_response, disconnected, AsyncMessage, CancelData, Notification};
1314

15+
pub enum RequestMessages {
16+
Single(Vec<u8>),
17+
CopyIn {
18+
receiver: CopyInReceiver,
19+
pending_message: Option<Vec<u8>>,
20+
},
21+
}
22+
1423
pub struct Request {
15-
pub messages: Vec<u8>,
24+
pub messages: RequestMessages,
1625
pub sender: mpsc::Sender<Message>,
1726
}
1827

@@ -28,7 +37,7 @@ pub struct Connection {
2837
cancel_data: CancelData,
2938
parameters: HashMap<String, String>,
3039
receiver: mpsc::UnboundedReceiver<Request>,
31-
pending_request: Option<Vec<u8>>,
40+
pending_request: Option<RequestMessages>,
3241
pending_response: Option<Message>,
3342
responses: VecDeque<mpsc::Sender<Message>>,
3443
state: State,
@@ -140,7 +149,7 @@ impl Connection {
140149
}
141150
}
142151

143-
fn poll_request(&mut self) -> Poll<Option<Vec<u8>>, Error> {
152+
fn poll_request(&mut self) -> Poll<Option<RequestMessages>, Error> {
144153
if let Some(message) = self.pending_request.take() {
145154
trace!("retrying pending request");
146155
return Ok(Async::Ready(Some(message)));
@@ -170,7 +179,7 @@ impl Connection {
170179
self.state = State::Terminating;
171180
let mut request = vec![];
172181
frontend::terminate(&mut request);
173-
request
182+
RequestMessages::Single(request)
174183
}
175184
Async::Ready(None) => {
176185
trace!(
@@ -185,17 +194,60 @@ impl Connection {
185194
}
186195
};
187196

188-
match self.stream.start_send(request)? {
189-
AsyncSink::Ready => {
190-
if self.state == State::Terminating {
191-
trace!("poll_write: sent eof, closing");
192-
self.state = State::Closing;
197+
match request {
198+
RequestMessages::Single(request) => match self.stream.start_send(request)? {
199+
AsyncSink::Ready => {
200+
if self.state == State::Terminating {
201+
trace!("poll_write: sent eof, closing");
202+
self.state = State::Closing;
203+
}
193204
}
194-
}
195-
AsyncSink::NotReady(request) => {
196-
trace!("poll_write: waiting on socket");
197-
self.pending_request = Some(request);
198-
return Ok(false);
205+
AsyncSink::NotReady(request) => {
206+
trace!("poll_write: waiting on socket");
207+
self.pending_request = Some(RequestMessages::Single(request));
208+
return Ok(false);
209+
}
210+
},
211+
RequestMessages::CopyIn {
212+
mut receiver,
213+
mut pending_message,
214+
} => {
215+
let message = match pending_message.take() {
216+
Some(message) => message,
217+
None => match receiver.poll() {
218+
Ok(Async::Ready(Some(message))) => message,
219+
Ok(Async::Ready(None)) => {
220+
trace!("poll_write: finished copy_in request");
221+
continue;
222+
}
223+
Ok(Async::NotReady) => {
224+
trace!("poll_write: waiting on copy_in stream");
225+
self.pending_request = Some(RequestMessages::CopyIn {
226+
receiver,
227+
pending_message,
228+
});
229+
return Ok(true);
230+
}
231+
Err(()) => unreachable!("mpsc::Receiver doesn't return errors"),
232+
},
233+
};
234+
235+
match self.stream.start_send(message)? {
236+
AsyncSink::Ready => {
237+
self.pending_request = Some(RequestMessages::CopyIn {
238+
receiver,
239+
pending_message: None,
240+
});
241+
}
242+
AsyncSink::NotReady(message) => {
243+
trace!("poll_write: waiting on socket");
244+
self.pending_request = Some(RequestMessages::CopyIn {
245+
receiver,
246+
pending_message: Some(message),
247+
});
248+
return Ok(false);
249+
}
250+
};
199251
}
200252
}
201253
}

0 commit comments

Comments
 (0)