Skip to content

Commit 23b0d6e

Browse files
committed
Support multiple hosts when connecting
cc sfackler#399
1 parent 7e7ae96 commit 23b0d6e

File tree

3 files changed

+98
-21
lines changed

3 files changed

+98
-21
lines changed

tokio-postgres/src/error/mod.rs

+9
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,8 @@ enum Kind {
354354
MissingHost,
355355
#[cfg(feature = "runtime")]
356356
InvalidPort,
357+
#[cfg(feature = "runtime")]
358+
InvalidPortCount,
357359
}
358360

359361
struct ErrorInner {
@@ -397,6 +399,8 @@ impl fmt::Display for Error {
397399
Kind::MissingHost => "host not provided",
398400
#[cfg(feature = "runtime")]
399401
Kind::InvalidPort => "invalid port",
402+
#[cfg(feature = "runtime")]
403+
Kind::InvalidPortCount => "wrong number of ports provided",
400404
};
401405
fmt.write_str(s)?;
402406
if let Some(ref cause) = self.0.cause {
@@ -514,4 +518,9 @@ impl Error {
514518
pub(crate) fn invalid_port(e: ParseIntError) -> Error {
515519
Error::new(Kind::InvalidPort, Some(Box::new(e)))
516520
}
521+
522+
#[cfg(feature = "runtime")]
523+
pub(crate) fn invalid_port_count() -> Error {
524+
Error::new(Kind::InvalidPortCount, None)
525+
}
517526
}

tokio-postgres/src/proto/connect.rs

+62-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
use futures::{try_ready, Future, Poll};
1+
use futures::{try_ready, Async, Future, Poll};
22
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
33
use std::collections::HashMap;
4+
use std::vec;
45

56
use crate::proto::{Client, ConnectOnceFuture, Connection};
67
use crate::{Error, MakeTlsMode, Socket};
@@ -20,11 +21,16 @@ where
2021
future: T::Future,
2122
host: String,
2223
port: u16,
24+
addrs: vec::IntoIter<(String, u16)>,
25+
make_tls_mode: T,
2326
params: HashMap<String, String>,
2427
},
25-
#[state_machine_future(transitions(Finished))]
28+
#[state_machine_future(transitions(MakingTlsMode, Finished))]
2629
Connecting {
2730
future: ConnectOnceFuture<T::TlsMode>,
31+
addrs: vec::IntoIter<(String, u16)>,
32+
make_tls_mode: T,
33+
params: HashMap<String, String>,
2834
},
2935
#[state_machine_future(ready)]
3036
Finished((Client, Connection<T::Stream>)),
@@ -43,16 +49,42 @@ where
4349
Some(host) => host,
4450
None => return Err(Error::missing_host()),
4551
};
52+
let mut addrs = host
53+
.split(',')
54+
.map(|s| (s.to_string(), 0u16))
55+
.collect::<Vec<_>>();
4656

47-
let port = match state.params.remove("port") {
48-
Some(port) => port.parse::<u16>().map_err(Error::invalid_port)?,
49-
None => 5432,
50-
};
57+
let port = state.params.remove("port").unwrap_or_else(String::new);
58+
let mut ports = port
59+
.split(',')
60+
.map(|s| {
61+
if s.is_empty() {
62+
Ok(5432)
63+
} else {
64+
s.parse::<u16>().map_err(Error::invalid_port)
65+
}
66+
})
67+
.collect::<Result<Vec<_>, _>>()?;
68+
if ports.len() == 1 {
69+
ports.resize(addrs.len(), ports[0]);
70+
}
71+
if addrs.len() != ports.len() {
72+
return Err(Error::invalid_port_count());
73+
}
74+
75+
for (addr, port) in addrs.iter_mut().zip(ports) {
76+
addr.1 = port;
77+
}
78+
79+
let mut addrs = addrs.into_iter();
80+
let (host, port) = addrs.next().expect("addrs cannot be empty");
5181

5282
transition!(MakingTlsMode {
5383
future: state.make_tls_mode.make_tls_mode(&host),
5484
host,
5585
port,
86+
addrs,
87+
make_tls_mode: state.make_tls_mode,
5688
params: state.params,
5789
})
5890
}
@@ -64,15 +96,36 @@ where
6496
let state = state.take();
6597

6698
transition!(Connecting {
67-
future: ConnectOnceFuture::new(state.host, state.port, tls_mode, state.params),
99+
future: ConnectOnceFuture::new(state.host, state.port, tls_mode, state.params.clone()),
100+
addrs: state.addrs,
101+
make_tls_mode: state.make_tls_mode,
102+
params: state.params,
68103
})
69104
}
70105

71106
fn poll_connecting<'a>(
72107
state: &'a mut RentToOwn<'a, Connecting<T>>,
73108
) -> Poll<AfterConnecting<T>, Error> {
74-
let r = try_ready!(state.future.poll());
75-
transition!(Finished(r))
109+
match state.future.poll() {
110+
Ok(Async::Ready(r)) => transition!(Finished(r)),
111+
Ok(Async::NotReady) => Ok(Async::NotReady),
112+
Err(e) => {
113+
let mut state = state.take();
114+
let (host, port) = match state.addrs.next() {
115+
Some(addr) => addr,
116+
None => return Err(e),
117+
};
118+
119+
transition!(MakingTlsMode {
120+
future: state.make_tls_mode.make_tls_mode(&host),
121+
host,
122+
port,
123+
addrs: state.addrs,
124+
make_tls_mode: state.make_tls_mode,
125+
params: state.params,
126+
})
127+
}
128+
}
76129
}
77130
}
78131

tokio-postgres/tests/test/runtime.rs

+27-12
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,9 @@ fn connect(s: &str) -> impl Future<Item = (Client, Connection<Socket>), Error =
66
s.parse::<tokio_postgres::Builder>().unwrap().connect(NoTls)
77
}
88

9-
#[test]
10-
#[ignore] // FIXME doesn't work with our docker-based tests :(
11-
fn unix_socket() {
9+
fn smoke_test(s: &str) {
1210
let mut runtime = Runtime::new().unwrap();
13-
14-
let connect = connect("host=/var/run/postgresql port=5433 user=postgres");
11+
let connect = connect(s);
1512
let (mut client, connection) = runtime.block_on(connect).unwrap();
1613
let connection = connection.map_err(|e| panic!("{}", e));
1714
runtime.spawn(connection);
@@ -20,15 +17,33 @@ fn unix_socket() {
2017
runtime.block_on(execute).unwrap();
2118
}
2219

20+
#[test]
21+
#[ignore] // FIXME doesn't work with our docker-based tests :(
22+
fn unix_socket() {
23+
smoke_test("host=/var/run/postgresql port=5433 user=postgres");
24+
}
25+
2326
#[test]
2427
fn tcp() {
25-
let mut runtime = Runtime::new().unwrap();
28+
smoke_test("host=localhost port=5433 user=postgres")
29+
}
2630

27-
let connect = connect("host=localhost port=5433 user=postgres");
28-
let (mut client, connection) = runtime.block_on(connect).unwrap();
29-
let connection = connection.map_err(|e| panic!("{}", e));
30-
runtime.spawn(connection);
31+
#[test]
32+
fn multiple_hosts_one_port() {
33+
smoke_test("host=foobar.invalid,localhost port=5433 user=postgres");
34+
}
3135

32-
let execute = client.batch_execute("SELECT 1");
33-
runtime.block_on(execute).unwrap();
36+
#[test]
37+
fn multiple_hosts_multiple_ports() {
38+
smoke_test("host=foobar.invalid,localhost port=5432,5433 user=postgres");
39+
}
40+
41+
#[test]
42+
fn wrong_port_count() {
43+
let mut runtime = Runtime::new().unwrap();
44+
let f = connect("host=localhost port=5433,5433 user=postgres");
45+
runtime.block_on(f).err().unwrap();
46+
47+
let f = connect("host=localhost,localhost,localhost port=5433,5433 user=postgres");
48+
runtime.block_on(f).err().unwrap();
3449
}

0 commit comments

Comments
 (0)