Skip to content

Commit 5517719

Browse files
committed
Binary copy out support
1 parent a94127c commit 5517719

File tree

3 files changed

+222
-20
lines changed

3 files changed

+222
-20
lines changed

tokio-postgres-binary-copy/Cargo.toml

-1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,3 @@ tokio-postgres = { version = "=0.5.0-alpha.1", default-features = false, path =
1414
[dev-dependencies]
1515
tokio = "=0.2.0-alpha.6"
1616
tokio-postgres = { version = "=0.5.0-alpha.1", path = "../tokio-postgres" }
17-

tokio-postgres-binary-copy/src/lib.rs

+148-15
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,63 @@
1-
use bytes::{BigEndian, BufMut, ByteOrder, Bytes, BytesMut};
2-
use futures::{future, Stream};
1+
use bytes::{BigEndian, BufMut, ByteOrder, Bytes, BytesMut, Buf};
2+
use futures::{future, ready, Stream};
33
use parking_lot::Mutex;
44
use pin_project_lite::pin_project;
55
use std::convert::TryFrom;
66
use std::error::Error;
77
use std::future::Future;
8+
use std::ops::Range;
89
use std::pin::Pin;
910
use std::sync::Arc;
1011
use std::task::{Context, Poll};
11-
use tokio_postgres::types::{IsNull, ToSql, Type};
12+
use tokio_postgres::types::{IsNull, ToSql, Type, FromSql, WrongType};
13+
use tokio_postgres::CopyStream;
14+
use std::io::Cursor;
1215

1316
#[cfg(test)]
1417
mod test;
1518

1619
const BLOCK_SIZE: usize = 4096;
20+
const MAGIC: &[u8] = b"PGCOPY\n\xff\r\n\0";
21+
const HEADER_LEN: usize = MAGIC.len() + 4 + 4;
1722

1823
pin_project! {
19-
pub struct BinaryCopyStream<F> {
24+
pub struct BinaryCopyInStream<F> {
2025
#[pin]
2126
future: F,
2227
buf: Arc<Mutex<BytesMut>>,
2328
done: bool,
2429
}
2530
}
2631

27-
impl<F> BinaryCopyStream<F>
32+
impl<F> BinaryCopyInStream<F>
2833
where
2934
F: Future<Output = Result<(), Box<dyn Error + Sync + Send>>>,
3035
{
31-
pub fn new<M>(types: &[Type], write_values: M) -> BinaryCopyStream<F>
36+
pub fn new<M>(types: &[Type], write_values: M) -> BinaryCopyInStream<F>
3237
where
33-
M: FnOnce(BinaryCopyWriter) -> F,
38+
M: FnOnce(BinaryCopyInWriter) -> F,
3439
{
3540
let mut buf = BytesMut::new();
36-
buf.reserve(11 + 4 + 4);
37-
buf.put_slice(b"PGCOPY\n\xff\r\n\0"); // magic
41+
buf.reserve(HEADER_LEN);
42+
buf.put_slice(MAGIC); // magic
3843
buf.put_i32_be(0); // flags
3944
buf.put_i32_be(0); // header extension
4045

4146
let buf = Arc::new(Mutex::new(buf));
42-
let writer = BinaryCopyWriter {
47+
let writer = BinaryCopyInWriter {
4348
buf: buf.clone(),
4449
types: types.to_vec(),
4550
};
4651

47-
BinaryCopyStream {
52+
BinaryCopyInStream {
4853
future: write_values(writer),
4954
buf,
5055
done: false,
5156
}
5257
}
5358
}
5459

55-
impl<F> Stream for BinaryCopyStream<F>
60+
impl<F> Stream for BinaryCopyInStream<F>
5661
where
5762
F: Future<Output = Result<(), Box<dyn Error + Sync + Send>>>,
5863
{
@@ -81,12 +86,12 @@ where
8186
}
8287

8388
// FIXME this should really just take a reference to the buffer, but that requires HKT :(
84-
pub struct BinaryCopyWriter {
89+
pub struct BinaryCopyInWriter {
8590
buf: Arc<Mutex<BytesMut>>,
8691
types: Vec<Type>,
8792
}
8893

89-
impl BinaryCopyWriter {
94+
impl BinaryCopyInWriter {
9095
pub async fn write(
9196
&mut self,
9297
values: &[&(dyn ToSql + Send)],
@@ -119,7 +124,7 @@ impl BinaryCopyWriter {
119124
let mut buf = self.buf.lock();
120125

121126
buf.reserve(2);
122-
buf.put_i16_be(self.types.len() as i16);
127+
buf.put_u16_be(self.types.len() as u16);
123128

124129
for (value, type_) in values.zip(&self.types) {
125130
let idx = buf.len();
@@ -135,3 +140,131 @@ impl BinaryCopyWriter {
135140
Ok(())
136141
}
137142
}
143+
144+
struct Header {
145+
has_oids: bool,
146+
}
147+
148+
pin_project! {
149+
pub struct BinaryCopyOutStream {
150+
#[pin]
151+
stream: CopyStream,
152+
types: Arc<Vec<Type>>,
153+
header: Option<Header>,
154+
}
155+
}
156+
157+
impl BinaryCopyOutStream {
158+
pub fn new(types: &[Type], stream: CopyStream) -> BinaryCopyOutStream {
159+
BinaryCopyOutStream {
160+
stream,
161+
types: Arc::new(types.to_vec()),
162+
header: None,
163+
}
164+
}
165+
}
166+
167+
impl Stream for BinaryCopyOutStream {
168+
type Item = Result<BinaryCopyOutRow, Box<dyn Error + Sync + Send>>;
169+
170+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
171+
let this = self.project();
172+
173+
let chunk = match ready!(this.stream.poll_next(cx)) {
174+
Some(Ok(chunk)) => chunk,
175+
Some(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
176+
None => return Poll::Ready(Some(Err("unexpected EOF".into()))),
177+
};
178+
let mut chunk= Cursor::new(chunk);
179+
180+
let has_oids = match &this.header {
181+
Some(header) => header.has_oids,
182+
None => {
183+
check_remaining(&chunk, HEADER_LEN)?;
184+
if &chunk.bytes()[..MAGIC.len()] != MAGIC {
185+
return Poll::Ready(Some(Err("invalid magic value".into())));
186+
}
187+
chunk.advance(MAGIC.len());
188+
189+
let flags = chunk.get_i32_be();
190+
let has_oids = (flags & (1 << 16)) != 0;
191+
192+
let header_extension = chunk.get_u32_be() as usize;
193+
check_remaining(&chunk, header_extension)?;
194+
chunk.advance(header_extension);
195+
196+
*this.header = Some(Header { has_oids });
197+
has_oids
198+
}
199+
};
200+
201+
check_remaining(&chunk, 2)?;
202+
let mut len = chunk.get_i16_be();
203+
if len == -1 {
204+
return Poll::Ready(None);
205+
}
206+
207+
if has_oids {
208+
len += 1;
209+
}
210+
if len as usize != this.types.len() {
211+
return Poll::Ready(Some(Err("unexpected tuple size".into())));
212+
}
213+
214+
let mut ranges = vec![];
215+
for _ in 0..len {
216+
check_remaining(&chunk, 4)?;
217+
let len = chunk.get_i32_be();
218+
if len == -1 {
219+
ranges.push(None);
220+
} else {
221+
let len = len as usize;
222+
check_remaining(&chunk, len)?;
223+
let start = chunk.position() as usize;
224+
ranges.push(Some(start..start + len));
225+
chunk.advance(len);
226+
}
227+
}
228+
229+
Poll::Ready(Some(Ok(BinaryCopyOutRow {
230+
buf: chunk.into_inner(),
231+
ranges,
232+
types: this.types.clone(),
233+
})))
234+
}
235+
}
236+
237+
fn check_remaining(buf: &impl Buf, len: usize) -> Result<(), Box<dyn Error + Sync + Send>> {
238+
if buf.remaining() < len {
239+
Err("unexpected EOF".into())
240+
} else {
241+
Ok(())
242+
}
243+
}
244+
245+
pub struct BinaryCopyOutRow {
246+
buf: Bytes,
247+
ranges: Vec<Option<Range<usize>>>,
248+
types: Arc<Vec<Type>>,
249+
}
250+
251+
impl BinaryCopyOutRow {
252+
pub fn try_get<'a, T>(&'a self, idx: usize) -> Result<T, Box<dyn Error + Sync + Send>> where T: FromSql<'a> {
253+
let type_ = &self.types[idx];
254+
if !T::accepts(type_) {
255+
return Err(WrongType::new::<T>(type_.clone()).into());
256+
}
257+
258+
match &self.ranges[idx] {
259+
Some(range) => T::from_sql(type_, &self.buf[range.clone()]).map_err(Into::into),
260+
None => T::from_sql_null(type_).map_err(Into::into)
261+
}
262+
}
263+
264+
pub fn get<'a, T>(&'a self, idx: usize) -> T where T: FromSql<'a> {
265+
match self.try_get(idx) {
266+
Ok(value) => value,
267+
Err(e) => panic!("error retrieving column {}: {}", idx, e),
268+
}
269+
}
270+
}

tokio-postgres-binary-copy/src/test.rs

+74-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
use crate::BinaryCopyStream;
1+
use crate::{BinaryCopyInStream, BinaryCopyOutStream};
22
use tokio_postgres::types::Type;
33
use tokio_postgres::{Client, NoTls};
4+
use futures::TryStreamExt;
45

56
async fn connect() -> Client {
67
let (client, connection) =
@@ -22,7 +23,7 @@ async fn write_basic() {
2223
.await
2324
.unwrap();
2425

25-
let stream = BinaryCopyStream::new(&[Type::INT4, Type::TEXT], |mut w| {
26+
let stream = BinaryCopyInStream::new(&[Type::INT4, Type::TEXT], |mut w| {
2627
async move {
2728
w.write(&[&1i32, &"foobar"]).await?;
2829
w.write(&[&2i32, &None::<&str>]).await?;
@@ -56,7 +57,7 @@ async fn write_many_rows() {
5657
.await
5758
.unwrap();
5859

59-
let stream = BinaryCopyStream::new(&[Type::INT4, Type::TEXT], |mut w| {
60+
let stream = BinaryCopyInStream::new(&[Type::INT4, Type::TEXT], |mut w| {
6061
async move {
6162
for i in 0..10_000i32 {
6263
w.write(&[&i, &format!("the value for {}", i)]).await?;
@@ -90,7 +91,7 @@ async fn write_big_rows() {
9091
.await
9192
.unwrap();
9293

93-
let stream = BinaryCopyStream::new(&[Type::INT4, Type::BYTEA], |mut w| {
94+
let stream = BinaryCopyInStream::new(&[Type::INT4, Type::BYTEA], |mut w| {
9495
async move {
9596
for i in 0..2i32 {
9697
w.write(&[&i, &vec![i as u8; 128 * 1024]]).await?;
@@ -114,3 +115,72 @@ async fn write_big_rows() {
114115
assert_eq!(row.get::<_, &[u8]>(1), &*vec![i as u8; 128 * 1024]);
115116
}
116117
}
118+
119+
#[tokio::test]
120+
async fn read_basic() {
121+
let client = connect().await;
122+
123+
client
124+
.batch_execute(
125+
"
126+
CREATE TEMPORARY TABLE foo (id INT, bar TEXT);
127+
INSERT INTO foo (id, bar) VALUES (1, 'foobar'), (2, NULL);
128+
"
129+
)
130+
.await
131+
.unwrap();
132+
133+
let stream = client.copy_out("COPY foo (id, bar) TO STDIN BINARY", &[]).await.unwrap();
134+
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::TEXT], stream).try_collect::<Vec<_>>().await.unwrap();
135+
assert_eq!(rows.len(), 2);
136+
137+
assert_eq!(rows[0].get::<i32>(0), 1);
138+
assert_eq!(rows[0].get::<Option<&str>>(1), Some("foobar"));
139+
assert_eq!(rows[1].get::<i32>(0), 2);
140+
assert_eq!(rows[1].get::<Option<&str>>(1), None);
141+
}
142+
143+
#[tokio::test]
144+
async fn read_many_rows() {
145+
let client = connect().await;
146+
147+
client
148+
.batch_execute(
149+
"
150+
CREATE TEMPORARY TABLE foo (id INT, bar TEXT);
151+
INSERT INTO foo (id, bar) SELECT i, 'the value for ' || i FROM generate_series(0, 9999) i;"
152+
)
153+
.await
154+
.unwrap();
155+
156+
let stream = client.copy_out("COPY foo (id, bar) TO STDIN BINARY", &[]).await.unwrap();
157+
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::TEXT], stream).try_collect::<Vec<_>>().await.unwrap();
158+
assert_eq!(rows.len(), 10_000);
159+
160+
for (i, row) in rows.iter().enumerate() {
161+
assert_eq!(row.get::<i32>(0), i as i32);
162+
assert_eq!(row.get::<&str>(1), format!("the value for {}", i));
163+
}
164+
}
165+
166+
#[tokio::test]
167+
async fn read_big_rows() {
168+
let client = connect().await;
169+
170+
client
171+
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)")
172+
.await
173+
.unwrap();
174+
for i in 0..2i32 {
175+
client.execute("INSERT INTO foo (id, bar) VALUES ($1, $2)", &[&i, &vec![i as u8; 128 * 1024]]).await.unwrap();
176+
}
177+
178+
let stream = client.copy_out("COPY foo (id, bar) TO STDIN BINARY", &[]).await.unwrap();
179+
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::BYTEA], stream).try_collect::<Vec<_>>().await.unwrap();
180+
assert_eq!(rows.len(), 2);
181+
182+
for (i, row) in rows.iter().enumerate() {
183+
assert_eq!(row.get::<i32>(0), i as i32);
184+
assert_eq!(row.get::<&[u8]>(1), &vec![i as u8; 128 * 1024][..]);
185+
}
186+
}

0 commit comments

Comments
 (0)