Skip to content

Commit 9e399aa

Browse files
committed
Basic transaction support
1 parent bf06336 commit 9e399aa

File tree

4 files changed

+212
-0
lines changed

4 files changed

+212
-0
lines changed

tokio-postgres/src/lib.rs

+27
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,14 @@ impl Client {
9595
Query(self.0.query(&statement.0, params))
9696
}
9797

98+
pub fn transaction<T>(&mut self, future: T) -> Transaction<T>
99+
where
100+
T: Future,
101+
T::Error: From<Error>,
102+
{
103+
Transaction(proto::TransactionFuture::new(self.0.clone(), future))
104+
}
105+
98106
pub fn batch_execute(&mut self, query: &str) -> BatchExecute {
99107
BatchExecute(self.0.batch_execute(query))
100108
}
@@ -242,6 +250,25 @@ impl Row {
242250
}
243251
}
244252

253+
#[must_use = "futures do nothing unless polled"]
254+
pub struct Transaction<T>(proto::TransactionFuture<T, T::Item, T::Error>)
255+
where
256+
T: Future,
257+
T::Error: From<Error>;
258+
259+
impl<T> Future for Transaction<T>
260+
where
261+
T: Future,
262+
T::Error: From<Error>,
263+
{
264+
type Item = T::Item;
265+
type Error = T::Error;
266+
267+
fn poll(&mut self) -> Poll<T::Item, T::Error> {
268+
self.0.poll()
269+
}
270+
}
271+
245272
#[must_use = "futures do nothing unless polled"]
246273
pub struct BatchExecute(proto::SimpleQueryFuture);
247274

tokio-postgres/src/proto/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ mod row;
2121
mod simple_query;
2222
mod socket;
2323
mod statement;
24+
mod transaction;
2425
mod typeinfo;
2526
mod typeinfo_composite;
2627
mod typeinfo_enum;
@@ -37,3 +38,4 @@ pub use proto::row::Row;
3738
pub use proto::simple_query::SimpleQueryFuture;
3839
pub use proto::socket::Socket;
3940
pub use proto::statement::Statement;
41+
pub use proto::transaction::TransactionFuture;
+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
use futures::{Async, Future, Poll};
2+
use proto::client::Client;
3+
use proto::simple_query::SimpleQueryFuture;
4+
use state_machine_future::RentToOwn;
5+
6+
use Error;
7+
8+
#[derive(StateMachineFuture)]
9+
pub enum Transaction<F, T, E>
10+
where
11+
F: Future<Item = T, Error = E>,
12+
E: From<Error>,
13+
{
14+
#[state_machine_future(start, transitions(Beginning))]
15+
Start { client: Client, future: F },
16+
#[state_machine_future(transitions(Running))]
17+
Beginning {
18+
client: Client,
19+
begin: SimpleQueryFuture,
20+
future: F,
21+
},
22+
#[state_machine_future(transitions(Finishing))]
23+
Running { client: Client, future: F },
24+
#[state_machine_future(transitions(Finished))]
25+
Finishing {
26+
future: SimpleQueryFuture,
27+
result: Result<T, E>,
28+
},
29+
#[state_machine_future(ready)]
30+
Finished(T),
31+
#[state_machine_future(error)]
32+
Failed(E),
33+
}
34+
35+
impl<F, T, E> PollTransaction<F, T, E> for Transaction<F, T, E>
36+
where
37+
F: Future<Item = T, Error = E>,
38+
E: From<Error>,
39+
{
40+
fn poll_start<'a>(
41+
state: &'a mut RentToOwn<'a, Start<F, T, E>>,
42+
) -> Poll<AfterStart<F, T, E>, E> {
43+
let state = state.take();
44+
transition!(Beginning {
45+
begin: state.client.batch_execute("BEGIN"),
46+
client: state.client,
47+
future: state.future,
48+
})
49+
}
50+
51+
fn poll_beginning<'a>(
52+
state: &'a mut RentToOwn<'a, Beginning<F, T, E>>,
53+
) -> Poll<AfterBeginning<F, T, E>, E> {
54+
try_ready!(state.begin.poll());
55+
let state = state.take();
56+
transition!(Running {
57+
client: state.client,
58+
future: state.future,
59+
})
60+
}
61+
62+
fn poll_running<'a>(
63+
state: &'a mut RentToOwn<'a, Running<F, T, E>>,
64+
) -> Poll<AfterRunning<T, E>, E> {
65+
match state.future.poll() {
66+
Ok(Async::NotReady) => return Ok(Async::NotReady),
67+
Ok(Async::Ready(t)) => transition!(Finishing {
68+
future: state.client.batch_execute("COMMIT"),
69+
result: Ok(t),
70+
}),
71+
Err(e) => transition!(Finishing {
72+
future: state.client.batch_execute("ROLLBACK"),
73+
result: Err(e),
74+
}),
75+
}
76+
}
77+
78+
fn poll_finishing<'a>(
79+
state: &'a mut RentToOwn<'a, Finishing<T, E>>,
80+
) -> Poll<AfterFinishing<T>, E> {
81+
match state.future.poll() {
82+
Ok(Async::NotReady) => return Ok(Async::NotReady),
83+
Ok(Async::Ready(())) => {
84+
let t = state.take().result?;
85+
transition!(Finished(t))
86+
}
87+
Err(e) => match state.take().result {
88+
Ok(_) => Err(e.into()),
89+
// prioritize the future's error over the rollback error
90+
Err(e) => Err(e),
91+
},
92+
}
93+
}
94+
}
95+
96+
impl<F, T, E> TransactionFuture<F, T, E>
97+
where
98+
F: Future<Item = T, Error = E>,
99+
E: From<Error>,
100+
{
101+
pub fn new(client: Client, future: F) -> TransactionFuture<F, T, E> {
102+
Transaction::start(client, future)
103+
}
104+
}

tokio-postgres/tests/test.rs

+79
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ extern crate log;
99

1010
use futures::future;
1111
use futures::sync::mpsc;
12+
use std::error::Error;
1213
use std::time::{Duration, Instant};
1314
use tokio::prelude::*;
1415
use tokio::runtime::current_thread::Runtime;
@@ -477,3 +478,81 @@ fn notifications() {
477478
assert_eq!(notifications[1].channel, "test_notifications");
478479
assert_eq!(notifications[1].payload, "world");
479480
}
481+
482+
#[test]
483+
fn test_transaction_commit() {
484+
let _ = env_logger::try_init();
485+
let mut runtime = Runtime::new().unwrap();
486+
487+
let (mut client, connection) = runtime
488+
.block_on(tokio_postgres::connect(
489+
"postgres://postgres@localhost:5433".parse().unwrap(),
490+
TlsMode::None,
491+
))
492+
.unwrap();
493+
let connection = connection.map_err(|e| panic!("{}", e));
494+
runtime.handle().spawn(connection).unwrap();
495+
496+
runtime
497+
.block_on(client.batch_execute(
498+
"CREATE TEMPORARY TABLE foo (
499+
id SERIAL,
500+
name TEXT
501+
)",
502+
))
503+
.unwrap();
504+
505+
let f = client.batch_execute("INSERT INTO foo (name) VALUES ('steven')");
506+
runtime.block_on(client.transaction(f)).unwrap();
507+
508+
let rows = runtime
509+
.block_on(
510+
client
511+
.prepare("SELECT name FROM foo")
512+
.and_then(|s| client.query(&s, &[]).collect()),
513+
)
514+
.unwrap();
515+
516+
assert_eq!(rows.len(), 1);
517+
assert_eq!(rows[0].get::<_, &str>(0), "steven");
518+
}
519+
520+
#[test]
521+
fn test_transaction_abort() {
522+
let _ = env_logger::try_init();
523+
let mut runtime = Runtime::new().unwrap();
524+
525+
let (mut client, connection) = runtime
526+
.block_on(tokio_postgres::connect(
527+
"postgres://postgres@localhost:5433".parse().unwrap(),
528+
TlsMode::None,
529+
))
530+
.unwrap();
531+
let connection = connection.map_err(|e| panic!("{}", e));
532+
runtime.handle().spawn(connection).unwrap();
533+
534+
runtime
535+
.block_on(client.batch_execute(
536+
"CREATE TEMPORARY TABLE foo (
537+
id SERIAL,
538+
name TEXT
539+
)",
540+
))
541+
.unwrap();
542+
543+
let f = client
544+
.batch_execute("INSERT INTO foo (name) VALUES ('steven')")
545+
.map_err(|e| Box::new(e) as Box<Error>)
546+
.and_then(|_| Err::<(), _>(Box::<Error>::from("")));
547+
runtime.block_on(client.transaction(f)).unwrap_err();
548+
549+
let rows = runtime
550+
.block_on(
551+
client
552+
.prepare("SELECT name FROM foo")
553+
.and_then(|s| client.query(&s, &[]).collect()),
554+
)
555+
.unwrap();
556+
557+
assert_eq!(rows.len(), 0);
558+
}

0 commit comments

Comments
 (0)