Skip to content

Commit f6189a9

Browse files
committed
Fix transaction not being rolled back on Client::transaction() Future dropped before completion
1 parent 0adcf58 commit f6189a9

File tree

2 files changed

+159
-4
lines changed

2 files changed

+159
-4
lines changed

tokio-postgres/src/client.rs

+38-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::codec::BackendMessages;
1+
use crate::codec::{BackendMessages, FrontendMessage};
22
use crate::config::{Host, SslMode};
33
use crate::connection::{Request, RequestMessages};
44
use crate::copy_out::CopyOutStream;
@@ -19,7 +19,7 @@ use fallible_iterator::FallibleIterator;
1919
use futures::channel::mpsc;
2020
use futures::{future, pin_mut, ready, StreamExt, TryStreamExt};
2121
use parking_lot::Mutex;
22-
use postgres_protocol::message::backend::Message;
22+
use postgres_protocol::message::{backend::Message, frontend};
2323
use postgres_types::BorrowToSql;
2424
use std::collections::HashMap;
2525
use std::fmt;
@@ -488,7 +488,42 @@ impl Client {
488488
///
489489
/// The transaction will roll back by default - use the `commit` method to commit it.
490490
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
491-
self.batch_execute("BEGIN").await?;
491+
struct RollbackIfNotDone<'me> {
492+
client: &'me Client,
493+
done: bool,
494+
}
495+
496+
impl<'a> Drop for RollbackIfNotDone<'a> {
497+
fn drop(&mut self) {
498+
if self.done {
499+
return;
500+
}
501+
502+
let buf = self.client.inner().with_buf(|buf| {
503+
frontend::query("ROLLBACK", buf).unwrap();
504+
buf.split().freeze()
505+
});
506+
let _ = self
507+
.client
508+
.inner()
509+
.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
510+
}
511+
}
512+
513+
// This is done, as `Future` created by this method can be dropped after
514+
// `RequestMessages` is synchronously send to the `Connection` by
515+
// `batch_execute()`, but before `Responses` is asynchronously polled to
516+
// completion. In that case `Transaction` won't be created and thus
517+
// won't be rolled back.
518+
{
519+
let mut cleaner = RollbackIfNotDone {
520+
client: self,
521+
done: false,
522+
};
523+
self.batch_execute("BEGIN").await?;
524+
cleaner.done = true;
525+
}
526+
492527
Ok(Transaction::new(self))
493528
}
494529

tokio-postgres/tests/test/main.rs

+121-1
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
use bytes::{Bytes, BytesMut};
44
use futures::channel::mpsc;
55
use futures::{
6-
future, join, pin_mut, stream, try_join, FutureExt, SinkExt, StreamExt, TryStreamExt,
6+
future, join, pin_mut, stream, try_join, Future, FutureExt, SinkExt, StreamExt, TryStreamExt,
77
};
8+
use pin_project_lite::pin_project;
89
use std::fmt::Write;
10+
use std::pin::Pin;
11+
use std::task::{Context, Poll};
912
use std::time::Duration;
1013
use tokio::net::TcpStream;
1114
use tokio::time;
@@ -22,6 +25,35 @@ mod parse;
2225
mod runtime;
2326
mod types;
2427

28+
pin_project! {
29+
/// Polls `F` at most `polls_left` times returning `Some(F::Output)` if
30+
/// [`Future`] returned [`Poll::Ready`] or [`None`] otherwise.
31+
struct Cancellable<F> {
32+
#[pin]
33+
fut: F,
34+
polls_left: usize,
35+
}
36+
}
37+
38+
impl<F: Future> Future for Cancellable<F> {
39+
type Output = Option<F::Output>;
40+
41+
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
42+
let this = self.project();
43+
match this.fut.poll(ctx) {
44+
Poll::Ready(r) => Poll::Ready(Some(r)),
45+
Poll::Pending => {
46+
*this.polls_left = this.polls_left.saturating_sub(1);
47+
if *this.polls_left == 0 {
48+
Poll::Ready(None)
49+
} else {
50+
Poll::Pending
51+
}
52+
}
53+
}
54+
}
55+
}
56+
2557
async fn connect_raw(s: &str) -> Result<(Client, Connection<TcpStream, NoTlsStream>), Error> {
2658
let socket = TcpStream::connect("127.0.0.1:5433").await.unwrap();
2759
let config = s.parse::<Config>().unwrap();
@@ -35,6 +67,20 @@ async fn connect(s: &str) -> Client {
3567
client
3668
}
3769

70+
async fn current_transaction_id(client: &Client) -> i64 {
71+
client
72+
.query("SELECT txid_current()", &[])
73+
.await
74+
.unwrap()
75+
.pop()
76+
.unwrap()
77+
.get::<_, i64>("txid_current")
78+
}
79+
80+
async fn in_transaction(client: &Client) -> bool {
81+
current_transaction_id(client).await == current_transaction_id(client).await
82+
}
83+
3884
#[tokio::test]
3985
async fn plain_password_missing() {
4086
connect_raw("user=pass_user dbname=postgres")
@@ -377,6 +423,80 @@ async fn transaction_rollback() {
377423
assert_eq!(rows.len(), 0);
378424
}
379425

426+
#[tokio::test]
427+
async fn transaction_future_cancellation() {
428+
let mut client = connect("user=postgres").await;
429+
430+
for i in 0.. {
431+
let done = {
432+
let txn = client.transaction();
433+
let fut = Cancellable {
434+
fut: txn,
435+
polls_left: i,
436+
};
437+
fut.await
438+
.map(|res| res.expect("transaction failed"))
439+
.is_some()
440+
};
441+
442+
assert!(!in_transaction(&client).await);
443+
444+
if done {
445+
break;
446+
}
447+
}
448+
}
449+
450+
#[tokio::test]
451+
async fn transaction_commit_future_cancellation() {
452+
let mut client = connect("user=postgres").await;
453+
454+
for i in 0.. {
455+
let done = {
456+
let txn = client.transaction().await.unwrap();
457+
let commit = txn.commit();
458+
let fut = Cancellable {
459+
fut: commit,
460+
polls_left: i,
461+
};
462+
fut.await
463+
.map(|res| res.expect("transaction failed"))
464+
.is_some()
465+
};
466+
467+
assert!(!in_transaction(&client).await);
468+
469+
if done {
470+
break;
471+
}
472+
}
473+
}
474+
475+
#[tokio::test]
476+
async fn transaction_rollback_future_cancellation() {
477+
let mut client = connect("user=postgres").await;
478+
479+
for i in 0.. {
480+
let done = {
481+
let txn = client.transaction().await.unwrap();
482+
let rollback = txn.rollback();
483+
let fut = Cancellable {
484+
fut: rollback,
485+
polls_left: i,
486+
};
487+
fut.await
488+
.map(|res| res.expect("transaction failed"))
489+
.is_some()
490+
};
491+
492+
assert!(!in_transaction(&client).await);
493+
494+
if done {
495+
break;
496+
}
497+
}
498+
}
499+
380500
#[tokio::test]
381501
async fn transaction_rollback_drop() {
382502
let mut client = connect("user=postgres").await;

0 commit comments

Comments
 (0)