Skip to content

Commit f458847

Browse files
committed
Support copy_in
1 parent 4afd523 commit f458847

File tree

8 files changed

+365
-163
lines changed

8 files changed

+365
-163
lines changed

tokio-postgres/src/cancel_query.rs

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
use crate::client::SocketConfig;
2-
use crate::config::{SslMode, Host};
2+
use crate::config::{Host, SslMode};
3+
use crate::tls::MakeTlsConnect;
34
use crate::{cancel_query_raw, connect_socket, connect_tls, Error, Socket};
45
use std::io;
5-
use crate::tls::MakeTlsConnect;
66

77
pub(crate) async fn cancel_query<T>(
88
config: Option<SocketConfig>,
99
ssl_mode: SslMode,
1010
mut tls: T,
1111
process_id: i32,
1212
secret_key: i32,
13-
) -> Result<(), Error> where T: MakeTlsConnect<Socket> {
13+
) -> Result<(), Error>
14+
where
15+
T: MakeTlsConnect<Socket>,
16+
{
1417
let config = match config {
1518
Some(config) => config,
1619
None => {
@@ -27,7 +30,9 @@ pub(crate) async fn cancel_query<T>(
2730
#[cfg(unix)]
2831
Host::Unix(_) => "",
2932
};
30-
let tls = tls.make_tls_connect(hostname).map_err(|e| Error::tls(e.into()))?;
33+
let tls = tls
34+
.make_tls_connect(hostname)
35+
.map_err(|e| Error::tls(e.into()))?;
3136

3237
let socket = connect_socket::connect_socket(
3338
&config.host,

tokio-postgres/src/client.rs

+28-2
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,19 @@ use crate::tls::TlsConnect;
77
use crate::types::{Oid, ToSql, Type};
88
#[cfg(feature = "runtime")]
99
use crate::Socket;
10-
use crate::{cancel_query, cancel_query_raw, query, Transaction};
10+
use crate::{cancel_query, cancel_query_raw, copy_in, query, Transaction};
1111
use crate::{prepare, SimpleQueryMessage};
1212
use crate::{simple_query, Row};
1313
use crate::{Error, Statement};
14+
use bytes::IntoBuf;
1415
use fallible_iterator::FallibleIterator;
1516
use futures::channel::mpsc;
16-
use futures::{future, Stream};
17+
use futures::{future, Stream, TryStream};
1718
use futures::{ready, StreamExt};
1819
use parking_lot::Mutex;
1920
use postgres_protocol::message::backend::Message;
2021
use std::collections::HashMap;
22+
use std::error;
2123
use std::future::Future;
2224
use std::sync::Arc;
2325
use std::task::{Context, Poll};
@@ -240,6 +242,30 @@ impl Client {
240242
query::execute(self.inner(), buf)
241243
}
242244

245+
/// Executes a `COPY FROM STDIN` statement, returning the number of rows created.
246+
///
247+
/// The data in the provided stream is passed along to the server verbatim; it is the caller's responsibility to
248+
/// ensure it uses the proper format.
249+
///
250+
/// # Panics
251+
///
252+
/// Panics if the number of parameters provided does not match the number expected.
253+
pub fn copy_in<S>(
254+
&mut self,
255+
statement: &Statement,
256+
params: &[&dyn ToSql],
257+
stream: S,
258+
) -> impl Future<Output = Result<u64, Error>>
259+
where
260+
S: TryStream,
261+
S::Ok: IntoBuf,
262+
<S::Ok as IntoBuf>::Buf: 'static + Send,
263+
S::Error: Into<Box<dyn error::Error + Sync + Send>>,
264+
{
265+
let buf = query::encode(statement, params.iter().cloned());
266+
copy_in::copy_in(self.inner(), buf, stream)
267+
}
268+
243269
/// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
244270
///
245271
/// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that

tokio-postgres/src/connection.rs

+20
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
2+
use crate::copy_in::CopyInReceiver;
23
use crate::error::DbError;
34
use crate::maybe_tls_stream::MaybeTlsStream;
45
use crate::{AsyncMessage, Error, Notification};
@@ -17,6 +18,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
1718

1819
pub enum RequestMessages {
1920
Single(FrontendMessage),
21+
CopyIn(CopyInReceiver),
2022
}
2123

2224
pub struct Request {
@@ -237,6 +239,24 @@ where
237239
self.state = State::Closing;
238240
}
239241
}
242+
RequestMessages::CopyIn(mut receiver) => {
243+
let message = match receiver.poll_next_unpin(cx) {
244+
Poll::Ready(Some(message)) => message,
245+
Poll::Ready(None) => {
246+
trace!("poll_write: finished copy_in request");
247+
continue;
248+
}
249+
Poll::Pending => {
250+
trace!("poll_write: waiting on copy_in stream");
251+
self.pending_request = Some(RequestMessages::CopyIn(receiver));
252+
return Ok(true);
253+
}
254+
};
255+
Pin::new(&mut self.stream)
256+
.start_send(message)
257+
.map_err(Error::io)?;
258+
self.pending_request = Some(RequestMessages::CopyIn(receiver));
259+
}
240260
}
241261
}
242262
}

tokio-postgres/src/copy_in.rs

+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
use crate::client::InnerClient;
2+
use crate::codec::FrontendMessage;
3+
use crate::connection::RequestMessages;
4+
use crate::Error;
5+
use bytes::{Buf, BufMut, BytesMut, IntoBuf};
6+
use futures::channel::mpsc;
7+
use futures::ready;
8+
use futures::{SinkExt, Stream, StreamExt, TryStream, TryStreamExt};
9+
use pin_utils::pin_mut;
10+
use postgres_protocol::message::backend::Message;
11+
use postgres_protocol::message::frontend;
12+
use std::error;
13+
use std::pin::Pin;
14+
use std::sync::Arc;
15+
use std::task::{Context, Poll};
16+
use postgres_protocol::message::frontend::CopyData;
17+
18+
enum CopyInMessage {
19+
Message(FrontendMessage),
20+
Done,
21+
}
22+
23+
pub struct CopyInReceiver {
24+
receiver: mpsc::Receiver<CopyInMessage>,
25+
done: bool,
26+
}
27+
28+
impl CopyInReceiver {
29+
fn new(receiver: mpsc::Receiver<CopyInMessage>) -> CopyInReceiver {
30+
CopyInReceiver {
31+
receiver,
32+
done: false,
33+
}
34+
}
35+
}
36+
37+
impl Stream for CopyInReceiver {
38+
type Item = FrontendMessage;
39+
40+
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
41+
if self.done {
42+
return Poll::Ready(None);
43+
}
44+
45+
match ready!(self.receiver.poll_next_unpin(cx)) {
46+
Some(CopyInMessage::Message(message)) => Poll::Ready(Some(message)),
47+
Some(CopyInMessage::Done) => {
48+
self.done = true;
49+
let mut buf = vec![];
50+
frontend::copy_done(&mut buf);
51+
frontend::sync(&mut buf);
52+
Poll::Ready(Some(FrontendMessage::Raw(buf)))
53+
}
54+
None => {
55+
self.done = true;
56+
let mut buf = vec![];
57+
frontend::copy_fail("", &mut buf).unwrap();
58+
frontend::sync(&mut buf);
59+
Poll::Ready(Some(FrontendMessage::Raw(buf)))
60+
}
61+
}
62+
}
63+
}
64+
65+
pub async fn copy_in<S>(
66+
client: Arc<InnerClient>,
67+
buf: Result<Vec<u8>, Error>,
68+
stream: S,
69+
) -> Result<u64, Error>
70+
where
71+
S: TryStream,
72+
S::Ok: IntoBuf,
73+
<S::Ok as IntoBuf>::Buf: 'static + Send,
74+
S::Error: Into<Box<dyn error::Error + Sync + Send>>,
75+
{
76+
let buf = buf?;
77+
78+
let (mut sender, receiver) = mpsc::channel(1);
79+
let receiver = CopyInReceiver::new(receiver);
80+
let mut responses = client.send(RequestMessages::CopyIn(receiver))?;
81+
82+
sender
83+
.send(CopyInMessage::Message(FrontendMessage::Raw(buf)))
84+
.await
85+
.map_err(|_| Error::closed())?;
86+
87+
match responses.next().await? {
88+
Message::BindComplete => {}
89+
_ => return Err(Error::unexpected_message()),
90+
}
91+
92+
match responses.next().await? {
93+
Message::CopyInResponse(_) => {}
94+
_ => return Err(Error::unexpected_message()),
95+
}
96+
97+
let mut bytes = BytesMut::new();
98+
let stream = stream.into_stream();
99+
pin_mut!(stream);
100+
101+
while let Some(buf) = stream.try_next().await.map_err(Error::copy_in_stream)? {
102+
let buf = buf.into_buf();
103+
104+
let data: Box<dyn Buf + Send> = if buf.remaining() > 4096 {
105+
if bytes.is_empty() {
106+
Box::new(buf)
107+
} else {
108+
Box::new(bytes.take().freeze().into_buf().chain(buf))
109+
}
110+
} else {
111+
bytes.reserve(buf.remaining());
112+
bytes.put(buf);
113+
if bytes.len() > 4096 {
114+
Box::new(bytes.take().freeze().into_buf())
115+
} else {
116+
continue;
117+
}
118+
};
119+
120+
let data = CopyData::new(data).map_err(Error::encode)?;
121+
sender
122+
.send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
123+
.await
124+
.map_err(|_| Error::closed())?;
125+
}
126+
127+
if !bytes.is_empty() {
128+
let data: Box<dyn Buf + Send> = Box::new(bytes.freeze().into_buf());
129+
let data = CopyData::new(data).map_err(Error::encode)?;
130+
sender
131+
.send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
132+
.await
133+
.map_err(|_| Error::closed())?;
134+
}
135+
136+
sender
137+
.send(CopyInMessage::Done)
138+
.await
139+
.map_err(|_| Error::closed())?;
140+
141+
match responses.next().await? {
142+
Message::CommandComplete(body) => {
143+
let rows = body
144+
.tag()
145+
.map_err(Error::parse)?
146+
.rsplit(' ')
147+
.next()
148+
.unwrap()
149+
.parse()
150+
.unwrap_or(0);
151+
Ok(rows)
152+
}
153+
_ => Err(Error::unexpected_message()),
154+
}
155+
}

tokio-postgres/src/lib.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
//! use tokio_postgres::{NoTls, Error, Row};
1010
//!
1111
//! # #[cfg(feature = "runtime")]
12-
//! #[tokio::main]
12+
//! #[tokio::main] // By default, tokio_postgres uses the tokio crate as its runtime.
1313
//! async fn main() -> Result<(), Error> {
1414
//! // Connect to the database.
15-
//! let (mut client, connection) = tokio_postgres::connect("host=localhost user=postgres", NoTls).await?;
15+
//! let (mut client, connection) =
16+
//! tokio_postgres::connect("host=localhost user=postgres", NoTls).await?;
1617
//!
1718
//! // The connection object performs the actual communication with the database,
1819
//! // so spawn it off to run on its own.
@@ -108,7 +109,6 @@
108109

109110
pub use crate::client::Client;
110111
pub use crate::config::Config;
111-
pub use crate::transaction::Transaction;
112112
pub use crate::connection::Connection;
113113
use crate::error::DbError;
114114
pub use crate::error::Error;
@@ -118,6 +118,7 @@ pub use crate::socket::Socket;
118118
#[cfg(feature = "runtime")]
119119
use crate::tls::MakeTlsConnect;
120120
pub use crate::tls::NoTls;
121+
pub use crate::transaction::Transaction;
121122
pub use statement::{Column, Statement};
122123

123124
#[cfg(feature = "runtime")]
@@ -133,6 +134,7 @@ mod connect_raw;
133134
mod connect_socket;
134135
mod connect_tls;
135136
mod connection;
137+
mod copy_in;
136138
pub mod error;
137139
mod maybe_tls_stream;
138140
mod prepare;
@@ -142,8 +144,8 @@ mod simple_query;
142144
#[cfg(feature = "runtime")]
143145
mod socket;
144146
mod statement;
145-
mod transaction;
146147
pub mod tls;
148+
mod transaction;
147149
pub mod types;
148150

149151
/// A convenience function which parses a connection string and connects to the database.

tokio-postgres/src/transaction.rs

+19-1
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ use crate::types::{ToSql, Type};
77
#[cfg(feature = "runtime")]
88
use crate::Socket;
99
use crate::{query, Client, Error, Row, SimpleQueryMessage, Statement};
10-
use futures::Stream;
10+
use bytes::IntoBuf;
11+
use futures::{Stream, TryStream};
1112
use postgres_protocol::message::frontend;
13+
use std::error;
1214
use std::future::Future;
1315
use tokio::io::{AsyncRead, AsyncWrite};
1416

@@ -120,6 +122,22 @@ impl<'a> Transaction<'a> {
120122
query::execute(self.client.inner(), buf)
121123
}
122124

125+
/// Like `Client::copy_in`.
126+
pub fn copy_in<S>(
127+
&mut self,
128+
statement: &Statement,
129+
params: &[&dyn ToSql],
130+
stream: S,
131+
) -> impl Future<Output = Result<u64, Error>>
132+
where
133+
S: TryStream,
134+
S::Ok: IntoBuf,
135+
<S::Ok as IntoBuf>::Buf: 'static + Send,
136+
S::Error: Into<Box<dyn error::Error + Sync + Send>>,
137+
{
138+
self.client.copy_in(statement, params, stream)
139+
}
140+
123141
/// Like `Client::simple_query`.
124142
pub fn simple_query(
125143
&mut self,

0 commit comments

Comments
 (0)