Skip to content

Commit e5e03b0

Browse files
committed
Change the copy_in interface
Rather than taking in a Stream and advancing it internally, return a Sink that can be advanced by the calling code. This significantly simplifies encoding logic for things like tokio-postgres-binary-copy. Similarly, the blocking interface returns a Writer. Closes sfackler#489
1 parent a5428e6 commit e5e03b0

File tree

16 files changed

+367
-335
lines changed

16 files changed

+367
-335
lines changed

postgres/src/client.rs

+16-16
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,15 @@
1+
use crate::iter::Iter;
2+
#[cfg(feature = "runtime")]
3+
use crate::Config;
4+
use crate::{CopyInWriter, CopyOutReader, Statement, ToStatement, Transaction};
15
use fallible_iterator::FallibleIterator;
26
use futures::executor;
3-
use std::io::{BufRead, Read};
47
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
58
use tokio_postgres::types::{ToSql, Type};
69
#[cfg(feature = "runtime")]
710
use tokio_postgres::Socket;
811
use tokio_postgres::{Error, Row, SimpleQueryMessage};
912

10-
use crate::copy_in_stream::CopyInStream;
11-
use crate::copy_out_reader::CopyOutReader;
12-
use crate::iter::Iter;
13-
#[cfg(feature = "runtime")]
14-
use crate::Config;
15-
use crate::{Statement, ToStatement, Transaction};
16-
1713
/// A synchronous PostgreSQL client.
1814
///
1915
/// This is a lightweight wrapper over the asynchronous tokio_postgres `Client`.
@@ -264,29 +260,33 @@ impl Client {
264260
/// The `query` argument can either be a `Statement`, or a raw query string. The data in the provided reader is
265261
/// passed along to the server verbatim; it is the caller's responsibility to ensure it uses the proper format.
266262
///
263+
/// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted.
264+
///
267265
/// # Examples
268266
///
269267
/// ```no_run
270268
/// use postgres::{Client, NoTls};
269+
/// use std::io::Write;
271270
///
272-
/// # fn main() -> Result<(), postgres::Error> {
271+
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
273272
/// let mut client = Client::connect("host=localhost user=postgres", NoTls)?;
274273
///
275-
/// client.copy_in("COPY people FROM stdin", &[], &mut "1\tjohn\n2\tjane\n".as_bytes())?;
274+
/// let mut writer = client.copy_in("COPY people FROM stdin", &[])?;
275+
/// writer.write_all(b"1\tjohn\n2\tjane\n")?;
276+
/// writer.finish()?;
276277
/// # Ok(())
277278
/// # }
278279
/// ```
279-
pub fn copy_in<T, R>(
280+
pub fn copy_in<T>(
280281
&mut self,
281282
query: &T,
282283
params: &[&(dyn ToSql + Sync)],
283-
reader: R,
284-
) -> Result<u64, Error>
284+
) -> Result<CopyInWriter<'_>, Error>
285285
where
286286
T: ?Sized + ToStatement,
287-
R: Read + Unpin,
288287
{
289-
executor::block_on(self.0.copy_in(query, params, CopyInStream(reader)))
288+
let sink = executor::block_on(self.0.copy_in(query, params))?;
289+
Ok(CopyInWriter::new(sink))
290290
}
291291

292292
/// Executes a `COPY TO STDOUT` statement, returning a reader of the resulting data.
@@ -312,7 +312,7 @@ impl Client {
312312
&mut self,
313313
query: &T,
314314
params: &[&(dyn ToSql + Sync)],
315-
) -> Result<impl BufRead, Error>
315+
) -> Result<CopyOutReader<'_>, Error>
316316
where
317317
T: ?Sized + ToStatement,
318318
{

postgres/src/copy_in_stream.rs

-24
This file was deleted.

postgres/src/copy_in_writer.rs

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
use bytes::{Bytes, BytesMut};
2+
use futures::{executor, SinkExt};
3+
use std::io;
4+
use std::io::Write;
5+
use std::marker::PhantomData;
6+
use std::pin::Pin;
7+
use tokio_postgres::{CopyInSink, Error};
8+
9+
/// The writer returned by the `copy_in` method.
10+
///
11+
/// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted.
12+
pub struct CopyInWriter<'a> {
13+
sink: Pin<Box<CopyInSink<Bytes>>>,
14+
buf: BytesMut,
15+
_p: PhantomData<&'a mut ()>,
16+
}
17+
18+
// no-op impl to extend borrow until drop
19+
impl Drop for CopyInWriter<'_> {
20+
fn drop(&mut self) {}
21+
}
22+
23+
impl<'a> CopyInWriter<'a> {
24+
pub(crate) fn new(sink: CopyInSink<Bytes>) -> CopyInWriter<'a> {
25+
CopyInWriter {
26+
sink: Box::pin(sink),
27+
buf: BytesMut::new(),
28+
_p: PhantomData,
29+
}
30+
}
31+
32+
/// Completes the copy, returning the number of rows written.
33+
///
34+
/// If this is not called, the copy will be aborted.
35+
pub fn finish(mut self) -> Result<u64, Error> {
36+
self.flush_inner()?;
37+
executor::block_on(self.sink.as_mut().finish())
38+
}
39+
40+
fn flush_inner(&mut self) -> Result<(), Error> {
41+
if self.buf.is_empty() {
42+
return Ok(());
43+
}
44+
45+
executor::block_on(self.sink.as_mut().send(self.buf.split().freeze()))
46+
}
47+
}
48+
49+
impl Write for CopyInWriter<'_> {
50+
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
51+
if self.buf.len() > 4096 {
52+
self.flush()?;
53+
}
54+
55+
self.buf.extend_from_slice(buf);
56+
Ok(buf.len())
57+
}
58+
59+
fn flush(&mut self) -> io::Result<()> {
60+
self.flush_inner()
61+
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
62+
}
63+
}

postgres/src/copy_out_reader.rs

+9-24
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,24 @@
11
use bytes::{Buf, Bytes};
2-
use futures::{executor, Stream};
2+
use futures::executor;
33
use std::io::{self, BufRead, Cursor, Read};
44
use std::marker::PhantomData;
55
use std::pin::Pin;
6-
use tokio_postgres::Error;
6+
use tokio_postgres::{CopyStream, Error};
77

88
/// The reader returned by the `copy_out` method.
9-
pub struct CopyOutReader<'a, S>
10-
where
11-
S: Stream,
12-
{
13-
it: executor::BlockingStream<Pin<Box<S>>>,
9+
pub struct CopyOutReader<'a> {
10+
it: executor::BlockingStream<Pin<Box<CopyStream>>>,
1411
cur: Cursor<Bytes>,
1512
_p: PhantomData<&'a mut ()>,
1613
}
1714

1815
// no-op impl to extend borrow until drop
19-
impl<'a, S> Drop for CopyOutReader<'a, S>
20-
where
21-
S: Stream,
22-
{
16+
impl Drop for CopyOutReader<'_> {
2317
fn drop(&mut self) {}
2418
}
2519

26-
impl<'a, S> CopyOutReader<'a, S>
27-
where
28-
S: Stream<Item = Result<Bytes, Error>>,
29-
{
30-
pub(crate) fn new(stream: S) -> Result<CopyOutReader<'a, S>, Error> {
20+
impl<'a> CopyOutReader<'a> {
21+
pub(crate) fn new(stream: CopyStream) -> Result<CopyOutReader<'a>, Error> {
3122
let mut it = executor::block_on_stream(Box::pin(stream));
3223
let cur = match it.next() {
3324
Some(Ok(cur)) => cur,
@@ -43,10 +34,7 @@ where
4334
}
4435
}
4536

46-
impl<'a, S> Read for CopyOutReader<'a, S>
47-
where
48-
S: Stream<Item = Result<Bytes, Error>>,
49-
{
37+
impl Read for CopyOutReader<'_> {
5038
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
5139
let b = self.fill_buf()?;
5240
let len = usize::min(buf.len(), b.len());
@@ -56,10 +44,7 @@ where
5644
}
5745
}
5846

59-
impl<'a, S> BufRead for CopyOutReader<'a, S>
60-
where
61-
S: Stream<Item = Result<Bytes, Error>>,
62-
{
47+
impl BufRead for CopyOutReader<'_> {
6348
fn fill_buf(&mut self) -> io::Result<&[u8]> {
6449
if self.cur.remaining() == 0 {
6550
match self.it.next() {

postgres/src/lib.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ pub use tokio_postgres::{
6969
pub use crate::client::*;
7070
#[cfg(feature = "runtime")]
7171
pub use crate::config::Config;
72+
pub use crate::copy_in_writer::CopyInWriter;
73+
pub use crate::copy_out_reader::CopyOutReader;
7274
#[doc(no_inline)]
7375
pub use crate::error::Error;
7476
#[doc(no_inline)]
@@ -80,7 +82,7 @@ pub use crate::transaction::*;
8082
mod client;
8183
#[cfg(feature = "runtime")]
8284
pub mod config;
83-
mod copy_in_stream;
85+
mod copy_in_writer;
8486
mod copy_out_reader;
8587
mod iter;
8688
mod transaction;

postgres/src/test.rs

+23-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::io::Read;
1+
use std::io::{Read, Write};
22
use tokio_postgres::types::Type;
33
use tokio_postgres::NoTls;
44

@@ -154,13 +154,9 @@ fn copy_in() {
154154
.simple_query("CREATE TEMPORARY TABLE foo (id INT, name TEXT)")
155155
.unwrap();
156156

157-
client
158-
.copy_in(
159-
"COPY foo FROM stdin",
160-
&[],
161-
&mut &b"1\tsteven\n2\ttimothy"[..],
162-
)
163-
.unwrap();
157+
let mut writer = client.copy_in("COPY foo FROM stdin", &[]).unwrap();
158+
writer.write_all(b"1\tsteven\n2\ttimothy").unwrap();
159+
writer.finish().unwrap();
164160

165161
let rows = client
166162
.query("SELECT id, name FROM foo ORDER BY id", &[])
@@ -173,6 +169,25 @@ fn copy_in() {
173169
assert_eq!(rows[1].get::<_, &str>(1), "timothy");
174170
}
175171

172+
#[test]
173+
fn copy_in_abort() {
174+
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
175+
176+
client
177+
.simple_query("CREATE TEMPORARY TABLE foo (id INT, name TEXT)")
178+
.unwrap();
179+
180+
let mut writer = client.copy_in("COPY foo FROM stdin", &[]).unwrap();
181+
writer.write_all(b"1\tsteven\n2\ttimothy").unwrap();
182+
drop(writer);
183+
184+
let rows = client
185+
.query("SELECT id, name FROM foo ORDER BY id", &[])
186+
.unwrap();
187+
188+
assert_eq!(rows.len(), 0);
189+
}
190+
176191
#[test]
177192
fn copy_out() {
178193
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();

postgres/src/transaction.rs

+7-12
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
1+
use crate::iter::Iter;
2+
use crate::{CopyInWriter, CopyOutReader, Portal, Statement, ToStatement};
13
use fallible_iterator::FallibleIterator;
24
use futures::executor;
3-
use std::io::{BufRead, Read};
45
use tokio_postgres::types::{ToSql, Type};
56
use tokio_postgres::{Error, Row, SimpleQueryMessage};
67

7-
use crate::copy_in_stream::CopyInStream;
8-
use crate::copy_out_reader::CopyOutReader;
9-
use crate::iter::Iter;
10-
use crate::{Portal, Statement, ToStatement};
11-
128
/// A representation of a PostgreSQL database transaction.
139
///
1410
/// Transactions will implicitly roll back by default when dropped. Use the `commit` method to commit the changes made
@@ -117,25 +113,24 @@ impl<'a> Transaction<'a> {
117113
}
118114

119115
/// Like `Client::copy_in`.
120-
pub fn copy_in<T, R>(
116+
pub fn copy_in<T>(
121117
&mut self,
122118
query: &T,
123119
params: &[&(dyn ToSql + Sync)],
124-
reader: R,
125-
) -> Result<u64, Error>
120+
) -> Result<CopyInWriter<'_>, Error>
126121
where
127122
T: ?Sized + ToStatement,
128-
R: Read + Unpin,
129123
{
130-
executor::block_on(self.0.copy_in(query, params, CopyInStream(reader)))
124+
let sink = executor::block_on(self.0.copy_in(query, params))?;
125+
Ok(CopyInWriter::new(sink))
131126
}
132127

133128
/// Like `Client::copy_out`.
134129
pub fn copy_out<T>(
135130
&mut self,
136131
query: &T,
137132
params: &[&(dyn ToSql + Sync)],
138-
) -> Result<impl BufRead, Error>
133+
) -> Result<CopyOutReader<'_>, Error>
139134
where
140135
T: ?Sized + ToStatement,
141136
{

tokio-postgres-binary-copy/Cargo.toml

-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ edition = "2018"
88
byteorder = "1.0"
99
bytes = "0.5"
1010
futures = "0.3"
11-
parking_lot = "0.10"
1211
pin-project-lite = "0.1"
1312
tokio-postgres = { version = "=0.5.0-alpha.2", default-features = false, path = "../tokio-postgres" }
1413

0 commit comments

Comments
 (0)