Skip to content

Commit cf3ead0

Browse files
Add savepoint method to Transaction
This change creates a `Transaction.savepoint` method, which is equivalent to `Transaction.transaction`, but takes a custom name for the nested transaction's savepoint name.
1 parent c0512e0 commit cf3ead0

File tree

2 files changed

+39
-18
lines changed

2 files changed

+39
-18
lines changed

src/transaction.rs

+23-11
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ impl Config {
151151
pub struct Transaction<'conn> {
152152
conn: &'conn Connection,
153153
depth: u32,
154+
savepoint_name: Option<&'conn str>,
154155
commit: Cell<bool>,
155156
finished: bool,
156157
}
@@ -177,6 +178,7 @@ impl<'conn> TransactionInternals<'conn> for Transaction<'conn> {
177178
Transaction {
178179
conn: conn,
179180
depth: depth,
181+
savepoint_name: None,
180182
commit: Cell::new(false),
181183
finished: false,
182184
}
@@ -195,14 +197,13 @@ impl<'conn> Transaction<'conn> {
195197
fn finish_inner(&mut self) -> Result<()> {
196198
let mut conn = self.conn.conn.borrow_mut();
197199
debug_assert!(self.depth == conn.trans_depth);
198-
let query = match (self.commit.get(), self.depth != 1) {
199-
(false, true) => "ROLLBACK TO sp",
200-
(false, false) => "ROLLBACK",
201-
(true, true) => "RELEASE sp",
202-
(true, false) => "COMMIT",
203-
};
204200
conn.trans_depth -= 1;
205-
conn.quick_query(query).map(|_| ())
201+
match (self.commit.get(), self.savepoint_name) {
202+
(false, Some(savepoint_name)) => conn.quick_query(&format!("ROLLBACK TO {}", savepoint_name)),
203+
(false, None) => conn.quick_query("ROLLBACK"),
204+
(true, Some(savepoint_name)) => conn.quick_query(&format!("RELEASE {}", savepoint_name)),
205+
(true, None) => conn.quick_query("COMMIT"),
206+
}.map(|_| ())
206207
}
207208

208209
/// Like `Connection::prepare`.
@@ -233,22 +234,33 @@ impl<'conn> Transaction<'conn> {
233234
self.conn.batch_execute(query)
234235
}
235236

236-
/// Like `Connection::transaction`.
237+
/// Like `Connection::transaction`, but creates a nested transaction.
237238
///
238239
/// # Panics
239240
///
240241
/// Panics if there is an active nested transaction.
241242
pub fn transaction<'a>(&'a self) -> Result<Transaction<'a>> {
243+
self.savepoint("sp")
244+
}
245+
246+
/// Like `Connection::transaction`, but creates a nested transaction
247+
/// with the provided name.
248+
///
249+
/// # Panics
250+
///
251+
/// Panics if there is an active nested transaction.
252+
pub fn savepoint<'a>(&'a self, name: &'a str) -> Result<Transaction<'a>> {
242253
let mut conn = self.conn.conn.borrow_mut();
243254
check_desync!(conn);
244255
assert!(conn.trans_depth == self.depth,
245-
"`transaction` may only be called on the active transaction");
246-
try!(conn.quick_query("SAVEPOINT sp"));
256+
"`savepoint` may only be called on the active transaction");
257+
try!(conn.quick_query(&format!("SAVEPOINT {}", name)));
247258
conn.trans_depth += 1;
248259
Ok(Transaction {
249260
conn: self.conn,
250-
commit: Cell::new(false),
251261
depth: self.depth + 1,
262+
savepoint_name: Some(name),
263+
commit: Cell::new(false),
252264
finished: false,
253265
})
254266
}

tests/test.rs

+16-7
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,9 @@ fn test_nested_transactions() {
202202
}
203203

204204
{
205-
let trans3 = or_panic!(trans2.transaction());
206-
or_panic!(trans3.execute("INSERT INTO foo (id) VALUES (6)", &[]));
207-
assert!(trans3.commit().is_ok());
205+
let sp = or_panic!(trans2.savepoint("custom"));
206+
or_panic!(sp.execute("INSERT INTO foo (id) VALUES (6)", &[]));
207+
assert!(sp.commit().is_ok());
208208
}
209209

210210
assert!(trans2.commit().is_ok());
@@ -250,10 +250,10 @@ fn test_nested_transactions_finish() {
250250
}
251251

252252
{
253-
let trans3 = or_panic!(trans2.transaction());
254-
or_panic!(trans3.execute("INSERT INTO foo (id) VALUES (6)", &[]));
255-
trans3.set_commit();
256-
assert!(trans3.finish().is_ok());
253+
let sp = or_panic!(trans2.savepoint("custom"));
254+
or_panic!(sp.execute("INSERT INTO foo (id) VALUES (6)", &[]));
255+
sp.set_commit();
256+
assert!(sp.finish().is_ok());
257257
}
258258

259259
trans2.set_commit();
@@ -294,6 +294,15 @@ fn test_trans_with_nested_trans() {
294294
trans.transaction().unwrap();
295295
}
296296

297+
#[test]
298+
#[should_panic(expected = "active transaction")]
299+
fn test_trans_with_savepoints() {
300+
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", SslMode::None));
301+
let trans = or_panic!(conn.transaction());
302+
let _sp = or_panic!(trans.savepoint("custom"));
303+
trans.savepoint("custom2").unwrap();
304+
}
305+
297306
#[test]
298307
fn test_stmt_execute_after_transaction() {
299308
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", SslMode::None));

0 commit comments

Comments
 (0)