Skip to content

Commit 635e638

Browse files
committed
A less stringy builder
This allows us to support things like non-utf8 passwords and unix socket directories.
1 parent e80e1fc commit 635e638

File tree

8 files changed

+278
-201
lines changed

8 files changed

+278
-201
lines changed

postgres/src/builder.rs

+34
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use futures::sync::oneshot;
22
use futures::Future;
33
use log::error;
4+
use std::path::Path;
45
use std::str::FromStr;
6+
use std::time::Duration;
57
use tokio_postgres::{Error, MakeTlsMode, Socket, TlsMode};
68

79
use crate::{Client, RUNTIME};
@@ -19,11 +21,43 @@ impl Builder {
1921
Builder(tokio_postgres::Builder::new())
2022
}
2123

24+
pub fn host(&mut self, host: &str) -> &mut Builder {
25+
self.0.host(host);
26+
self
27+
}
28+
29+
#[cfg(unix)]
30+
pub fn host_path<T>(&mut self, host: T) -> &mut Builder
31+
where
32+
T: AsRef<Path>,
33+
{
34+
self.0.host_path(host);
35+
self
36+
}
37+
38+
pub fn port(&mut self, port: u16) -> &mut Builder {
39+
self.0.port(port);
40+
self
41+
}
42+
2243
pub fn param(&mut self, key: &str, value: &str) -> &mut Builder {
2344
self.0.param(key, value);
2445
self
2546
}
2647

48+
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Builder {
49+
self.0.connect_timeout(connect_timeout);
50+
self
51+
}
52+
53+
pub fn password<T>(&mut self, password: T) -> &mut Builder
54+
where
55+
T: AsRef<[u8]>,
56+
{
57+
self.0.password(password);
58+
self
59+
}
60+
2761
pub fn connect<T>(&self, tls_mode: T) -> Result<Client, Error>
2862
where
2963
T: MakeTlsMode<Socket> + 'static + Send,

tokio-postgres-native-tls/src/test.rs

+6-11
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@ use tokio_postgres::{self, PreferTls, RequireTls, TlsMode};
66

77
use crate::TlsConnector;
88

9-
fn smoke_test<T>(builder: &tokio_postgres::Builder, tls: T)
9+
fn smoke_test<T>(s: &str, tls: T)
1010
where
1111
T: TlsMode<TcpStream>,
1212
T::Stream: 'static,
1313
{
1414
let mut runtime = Runtime::new().unwrap();
1515

16+
let builder = s.parse::<tokio_postgres::Builder>().unwrap();
17+
1618
let handshake = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap())
1719
.map_err(|e| panic!("{}", e))
1820
.and_then(|s| builder.handshake(s, tls));
@@ -42,9 +44,7 @@ fn require() {
4244
.build()
4345
.unwrap();
4446
smoke_test(
45-
tokio_postgres::Builder::new()
46-
.user("ssl_user")
47-
.dbname("postgres"),
47+
"user=ssl_user dbname=postgres",
4848
RequireTls(TlsConnector::with_connector(connector, "localhost")),
4949
);
5050
}
@@ -58,9 +58,7 @@ fn prefer() {
5858
.build()
5959
.unwrap();
6060
smoke_test(
61-
tokio_postgres::Builder::new()
62-
.user("ssl_user")
63-
.dbname("postgres"),
61+
"user=ssl_user dbname=postgres",
6462
PreferTls(TlsConnector::with_connector(connector, "localhost")),
6563
);
6664
}
@@ -74,10 +72,7 @@ fn scram_user() {
7472
.build()
7573
.unwrap();
7674
smoke_test(
77-
tokio_postgres::Builder::new()
78-
.user("scram_user")
79-
.password("password")
80-
.dbname("postgres"),
75+
"user=scram_user password=password dbname=postgres",
8176
RequireTls(TlsConnector::with_connector(connector, "localhost")),
8277
);
8378
}

tokio-postgres-openssl/src/test.rs

+6-11
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@ use tokio_postgres::{self, PreferTls, RequireTls, TlsMode};
66

77
use super::*;
88

9-
fn smoke_test<T>(builder: &tokio_postgres::Builder, tls: T)
9+
fn smoke_test<T>(s: &str, tls: T)
1010
where
1111
T: TlsMode<TcpStream>,
1212
T::Stream: 'static,
1313
{
1414
let mut runtime = Runtime::new().unwrap();
1515

16+
let builder = s.parse::<tokio_postgres::Builder>().unwrap();
17+
1618
let handshake = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap())
1719
.map_err(|e| panic!("{}", e))
1820
.and_then(|s| builder.handshake(s, tls));
@@ -39,9 +41,7 @@ fn require() {
3941
builder.set_ca_file("../test/server.crt").unwrap();
4042
let ctx = builder.build();
4143
smoke_test(
42-
tokio_postgres::Builder::new()
43-
.user("ssl_user")
44-
.dbname("postgres"),
44+
"user=ssl_user dbname=postgres",
4545
RequireTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
4646
);
4747
}
@@ -52,9 +52,7 @@ fn prefer() {
5252
builder.set_ca_file("../test/server.crt").unwrap();
5353
let ctx = builder.build();
5454
smoke_test(
55-
tokio_postgres::Builder::new()
56-
.user("ssl_user")
57-
.dbname("postgres"),
55+
"user=ssl_user dbname=postgres",
5856
PreferTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
5957
);
6058
}
@@ -65,10 +63,7 @@ fn scram_user() {
6563
builder.set_ca_file("../test/server.crt").unwrap();
6664
let ctx = builder.build();
6765
smoke_test(
68-
tokio_postgres::Builder::new()
69-
.user("scram_user")
70-
.password("password")
71-
.dbname("postgres"),
66+
"user=scram_user password=password dbname=postgres",
7267
RequireTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
7368
);
7469
}

tokio-postgres/src/builder.rs

+104-17
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
use std::collections::hash_map::{self, HashMap};
22
use std::iter;
3+
#[cfg(all(feature = "runtime", unix))]
4+
use std::path::{Path, PathBuf};
35
use std::str::{self, FromStr};
6+
#[cfg(feature = "runtime")]
7+
use std::time::Duration;
48
use tokio_io::{AsyncRead, AsyncWrite};
59

610
#[cfg(feature = "runtime")]
@@ -10,9 +14,24 @@ use crate::proto::HandshakeFuture;
1014
use crate::{Connect, MakeTlsMode, Socket};
1115
use crate::{Error, Handshake, TlsMode};
1216

13-
#[derive(Clone)]
17+
#[cfg(feature = "runtime")]
18+
#[derive(Debug, Clone, PartialEq)]
19+
pub(crate) enum Host {
20+
Tcp(String),
21+
#[cfg(unix)]
22+
Unix(PathBuf),
23+
}
24+
25+
#[derive(Debug, Clone, PartialEq)]
1426
pub struct Builder {
15-
params: HashMap<String, String>,
27+
pub(crate) params: HashMap<String, String>,
28+
pub(crate) password: Option<Vec<u8>>,
29+
#[cfg(feature = "runtime")]
30+
pub(crate) host: Vec<Host>,
31+
#[cfg(feature = "runtime")]
32+
pub(crate) port: Vec<u16>,
33+
#[cfg(feature = "runtime")]
34+
pub(crate) connect_timeout: Option<Duration>,
1635
}
1736

1837
impl Default for Builder {
@@ -27,45 +46,80 @@ impl Builder {
2746
params.insert("client_encoding".to_string(), "UTF8".to_string());
2847
params.insert("timezone".to_string(), "GMT".to_string());
2948

30-
Builder { params }
49+
Builder {
50+
params,
51+
password: None,
52+
#[cfg(feature = "runtime")]
53+
host: vec![],
54+
#[cfg(feature = "runtime")]
55+
port: vec![],
56+
#[cfg(feature = "runtime")]
57+
connect_timeout: None,
58+
}
3159
}
3260

33-
pub fn user(&mut self, user: &str) -> &mut Builder {
34-
self.param("user", user)
61+
#[cfg(feature = "runtime")]
62+
pub fn host(&mut self, host: &str) -> &mut Builder {
63+
#[cfg(unix)]
64+
{
65+
if host.starts_with('/') {
66+
self.host.push(Host::Unix(PathBuf::from(host)));
67+
return self;
68+
}
69+
}
70+
71+
self.host.push(Host::Tcp(host.to_string()));
72+
self
3573
}
3674

37-
pub fn dbname(&mut self, database: &str) -> &mut Builder {
38-
self.param("dbname", database)
75+
#[cfg(all(feature = "runtime", unix))]
76+
pub fn host_path<T>(&mut self, host: T) -> &mut Builder
77+
where
78+
T: AsRef<Path>,
79+
{
80+
self.host.push(Host::Unix(host.as_ref().to_path_buf()));
81+
self
3982
}
4083

41-
pub fn password(&mut self, password: &str) -> &mut Builder {
42-
self.param("password", password)
84+
#[cfg(feature = "runtime")]
85+
pub fn port(&mut self, port: u16) -> &mut Builder {
86+
self.port.push(port);
87+
self
4388
}
4489

45-
pub fn param(&mut self, key: &str, value: &str) -> &mut Builder {
46-
self.params.insert(key.to_string(), value.to_string());
90+
#[cfg(feature = "runtime")]
91+
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Builder {
92+
self.connect_timeout = Some(connect_timeout);
4793
self
4894
}
4995

50-
/// FIXME do we want this?
51-
pub fn iter(&self) -> Iter<'_> {
52-
Iter(self.params.iter())
96+
pub fn password<T>(&mut self, password: T) -> &mut Builder
97+
where
98+
T: AsRef<[u8]>,
99+
{
100+
self.password = Some(password.as_ref().to_vec());
101+
self
102+
}
103+
104+
pub fn param(&mut self, key: &str, value: &str) -> &mut Builder {
105+
self.params.insert(key.to_string(), value.to_string());
106+
self
53107
}
54108

55109
pub fn handshake<S, T>(&self, stream: S, tls_mode: T) -> Handshake<S, T>
56110
where
57111
S: AsyncRead + AsyncWrite,
58112
T: TlsMode<S>,
59113
{
60-
Handshake(HandshakeFuture::new(stream, tls_mode, self.params.clone()))
114+
Handshake(HandshakeFuture::new(stream, tls_mode, self.clone()))
61115
}
62116

63117
#[cfg(feature = "runtime")]
64118
pub fn connect<T>(&self, make_tls_mode: T) -> Connect<T>
65119
where
66120
T: MakeTlsMode<Socket>,
67121
{
68-
Connect(ConnectFuture::new(make_tls_mode, self.params.clone()))
122+
Connect(ConnectFuture::new(make_tls_mode, self.clone()))
69123
}
70124
}
71125

@@ -77,7 +131,40 @@ impl FromStr for Builder {
77131
let mut builder = Builder::new();
78132

79133
while let Some((key, value)) = parser.parameter()? {
80-
builder.params.insert(key.to_string(), value);
134+
match key {
135+
"password" => {
136+
builder.password(value);
137+
}
138+
#[cfg(feature = "runtime")]
139+
"host" => {
140+
for host in value.split(',') {
141+
builder.host(host);
142+
}
143+
}
144+
#[cfg(feature = "runtime")]
145+
"port" => {
146+
for port in value.split(',') {
147+
let port = if port.is_empty() {
148+
5432
149+
} else {
150+
port.parse().map_err(Error::invalid_port)?
151+
};
152+
builder.port(port);
153+
}
154+
}
155+
#[cfg(feature = "runtime")]
156+
"connect_timeout" => {
157+
let timeout = value
158+
.parse::<i64>()
159+
.map_err(Error::invalid_connect_timeout)?;
160+
if timeout > 0 {
161+
builder.connect_timeout(Duration::from_secs(timeout as u64));
162+
}
163+
}
164+
key => {
165+
builder.param(key, &value);
166+
}
167+
}
81168
}
82169

83170
Ok(builder)

0 commit comments

Comments
 (0)