Skip to content

Commit 2200286

Browse files
authored
Merge pull request sfackler#345 from sfackler/prepare-typed
Add Connection::prepare_typed
2 parents 198bf07 + 44222e5 commit 2200286

File tree

2 files changed

+63
-9
lines changed

2 files changed

+63
-9
lines changed

postgres/src/lib.rs

+55-9
Original file line numberDiff line numberDiff line change
@@ -484,11 +484,22 @@ impl InnerConnection {
484484
mem::replace(&mut self.notice_handler, handler)
485485
}
486486

487-
fn raw_prepare(&mut self, stmt_name: &str, query: &str) -> Result<(Vec<Type>, Vec<Column>)> {
487+
fn raw_prepare(
488+
&mut self,
489+
stmt_name: &str,
490+
query: &str,
491+
types: &[Option<Type>],
492+
) -> Result<(Vec<Type>, Vec<Column>)> {
488493
debug!("preparing query with name `{}`: {}", stmt_name, query);
489494

490-
self.stream
491-
.write_message(|buf| frontend::parse(stmt_name, query, None, buf))?;
495+
self.stream.write_message(|buf| {
496+
frontend::parse(
497+
stmt_name,
498+
query,
499+
types.iter().map(|t| t.as_ref().map_or(0, |t| t.oid())),
500+
buf,
501+
)
502+
})?;
492503
self.stream
493504
.write_message(|buf| frontend::describe(b'S', stmt_name, buf))?;
494505
self.stream
@@ -657,9 +668,14 @@ impl InnerConnection {
657668
stmt_name
658669
}
659670

660-
fn prepare<'a>(&mut self, query: &str, conn: &'a Connection) -> Result<Statement<'a>> {
671+
fn prepare_typed<'a>(
672+
&mut self,
673+
query: &str,
674+
types: &[Option<Type>],
675+
conn: &'a Connection,
676+
) -> Result<Statement<'a>> {
661677
let stmt_name = self.make_stmt_name();
662-
let (param_types, columns) = self.raw_prepare(&stmt_name, query)?;
678+
let (param_types, columns) = self.raw_prepare(&stmt_name, query, types)?;
663679
let info = Arc::new(StatementInfo {
664680
name: stmt_name,
665681
param_types: param_types,
@@ -675,7 +691,7 @@ impl InnerConnection {
675691
Some(info) => info,
676692
None => {
677693
let stmt_name = self.make_stmt_name();
678-
let (param_types, columns) = self.raw_prepare(&stmt_name, query)?;
694+
let (param_types, columns) = self.raw_prepare(&stmt_name, query, &[])?;
679695
let info = Arc::new(StatementInfo {
680696
name: stmt_name,
681697
param_types: param_types,
@@ -734,6 +750,7 @@ impl InnerConnection {
734750
INNER JOIN pg_catalog.pg_namespace n ON \
735751
t.typnamespace = n.oid \
736752
WHERE t.oid = $1",
753+
&[],
737754
) {
738755
Ok(..) => {}
739756
// Range types weren't added until Postgres 9.2, so pg_range may not exist
@@ -746,6 +763,7 @@ impl InnerConnection {
746763
INNER JOIN pg_catalog.pg_namespace n \
747764
ON t.typnamespace = n.oid \
748765
WHERE t.oid = $1",
766+
&[],
749767
)?;
750768
}
751769
Err(e) => return Err(e),
@@ -811,6 +829,7 @@ impl InnerConnection {
811829
FROM pg_catalog.pg_enum \
812830
WHERE enumtypid = $1 \
813831
ORDER BY enumsortorder",
832+
&[],
814833
) {
815834
Ok(..) => {}
816835
// Postgres 9.0 doesn't have enumsortorder
@@ -821,6 +840,7 @@ impl InnerConnection {
821840
FROM pg_catalog.pg_enum \
822841
WHERE enumtypid = $1 \
823842
ORDER BY oid",
843+
&[],
824844
)?;
825845
}
826846
Err(e) => return Err(e),
@@ -858,6 +878,7 @@ impl InnerConnection {
858878
AND NOT attisdropped \
859879
AND attnum > 0 \
860880
ORDER BY attnum",
881+
&[],
861882
)?;
862883

863884
self.has_typeinfo_composite_query = true;
@@ -1055,7 +1076,7 @@ impl Connection {
10551076
/// println!("{} rows updated", rows_updated);
10561077
/// ```
10571078
pub fn execute(&self, query: &str, params: &[&ToSql]) -> Result<u64> {
1058-
let (param_types, columns) = self.0.borrow_mut().raw_prepare("", query)?;
1079+
let (param_types, columns) = self.0.borrow_mut().raw_prepare("", query, &[])?;
10591080
let info = Arc::new(StatementInfo {
10601081
name: String::new(),
10611082
param_types: param_types,
@@ -1091,7 +1112,7 @@ impl Connection {
10911112
/// }
10921113
/// ```
10931114
pub fn query(&self, query: &str, params: &[&ToSql]) -> Result<Rows> {
1094-
let (param_types, columns) = self.0.borrow_mut().raw_prepare("", query)?;
1115+
let (param_types, columns) = self.0.borrow_mut().raw_prepare("", query, &[])?;
10951116
let info = Arc::new(StatementInfo {
10961117
name: String::new(),
10971118
param_types: param_types,
@@ -1167,7 +1188,32 @@ impl Connection {
11671188
/// }
11681189
/// ```
11691190
pub fn prepare<'a>(&'a self, query: &str) -> Result<Statement<'a>> {
1170-
self.0.borrow_mut().prepare(query, self)
1191+
self.prepare_typed(query, &[])
1192+
}
1193+
1194+
/// Like `prepare`, but allows for the types of query parameters to be explicitly specified.
1195+
///
1196+
/// Postgres will normally infer the types of paramters, but this function offers more control
1197+
/// of that behavior. `None` will cause Postgres to infer the type. The list of types can be
1198+
/// shorter than the number of parameters in the query; it will act as if padded out with `None`
1199+
/// values.
1200+
///
1201+
/// # Example
1202+
///
1203+
/// ```rust,no_run
1204+
/// # use postgres::{Connection, TlsMode};
1205+
/// # use postgres::types::Type;
1206+
/// # let conn = Connection::connect("", TlsMode::None).unwrap();
1207+
/// // $1 would normally be assigned the type INT4, but we can override that to INT8
1208+
/// let stmt = conn.prepare_typed("SELECT $1::INT4", &[Some(Type::INT8)]).unwrap();
1209+
/// assert_eq!(stmt.param_types()[0], Type::INT8);
1210+
/// ```
1211+
pub fn prepare_typed<'a>(
1212+
&'a self,
1213+
query: &str,
1214+
types: &[Option<Type>],
1215+
) -> Result<Statement<'a>> {
1216+
self.0.borrow_mut().prepare_typed(query, types, self)
11711217
}
11721218

11731219
/// Creates a cached prepared statement.

postgres/tests/test.rs

+8
Original file line numberDiff line numberDiff line change
@@ -1480,3 +1480,11 @@ fn keepalive() {
14801480

14811481
Connection::connect(params, TlsMode::None).unwrap();
14821482
}
1483+
1484+
#[test]
1485+
fn explicit_types() {
1486+
let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap();
1487+
let stmt = conn.prepare_typed("SELECT $1::INT4", &[Some(Type::INT8)])
1488+
.unwrap();
1489+
assert_eq!(stmt.param_types()[0], Type::INT8);
1490+
}

0 commit comments

Comments
 (0)