Skip to content

Commit 932a7b1

Browse files
committed
Add a connect timeout
cc sfackler#246
1 parent 5518e0d commit 932a7b1

File tree

4 files changed

+95
-72
lines changed

4 files changed

+95
-72
lines changed

postgres-shared/src/params/mod.rs

+52-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
use std::error::Error;
33
use std::path::PathBuf;
44
use std::mem;
5+
use std::time::Duration;
56

67
use params::url::Url;
78

@@ -43,6 +44,7 @@ pub struct ConnectParams {
4344
user: Option<User>,
4445
database: Option<String>,
4546
options: Vec<(String, String)>,
47+
connect_timeout: Option<Duration>,
4648
}
4749

4850
impl ConnectParams {
@@ -79,6 +81,11 @@ impl ConnectParams {
7981
pub fn options(&self) -> &[(String, String)] {
8082
&self.options
8183
}
84+
85+
/// A timeout to apply to each socket-level connection attempt.
86+
pub fn connect_timeout(&self) -> Option<Duration> {
87+
self.connect_timeout
88+
}
8289
}
8390

8491
/// A builder for `ConnectParams`.
@@ -87,6 +94,7 @@ pub struct Builder {
8794
user: Option<User>,
8895
database: Option<String>,
8996
options: Vec<(String, String)>,
97+
connect_timeout: Option<Duration>,
9098
}
9199

92100
impl Builder {
@@ -97,6 +105,7 @@ impl Builder {
97105
user: None,
98106
database: None,
99107
options: vec![],
108+
connect_timeout: None,
100109
}
101110
}
102111

@@ -127,6 +136,12 @@ impl Builder {
127136
self
128137
}
129138

139+
/// Sets the connection timeout.
140+
pub fn connect_timeout(&mut self, connect_timeout: Option<Duration>) -> &mut Builder {
141+
self.connect_timeout = connect_timeout;
142+
self
143+
}
144+
130145
/// Constructs a `ConnectParams` from the builder.
131146
pub fn build(&mut self, host: Host) -> ConnectParams {
132147
ConnectParams {
@@ -135,6 +150,7 @@ impl Builder {
135150
user: self.user.take(),
136151
database: self.database.take(),
137152
options: mem::replace(&mut self.options, vec![]),
153+
connect_timeout: self.connect_timeout,
138154
}
139155
}
140156
}
@@ -196,7 +212,16 @@ impl IntoConnectParams for Url {
196212
}
197213

198214
for (name, value) in options {
199-
builder.option(&name, &value);
215+
match &*name {
216+
"connect_timeout" => {
217+
let timeout = value.parse().map_err(|_| "invalid connect_timeout")?;
218+
let timeout = Duration::from_secs(timeout);
219+
builder.connect_timeout(Some(timeout));
220+
}
221+
_ => {
222+
builder.option(&name, &value);
223+
}
224+
}
200225
}
201226

202227
let maybe_path = url::decode_component(&host)?;
@@ -209,3 +234,29 @@ impl IntoConnectParams for Url {
209234
Ok(builder.build(host))
210235
}
211236
}
237+
238+
#[cfg(test)]
239+
mod test {
240+
use super::*;
241+
242+
#[test]
243+
fn parse_url() {
244+
let params = "postgres://user@host:44/dbname?connect_timeout=10&application_name=foo";
245+
let params = params.into_connect_params().unwrap();
246+
assert_eq!(
247+
params.user(),
248+
Some(&User {
249+
name: "user".to_string(),
250+
password: None,
251+
})
252+
);
253+
assert_eq!(params.host(), &Host::Tcp("host".to_string()));
254+
assert_eq!(params.port(), 44);
255+
assert_eq!(params.database(), Some("dbname"));
256+
assert_eq!(
257+
params.options(),
258+
&[("application_name".to_string(), "foo".to_string())][..]
259+
);
260+
assert_eq!(params.connect_timeout(), Some(Duration::from_secs(10)));
261+
}
262+
}

postgres/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ no-logging = []
5959
bytes = "0.4"
6060
fallible-iterator = "0.1.3"
6161
log = "0.3"
62+
socket2 = "0.2"
6263

6364
openssl = { version = "0.9.2", optional = true }
6465
native-tls = { version = "0.1", optional = true }

postgres/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ extern crate fallible_iterator;
7777
extern crate log;
7878
extern crate postgres_protocol;
7979
extern crate postgres_shared;
80+
extern crate socket2;
8081

8182
use fallible_iterator::FallibleIterator;
8283
use std::cell::{Cell, RefCell};

postgres/src/priv_io.rs

+41-71
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
use std::io::{self, BufWriter, Read, Write};
2-
use std::fmt;
3-
use std::net::TcpStream;
2+
use std::net::{ToSocketAddrs, SocketAddr};
43
use std::time::Duration;
54
use std::result;
65
use bytes::{BufMut, BytesMut};
76
#[cfg(unix)]
87
use std::os::unix::net::UnixStream;
98
#[cfg(unix)]
10-
use std::os::unix::io::{AsRawFd, RawFd};
9+
use std::os::unix::io::{AsRawFd, RawFd, FromRawFd, IntoRawFd};
1110
#[cfg(windows)]
1211
use std::os::windows::io::{AsRawSocket, RawSocket};
1312
use postgres_protocol::message::frontend;
1413
use postgres_protocol::message::backend;
14+
use socket2::{Socket, SockAddr, Domain, Type};
1515

1616
use {Result, TlsMode};
1717
use error;
@@ -118,37 +118,20 @@ impl MessageStream {
118118
}
119119

120120
fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
121-
match self.stream.get_ref().get_ref().0 {
122-
InternalStream::Tcp(ref s) => s.set_read_timeout(timeout),
123-
#[cfg(unix)]
124-
InternalStream::Unix(ref s) => s.set_read_timeout(timeout),
125-
}
121+
self.stream.get_ref().get_ref().0.set_read_timeout(timeout)
126122
}
127123

128124
fn set_nonblocking(&self, nonblock: bool) -> io::Result<()> {
129-
match self.stream.get_ref().get_ref().0 {
130-
InternalStream::Tcp(ref s) => s.set_nonblocking(nonblock),
131-
#[cfg(unix)]
132-
InternalStream::Unix(ref s) => s.set_nonblocking(nonblock),
133-
}
125+
self.stream.get_ref().get_ref().0.set_nonblocking(nonblock)
134126
}
135127
}
136128

137129
/// A connection to the Postgres server.
138130
///
139131
/// It implements `Read`, `Write` and `TlsStream`, as well as `AsRawFd` on
140132
/// Unix platforms and `AsRawSocket` on Windows platforms.
141-
pub struct Stream(InternalStream);
142-
143-
impl fmt::Debug for Stream {
144-
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
145-
match self.0 {
146-
InternalStream::Tcp(ref s) => fmt::Debug::fmt(s, fmt),
147-
#[cfg(unix)]
148-
InternalStream::Unix(ref s) => fmt::Debug::fmt(s, fmt),
149-
}
150-
}
151-
}
133+
#[derive(Debug)]
134+
pub struct Stream(Socket);
152135

153136
impl Read for Stream {
154137
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
@@ -179,69 +162,56 @@ impl TlsStream for Stream {
179162
#[cfg(unix)]
180163
impl AsRawFd for Stream {
181164
fn as_raw_fd(&self) -> RawFd {
182-
match self.0 {
183-
InternalStream::Tcp(ref s) => s.as_raw_fd(),
184-
InternalStream::Unix(ref s) => s.as_raw_fd(),
185-
}
165+
self.0.as_raw_fd()
186166
}
187167
}
188168

189169
#[cfg(windows)]
190170
impl AsRawSocket for Stream {
191171
fn as_raw_socket(&self) -> RawSocket {
192-
// Unix sockets aren't supported on windows, so no need to match
193-
match self.0 {
194-
InternalStream::Tcp(ref s) => s.as_raw_socket(),
195-
}
196-
}
197-
}
198-
199-
enum InternalStream {
200-
Tcp(TcpStream),
201-
#[cfg(unix)]
202-
Unix(UnixStream),
203-
}
204-
205-
impl Read for InternalStream {
206-
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
207-
match *self {
208-
InternalStream::Tcp(ref mut s) => s.read(buf),
209-
#[cfg(unix)]
210-
InternalStream::Unix(ref mut s) => s.read(buf),
211-
}
172+
self.0.as_raw_socket()
212173
}
213174
}
214175

215-
impl Write for InternalStream {
216-
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
217-
match *self {
218-
InternalStream::Tcp(ref mut s) => s.write(buf),
219-
#[cfg(unix)]
220-
InternalStream::Unix(ref mut s) => s.write(buf),
221-
}
222-
}
223-
224-
fn flush(&mut self) -> io::Result<()> {
225-
match *self {
226-
InternalStream::Tcp(ref mut s) => s.flush(),
227-
#[cfg(unix)]
228-
InternalStream::Unix(ref mut s) => s.flush(),
229-
}
230-
}
231-
}
232-
233-
fn open_socket(params: &ConnectParams) -> Result<InternalStream> {
176+
fn open_socket(params: &ConnectParams) -> Result<Socket> {
234177
let port = params.port();
235178
match *params.host() {
236179
Host::Tcp(ref host) => {
237-
Ok(TcpStream::connect(&(&**host, port)).map(
238-
InternalStream::Tcp,
239-
)?)
180+
let mut error = None;
181+
for addr in (&**host, port).to_socket_addrs()? {
182+
let domain = match addr {
183+
SocketAddr::V4(_) => Domain::ipv4(),
184+
SocketAddr::V6(_) => Domain::ipv6(),
185+
};
186+
let socket = Socket::new(domain, Type::stream(), None)?;
187+
let addr = SockAddr::from(addr);
188+
let r = match params.connect_timeout() {
189+
Some(timeout) => socket.connect_timeout(&addr, timeout),
190+
None => socket.connect(&addr),
191+
};
192+
match r {
193+
Ok(()) => return Ok(socket),
194+
Err(e) => error = Some(e),
195+
}
196+
}
197+
198+
Err(
199+
error
200+
.unwrap_or_else(|| {
201+
io::Error::new(
202+
io::ErrorKind::InvalidInput,
203+
"could not resolve any addresses",
204+
)
205+
})
206+
.into(),
207+
)
240208
}
241209
#[cfg(unix)]
242210
Host::Unix(ref path) => {
243211
let path = path.join(&format!(".s.PGSQL.{}", port));
244-
Ok(UnixStream::connect(&path).map(InternalStream::Unix)?)
212+
Ok(UnixStream::connect(&path).map(|s| unsafe {
213+
Socket::from_raw_fd(s.into_raw_fd())
214+
})?)
245215
}
246216
#[cfg(not(unix))]
247217
Host::Unix(..) => {

0 commit comments

Comments
 (0)