Skip to content

Commit defe764

Browse files
committed
Support connect_timeout
1 parent d5104bc commit defe764

File tree

3 files changed

+81
-2
lines changed

3 files changed

+81
-2
lines changed

tokio-postgres/Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ circle-ci = { repository = "sfackler/rust-postgres" }
2828

2929
[features]
3030
default = ["runtime"]
31-
runtime = ["tokio-tcp", "tokio-uds", "futures-cpupool", "lazy_static"]
31+
runtime = ["tokio-tcp", "tokio-timer", "tokio-uds", "futures-cpupool", "lazy_static"]
3232

3333
"with-bit-vec-0.5" = ["bit-vec-05"]
3434
"with-chrono-0.4" = ["chrono-04"]
@@ -53,6 +53,7 @@ void = "1.0"
5353
tokio-tcp = { version = "0.1", optional = true }
5454
futures-cpupool = { version = "0.1", optional = true }
5555
lazy_static = { version = "1.0", optional = true }
56+
tokio-timer = { version = "0.2", optional = true }
5657

5758
bit-vec-05 = { version = "0.5", package = "bit-vec", optional = true }
5859
chrono-04 = { version = "0.4", package = "chrono", optional = true }

tokio-postgres/src/error/mod.rs

+27
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,12 @@ enum Kind {
356356
InvalidPort,
357357
#[cfg(feature = "runtime")]
358358
InvalidPortCount,
359+
#[cfg(feature = "runtime")]
360+
InvalidConnectTimeout,
361+
#[cfg(feature = "runtime")]
362+
Timer,
363+
#[cfg(feature = "runtime")]
364+
ConnectTimeout,
359365
}
360366

361367
struct ErrorInner {
@@ -401,6 +407,12 @@ impl fmt::Display for Error {
401407
Kind::InvalidPort => "invalid port",
402408
#[cfg(feature = "runtime")]
403409
Kind::InvalidPortCount => "wrong number of ports provided",
410+
#[cfg(feature = "runtime")]
411+
Kind::InvalidConnectTimeout => "invalid connect_timeout",
412+
#[cfg(feature = "runtime")]
413+
Kind::Timer => "timer error",
414+
#[cfg(feature = "runtime")]
415+
Kind::ConnectTimeout => "timed out connecting to server",
404416
};
405417
fmt.write_str(s)?;
406418
if let Some(ref cause) = self.0.cause {
@@ -523,4 +535,19 @@ impl Error {
523535
pub(crate) fn invalid_port_count() -> Error {
524536
Error::new(Kind::InvalidPortCount, None)
525537
}
538+
539+
#[cfg(feature = "runtime")]
540+
pub(crate) fn invalid_connect_timeout(e: ParseIntError) -> Error {
541+
Error::new(Kind::InvalidConnectTimeout, Some(Box::new(e)))
542+
}
543+
544+
#[cfg(feature = "runtime")]
545+
pub(crate) fn timer(e: tokio_timer::Error) -> Error {
546+
Error::new(Kind::Timer, Some(Box::new(e)))
547+
}
548+
549+
#[cfg(feature = "runtime")]
550+
pub(crate) fn connect_timeout() -> Error {
551+
Error::new(Kind::ConnectTimeout, None)
552+
}
526553
}

tokio-postgres/src/proto/connect_once.rs

+52-1
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ use std::io;
77
use std::net::{SocketAddr, ToSocketAddrs};
88
#[cfg(unix)]
99
use std::path::Path;
10+
use std::time::{Duration, Instant};
1011
use std::vec;
1112
use tokio_tcp::TcpStream;
13+
use tokio_timer::Delay;
1214
#[cfg(unix)]
1315
use tokio_uds::UnixStream;
1416

@@ -40,19 +42,25 @@ where
4042
#[state_machine_future(transitions(Handshaking))]
4143
ConnectingUnix {
4244
future: tokio_uds::ConnectFuture,
45+
connect_timeout: Option<Duration>,
46+
timeout: Option<Delay>,
4347
tls_mode: T,
4448
params: HashMap<String, String>,
4549
},
4650
#[state_machine_future(transitions(ConnectingTcp))]
4751
ResolvingDns {
4852
future: CpuFuture<vec::IntoIter<SocketAddr>, io::Error>,
53+
connect_timeout: Option<Duration>,
54+
timeout: Option<Delay>,
4955
tls_mode: T,
5056
params: HashMap<String, String>,
5157
},
5258
#[state_machine_future(transitions(Handshaking))]
5359
ConnectingTcp {
5460
future: tokio_tcp::ConnectFuture,
5561
addrs: vec::IntoIter<SocketAddr>,
62+
connect_timeout: Option<Duration>,
63+
timeout: Option<Delay>,
5664
tls_mode: T,
5765
params: HashMap<String, String>,
5866
},
@@ -69,14 +77,29 @@ where
6977
T: TlsMode<Socket>,
7078
{
7179
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
72-
let state = state.take();
80+
let mut state = state.take();
81+
82+
let connect_timeout = match state.params.remove("connect_timeout") {
83+
Some(s) => {
84+
let seconds = s.parse::<i64>().map_err(Error::invalid_connect_timeout)?;
85+
if seconds <= 0 {
86+
None
87+
} else {
88+
Some(Duration::from_secs(seconds as u64))
89+
}
90+
}
91+
None => None,
92+
};
93+
let timeout = connect_timeout.map(|d| Delay::new(Instant::now() + d));
7394

7495
#[cfg(unix)]
7596
{
7697
if state.host.starts_with('/') {
7798
let path = Path::new(&state.host).join(format!(".s.PGSQL.{}", state.port));
7899
transition!(ConnectingUnix {
79100
future: UnixStream::connect(path),
101+
connect_timeout,
102+
timeout,
80103
tls_mode: state.tls_mode,
81104
params: state.params,
82105
})
@@ -87,6 +110,8 @@ where
87110
let port = state.port;
88111
transition!(ResolvingDns {
89112
future: DNS_POOL.spawn_fn(move || (&*host, port).to_socket_addrs()),
113+
connect_timeout,
114+
timeout,
90115
tls_mode: state.tls_mode,
91116
params: state.params,
92117
})
@@ -96,6 +121,14 @@ where
96121
fn poll_connecting_unix<'a>(
97122
state: &'a mut RentToOwn<'a, ConnectingUnix<T>>,
98123
) -> Poll<AfterConnectingUnix<T>, Error> {
124+
if let Some(timeout) = &mut state.timeout {
125+
match timeout.poll() {
126+
Ok(Async::Ready(())) => return Err(Error::connect_timeout()),
127+
Ok(Async::NotReady) => {}
128+
Err(e) => return Err(Error::timer(e)),
129+
}
130+
}
131+
99132
let stream = try_ready!(state.future.poll().map_err(Error::connect));
100133
let stream = Socket::new_unix(stream);
101134
let state = state.take();
@@ -108,6 +141,14 @@ where
108141
fn poll_resolving_dns<'a>(
109142
state: &'a mut RentToOwn<'a, ResolvingDns<T>>,
110143
) -> Poll<AfterResolvingDns<T>, Error> {
144+
if let Some(timeout) = &mut state.timeout {
145+
match timeout.poll() {
146+
Ok(Async::Ready(())) => return Err(Error::connect_timeout()),
147+
Ok(Async::NotReady) => {}
148+
Err(e) => return Err(Error::timer(e)),
149+
}
150+
}
151+
111152
let mut addrs = try_ready!(state.future.poll().map_err(Error::connect));
112153
let state = state.take();
113154

@@ -124,6 +165,8 @@ where
124165
transition!(ConnectingTcp {
125166
future: TcpStream::connect(&addr),
126167
addrs,
168+
connect_timeout: state.connect_timeout,
169+
timeout: state.timeout,
127170
tls_mode: state.tls_mode,
128171
params: state.params,
129172
})
@@ -132,6 +175,14 @@ where
132175
fn poll_connecting_tcp<'a>(
133176
state: &'a mut RentToOwn<'a, ConnectingTcp<T>>,
134177
) -> Poll<AfterConnectingTcp<T>, Error> {
178+
if let Some(timeout) = &mut state.timeout {
179+
match timeout.poll() {
180+
Ok(Async::Ready(())) => return Err(Error::connect_timeout()),
181+
Ok(Async::NotReady) => {}
182+
Err(e) => return Err(Error::timer(e)),
183+
}
184+
}
185+
135186
let stream = loop {
136187
match state.future.poll() {
137188
Ok(Async::Ready(stream)) => break stream,

0 commit comments

Comments
 (0)