3
3
use bytes:: { Bytes , BytesMut } ;
4
4
use futures:: channel:: mpsc;
5
5
use futures:: {
6
- future, join, pin_mut, stream, try_join, FutureExt , SinkExt , StreamExt , TryStreamExt ,
6
+ future, join, pin_mut, stream, try_join, Future , FutureExt , SinkExt , StreamExt , TryStreamExt ,
7
7
} ;
8
+ use pin_project_lite:: pin_project;
8
9
use std:: fmt:: Write ;
10
+ use std:: pin:: Pin ;
11
+ use std:: task:: { Context , Poll } ;
9
12
use std:: time:: Duration ;
10
13
use tokio:: net:: TcpStream ;
11
14
use tokio:: time;
@@ -22,6 +25,35 @@ mod parse;
22
25
mod runtime;
23
26
mod types;
24
27
28
+ pin_project ! {
29
+ /// Polls `F` at most `polls_left` times returning `Some(F::Output)` if
30
+ /// [`Future`] returned [`Poll::Ready`] or [`None`] otherwise.
31
+ struct Cancellable <F > {
32
+ #[ pin]
33
+ fut: F ,
34
+ polls_left: usize ,
35
+ }
36
+ }
37
+
38
+ impl < F : Future > Future for Cancellable < F > {
39
+ type Output = Option < F :: Output > ;
40
+
41
+ fn poll ( self : Pin < & mut Self > , ctx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
42
+ let this = self . project ( ) ;
43
+ match this. fut . poll ( ctx) {
44
+ Poll :: Ready ( r) => Poll :: Ready ( Some ( r) ) ,
45
+ Poll :: Pending => {
46
+ * this. polls_left = this. polls_left . saturating_sub ( 1 ) ;
47
+ if * this. polls_left == 0 {
48
+ Poll :: Ready ( None )
49
+ } else {
50
+ Poll :: Pending
51
+ }
52
+ }
53
+ }
54
+ }
55
+ }
56
+
25
57
async fn connect_raw ( s : & str ) -> Result < ( Client , Connection < TcpStream , NoTlsStream > ) , Error > {
26
58
let socket = TcpStream :: connect ( "127.0.0.1:5433" ) . await . unwrap ( ) ;
27
59
let config = s. parse :: < Config > ( ) . unwrap ( ) ;
@@ -35,6 +67,20 @@ async fn connect(s: &str) -> Client {
35
67
client
36
68
}
37
69
70
+ async fn current_transaction_id ( client : & Client ) -> i64 {
71
+ client
72
+ . query ( "SELECT txid_current()" , & [ ] )
73
+ . await
74
+ . unwrap ( )
75
+ . pop ( )
76
+ . unwrap ( )
77
+ . get :: < _ , i64 > ( "txid_current" )
78
+ }
79
+
80
+ async fn in_transaction ( client : & Client ) -> bool {
81
+ current_transaction_id ( client) . await == current_transaction_id ( client) . await
82
+ }
83
+
38
84
#[ tokio:: test]
39
85
async fn plain_password_missing ( ) {
40
86
connect_raw ( "user=pass_user dbname=postgres" )
@@ -377,6 +423,80 @@ async fn transaction_rollback() {
377
423
assert_eq ! ( rows. len( ) , 0 ) ;
378
424
}
379
425
426
+ #[ tokio:: test]
427
+ async fn transaction_future_cancellation ( ) {
428
+ let mut client = connect ( "user=postgres" ) . await ;
429
+
430
+ for i in 0 .. {
431
+ let done = {
432
+ let txn = client. transaction ( ) ;
433
+ let fut = Cancellable {
434
+ fut : txn,
435
+ polls_left : i,
436
+ } ;
437
+ fut. await
438
+ . map ( |res| res. expect ( "transaction failed" ) )
439
+ . is_some ( )
440
+ } ;
441
+
442
+ assert ! ( !in_transaction( & client) . await ) ;
443
+
444
+ if done {
445
+ break ;
446
+ }
447
+ }
448
+ }
449
+
450
+ #[ tokio:: test]
451
+ async fn transaction_commit_future_cancellation ( ) {
452
+ let mut client = connect ( "user=postgres" ) . await ;
453
+
454
+ for i in 0 .. {
455
+ let done = {
456
+ let txn = client. transaction ( ) . await . unwrap ( ) ;
457
+ let commit = txn. commit ( ) ;
458
+ let fut = Cancellable {
459
+ fut : commit,
460
+ polls_left : i,
461
+ } ;
462
+ fut. await
463
+ . map ( |res| res. expect ( "transaction failed" ) )
464
+ . is_some ( )
465
+ } ;
466
+
467
+ assert ! ( !in_transaction( & client) . await ) ;
468
+
469
+ if done {
470
+ break ;
471
+ }
472
+ }
473
+ }
474
+
475
+ #[ tokio:: test]
476
+ async fn transaction_rollback_future_cancellation ( ) {
477
+ let mut client = connect ( "user=postgres" ) . await ;
478
+
479
+ for i in 0 .. {
480
+ let done = {
481
+ let txn = client. transaction ( ) . await . unwrap ( ) ;
482
+ let rollback = txn. rollback ( ) ;
483
+ let fut = Cancellable {
484
+ fut : rollback,
485
+ polls_left : i,
486
+ } ;
487
+ fut. await
488
+ . map ( |res| res. expect ( "transaction failed" ) )
489
+ . is_some ( )
490
+ } ;
491
+
492
+ assert ! ( !in_transaction( & client) . await ) ;
493
+
494
+ if done {
495
+ break ;
496
+ }
497
+ }
498
+ }
499
+
380
500
#[ tokio:: test]
381
501
async fn transaction_rollback_drop ( ) {
382
502
let mut client = connect ( "user=postgres" ) . await ;
0 commit comments