1
- use bytes:: { BigEndian , BufMut , ByteOrder , Bytes , BytesMut } ;
2
- use futures:: { future, Stream } ;
1
+ use bytes:: { BigEndian , BufMut , ByteOrder , Bytes , BytesMut , Buf } ;
2
+ use futures:: { future, ready , Stream } ;
3
3
use parking_lot:: Mutex ;
4
4
use pin_project_lite:: pin_project;
5
5
use std:: convert:: TryFrom ;
6
6
use std:: error:: Error ;
7
7
use std:: future:: Future ;
8
+ use std:: ops:: Range ;
8
9
use std:: pin:: Pin ;
9
10
use std:: sync:: Arc ;
10
11
use std:: task:: { Context , Poll } ;
11
- use tokio_postgres:: types:: { IsNull , ToSql , Type } ;
12
+ use tokio_postgres:: types:: { IsNull , ToSql , Type , FromSql , WrongType } ;
13
+ use tokio_postgres:: CopyStream ;
14
+ use std:: io:: Cursor ;
12
15
13
16
#[ cfg( test) ]
14
17
mod test;
15
18
16
19
const BLOCK_SIZE : usize = 4096 ;
20
+ const MAGIC : & [ u8 ] = b"PGCOPY\n \xff \r \n \0 " ;
21
+ const HEADER_LEN : usize = MAGIC . len ( ) + 4 + 4 ;
17
22
18
23
pin_project ! {
19
- pub struct BinaryCopyStream <F > {
24
+ pub struct BinaryCopyInStream <F > {
20
25
#[ pin]
21
26
future: F ,
22
27
buf: Arc <Mutex <BytesMut >>,
23
28
done: bool ,
24
29
}
25
30
}
26
31
27
- impl < F > BinaryCopyStream < F >
32
+ impl < F > BinaryCopyInStream < F >
28
33
where
29
34
F : Future < Output = Result < ( ) , Box < dyn Error + Sync + Send > > > ,
30
35
{
31
- pub fn new < M > ( types : & [ Type ] , write_values : M ) -> BinaryCopyStream < F >
36
+ pub fn new < M > ( types : & [ Type ] , write_values : M ) -> BinaryCopyInStream < F >
32
37
where
33
- M : FnOnce ( BinaryCopyWriter ) -> F ,
38
+ M : FnOnce ( BinaryCopyInWriter ) -> F ,
34
39
{
35
40
let mut buf = BytesMut :: new ( ) ;
36
- buf. reserve ( 11 + 4 + 4 ) ;
37
- buf. put_slice ( b"PGCOPY \n \xff \r \n \0 " ) ; // magic
41
+ buf. reserve ( HEADER_LEN ) ;
42
+ buf. put_slice ( MAGIC ) ; // magic
38
43
buf. put_i32_be ( 0 ) ; // flags
39
44
buf. put_i32_be ( 0 ) ; // header extension
40
45
41
46
let buf = Arc :: new ( Mutex :: new ( buf) ) ;
42
- let writer = BinaryCopyWriter {
47
+ let writer = BinaryCopyInWriter {
43
48
buf : buf. clone ( ) ,
44
49
types : types. to_vec ( ) ,
45
50
} ;
46
51
47
- BinaryCopyStream {
52
+ BinaryCopyInStream {
48
53
future : write_values ( writer) ,
49
54
buf,
50
55
done : false ,
51
56
}
52
57
}
53
58
}
54
59
55
- impl < F > Stream for BinaryCopyStream < F >
60
+ impl < F > Stream for BinaryCopyInStream < F >
56
61
where
57
62
F : Future < Output = Result < ( ) , Box < dyn Error + Sync + Send > > > ,
58
63
{
@@ -81,12 +86,12 @@ where
81
86
}
82
87
83
88
// FIXME this should really just take a reference to the buffer, but that requires HKT :(
84
- pub struct BinaryCopyWriter {
89
+ pub struct BinaryCopyInWriter {
85
90
buf : Arc < Mutex < BytesMut > > ,
86
91
types : Vec < Type > ,
87
92
}
88
93
89
- impl BinaryCopyWriter {
94
+ impl BinaryCopyInWriter {
90
95
pub async fn write (
91
96
& mut self ,
92
97
values : & [ & ( dyn ToSql + Send ) ] ,
@@ -119,7 +124,7 @@ impl BinaryCopyWriter {
119
124
let mut buf = self . buf . lock ( ) ;
120
125
121
126
buf. reserve ( 2 ) ;
122
- buf. put_i16_be ( self . types . len ( ) as i16 ) ;
127
+ buf. put_u16_be ( self . types . len ( ) as u16 ) ;
123
128
124
129
for ( value, type_) in values. zip ( & self . types ) {
125
130
let idx = buf. len ( ) ;
@@ -135,3 +140,131 @@ impl BinaryCopyWriter {
135
140
Ok ( ( ) )
136
141
}
137
142
}
143
+
144
+ struct Header {
145
+ has_oids : bool ,
146
+ }
147
+
148
+ pin_project ! {
149
+ pub struct BinaryCopyOutStream {
150
+ #[ pin]
151
+ stream: CopyStream ,
152
+ types: Arc <Vec <Type >>,
153
+ header: Option <Header >,
154
+ }
155
+ }
156
+
157
+ impl BinaryCopyOutStream {
158
+ pub fn new ( types : & [ Type ] , stream : CopyStream ) -> BinaryCopyOutStream {
159
+ BinaryCopyOutStream {
160
+ stream,
161
+ types : Arc :: new ( types. to_vec ( ) ) ,
162
+ header : None ,
163
+ }
164
+ }
165
+ }
166
+
167
+ impl Stream for BinaryCopyOutStream {
168
+ type Item = Result < BinaryCopyOutRow , Box < dyn Error + Sync + Send > > ;
169
+
170
+ fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
171
+ let this = self . project ( ) ;
172
+
173
+ let chunk = match ready ! ( this. stream. poll_next( cx) ) {
174
+ Some ( Ok ( chunk) ) => chunk,
175
+ Some ( Err ( e) ) => return Poll :: Ready ( Some ( Err ( e. into ( ) ) ) ) ,
176
+ None => return Poll :: Ready ( Some ( Err ( "unexpected EOF" . into ( ) ) ) ) ,
177
+ } ;
178
+ let mut chunk= Cursor :: new ( chunk) ;
179
+
180
+ let has_oids = match & this. header {
181
+ Some ( header) => header. has_oids ,
182
+ None => {
183
+ check_remaining ( & chunk, HEADER_LEN ) ?;
184
+ if & chunk. bytes ( ) [ ..MAGIC . len ( ) ] != MAGIC {
185
+ return Poll :: Ready ( Some ( Err ( "invalid magic value" . into ( ) ) ) ) ;
186
+ }
187
+ chunk. advance ( MAGIC . len ( ) ) ;
188
+
189
+ let flags = chunk. get_i32_be ( ) ;
190
+ let has_oids = ( flags & ( 1 << 16 ) ) != 0 ;
191
+
192
+ let header_extension = chunk. get_u32_be ( ) as usize ;
193
+ check_remaining ( & chunk, header_extension) ?;
194
+ chunk. advance ( header_extension) ;
195
+
196
+ * this. header = Some ( Header { has_oids } ) ;
197
+ has_oids
198
+ }
199
+ } ;
200
+
201
+ check_remaining ( & chunk, 2 ) ?;
202
+ let mut len = chunk. get_i16_be ( ) ;
203
+ if len == -1 {
204
+ return Poll :: Ready ( None ) ;
205
+ }
206
+
207
+ if has_oids {
208
+ len += 1 ;
209
+ }
210
+ if len as usize != this. types . len ( ) {
211
+ return Poll :: Ready ( Some ( Err ( "unexpected tuple size" . into ( ) ) ) ) ;
212
+ }
213
+
214
+ let mut ranges = vec ! [ ] ;
215
+ for _ in 0 ..len {
216
+ check_remaining ( & chunk, 4 ) ?;
217
+ let len = chunk. get_i32_be ( ) ;
218
+ if len == -1 {
219
+ ranges. push ( None ) ;
220
+ } else {
221
+ let len = len as usize ;
222
+ check_remaining ( & chunk, len) ?;
223
+ let start = chunk. position ( ) as usize ;
224
+ ranges. push ( Some ( start..start + len) ) ;
225
+ chunk. advance ( len) ;
226
+ }
227
+ }
228
+
229
+ Poll :: Ready ( Some ( Ok ( BinaryCopyOutRow {
230
+ buf : chunk. into_inner ( ) ,
231
+ ranges,
232
+ types : this. types . clone ( ) ,
233
+ } ) ) )
234
+ }
235
+ }
236
+
237
+ fn check_remaining ( buf : & impl Buf , len : usize ) -> Result < ( ) , Box < dyn Error + Sync + Send > > {
238
+ if buf. remaining ( ) < len {
239
+ Err ( "unexpected EOF" . into ( ) )
240
+ } else {
241
+ Ok ( ( ) )
242
+ }
243
+ }
244
+
245
+ pub struct BinaryCopyOutRow {
246
+ buf : Bytes ,
247
+ ranges : Vec < Option < Range < usize > > > ,
248
+ types : Arc < Vec < Type > > ,
249
+ }
250
+
251
+ impl BinaryCopyOutRow {
252
+ pub fn try_get < ' a , T > ( & ' a self , idx : usize ) -> Result < T , Box < dyn Error + Sync + Send > > where T : FromSql < ' a > {
253
+ let type_ = & self . types [ idx] ;
254
+ if !T :: accepts ( type_) {
255
+ return Err ( WrongType :: new :: < T > ( type_. clone ( ) ) . into ( ) ) ;
256
+ }
257
+
258
+ match & self . ranges [ idx] {
259
+ Some ( range) => T :: from_sql ( type_, & self . buf [ range. clone ( ) ] ) . map_err ( Into :: into) ,
260
+ None => T :: from_sql_null ( type_) . map_err ( Into :: into)
261
+ }
262
+ }
263
+
264
+ pub fn get < ' a , T > ( & ' a self , idx : usize ) -> T where T : FromSql < ' a > {
265
+ match self . try_get ( idx) {
266
+ Ok ( value) => value,
267
+ Err ( e) => panic ! ( "error retrieving column {}: {}" , idx, e) ,
268
+ }
269
+ }
270
+ }
0 commit comments