forked from sfackler/rust-postgres
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconnect_socket.rs
93 lines (85 loc) · 3.05 KB
/
connect_socket.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
use crate::config::Host;
use crate::{Error, Socket};
use socket2::{Domain, Protocol, Type};
use std::future::Future;
use std::io;
use std::net::SocketAddr;
#[cfg(unix)]
use std::os::unix::io::{FromRawFd, IntoRawFd};
#[cfg(windows)]
use std::os::windows::io::{FromRawSocket, IntoRawSocket};
use std::time::Duration;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::net::{self, TcpSocket};
use tokio::time;
pub(crate) async fn connect_socket(
host: &Host,
port: u16,
connect_timeout: Option<Duration>,
keepalives: bool,
keepalives_idle: Duration,
) -> Result<Socket, Error> {
match host {
Host::Tcp(host) => {
let addrs = net::lookup_host((&**host, port))
.await
.map_err(Error::connect)?;
let mut last_err = None;
for addr in addrs {
let domain = match addr {
SocketAddr::V4(_) => Domain::ipv4(),
SocketAddr::V6(_) => Domain::ipv6(),
};
let socket = socket2::Socket::new(domain, Type::stream(), Some(Protocol::tcp()))
.map_err(Error::connect)?;
socket.set_nonblocking(true).map_err(Error::connect)?;
socket.set_nodelay(true).map_err(Error::connect)?;
if keepalives {
socket
.set_keepalive(Some(keepalives_idle))
.map_err(Error::connect)?;
}
#[cfg(unix)]
let socket = unsafe { TcpSocket::from_raw_fd(socket.into_raw_fd()) };
#[cfg(windows)]
let socket = unsafe { TcpSocket::from_raw_socket(socket.into_raw_socket()) };
match connect_with_timeout(socket.connect(addr), connect_timeout).await {
Ok(socket) => return Ok(Socket::new_tcp(socket)),
Err(e) => last_err = Some(e),
}
}
Err(last_err.unwrap_or_else(|| {
Error::connect(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve any addresses",
))
}))
}
#[cfg(unix)]
Host::Unix(path) => {
let path = path.join(format!(".s.PGSQL.{}", port));
let socket = connect_with_timeout(UnixStream::connect(path), connect_timeout).await?;
Ok(Socket::new_unix(socket))
}
}
}
async fn connect_with_timeout<F, T>(connect: F, timeout: Option<Duration>) -> Result<T, Error>
where
F: Future<Output = io::Result<T>>,
{
match timeout {
Some(timeout) => match time::timeout(timeout, connect).await {
Ok(Ok(socket)) => Ok(socket),
Ok(Err(e)) => Err(Error::connect(e)),
Err(_) => Err(Error::connect(io::Error::new(
io::ErrorKind::TimedOut,
"connection timed out",
))),
},
None => match connect.await {
Ok(socket) => Ok(socket),
Err(e) => Err(Error::connect(e)),
},
}
}