Skip to content

Commit c2fb9c6

Browse files
committed
Move TLS logic to connect future
This way we can reuse it for query cancellation
1 parent 6edab70 commit c2fb9c6

File tree

4 files changed

+287
-264
lines changed

4 files changed

+287
-264
lines changed

tokio-postgres/src/proto/connect.rs

+276
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
use futures::{Async, Future, Poll};
2+
use futures_cpupool::{CpuFuture, CpuPool};
3+
use postgres_protocol::message::frontend;
4+
use state_machine_future::RentToOwn;
5+
use std::error::Error as StdError;
6+
use std::io;
7+
use std::net::{SocketAddr, ToSocketAddrs};
8+
use std::time::{Duration, Instant};
9+
use std::vec;
10+
use tokio_io::io::{read_exact, write_all, ReadExact, WriteAll};
11+
use tokio_tcp::{self, TcpStream};
12+
use tokio_timer::Delay;
13+
14+
#[cfg(unix)]
15+
use tokio_uds::{self, UnixStream};
16+
17+
use error::{self, Error};
18+
use params::{ConnectParams, Host};
19+
use proto::socket::Socket;
20+
use tls::{self, TlsConnect, TlsStream};
21+
use {bad_response, TlsMode};
22+
23+
lazy_static! {
24+
static ref DNS_POOL: CpuPool = CpuPool::new(2);
25+
}
26+
27+
#[derive(StateMachineFuture)]
28+
pub enum Connect {
29+
#[state_machine_future(start)]
30+
#[cfg_attr(unix, state_machine_future(transitions(ResolvingDns, ConnectingUnix)))]
31+
#[cfg_attr(not(unix), state_machine_future(transitions(ResolvingDns)))]
32+
Start { params: ConnectParams, tls: TlsMode },
33+
#[state_machine_future(transitions(ConnectingTcp))]
34+
ResolvingDns {
35+
future: CpuFuture<vec::IntoIter<SocketAddr>, io::Error>,
36+
timeout: Option<Duration>,
37+
params: ConnectParams,
38+
tls: TlsMode,
39+
},
40+
#[state_machine_future(transitions(PreparingSsl))]
41+
ConnectingTcp {
42+
addrs: vec::IntoIter<SocketAddr>,
43+
future: tokio_tcp::ConnectFuture,
44+
timeout: Option<(Duration, Delay)>,
45+
params: ConnectParams,
46+
tls: TlsMode,
47+
},
48+
#[cfg(unix)]
49+
#[state_machine_future(transitions(PreparingSsl))]
50+
ConnectingUnix {
51+
future: tokio_uds::ConnectFuture,
52+
timeout: Option<Delay>,
53+
params: ConnectParams,
54+
tls: TlsMode,
55+
},
56+
#[state_machine_future(transitions(Ready, SendingSsl))]
57+
PreparingSsl {
58+
socket: Socket,
59+
params: ConnectParams,
60+
tls: TlsMode,
61+
},
62+
#[state_machine_future(transitions(ReadingSsl))]
63+
SendingSsl {
64+
future: WriteAll<Socket, Vec<u8>>,
65+
params: ConnectParams,
66+
connector: Box<TlsConnect>,
67+
required: bool,
68+
},
69+
#[state_machine_future(transitions(ConnectingTls, Ready))]
70+
ReadingSsl {
71+
future: ReadExact<Socket, [u8; 1]>,
72+
params: ConnectParams,
73+
connector: Box<TlsConnect>,
74+
required: bool,
75+
},
76+
#[state_machine_future(transitions(Ready))]
77+
ConnectingTls {
78+
future:
79+
Box<Future<Item = Box<TlsStream>, Error = Box<StdError + Sync + Send>> + Sync + Send>,
80+
params: ConnectParams,
81+
},
82+
#[state_machine_future(ready)]
83+
Ready(Box<TlsStream>),
84+
#[state_machine_future(error)]
85+
Failed(Error),
86+
}
87+
88+
impl PollConnect for Connect {
89+
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll<AfterStart, Error> {
90+
let state = state.take();
91+
92+
let timeout = state.params.connect_timeout();
93+
let port = state.params.port();
94+
95+
match state.params.host().clone() {
96+
Host::Tcp(host) => transition!(ResolvingDns {
97+
future: DNS_POOL.spawn_fn(move || (&*host, port).to_socket_addrs()),
98+
params: state.params,
99+
tls: state.tls,
100+
timeout,
101+
}),
102+
#[cfg(unix)]
103+
Host::Unix(mut path) => {
104+
path.push(format!(".s.PGSQL.{}", port));
105+
transition!(ConnectingUnix {
106+
future: UnixStream::connect(path),
107+
timeout: timeout.map(|t| Delay::new(Instant::now() + t)),
108+
params: state.params,
109+
tls: state.tls,
110+
})
111+
}
112+
}
113+
}
114+
115+
fn poll_resolving_dns<'a>(
116+
state: &'a mut RentToOwn<'a, ResolvingDns>,
117+
) -> Poll<AfterResolvingDns, Error> {
118+
let mut addrs = try_ready!(state.future.poll());
119+
let state = state.take();
120+
121+
let addr = match addrs.next() {
122+
Some(addr) => addr,
123+
None => {
124+
return Err(io::Error::new(io::ErrorKind::Other, "resolved to 0 addresses").into())
125+
}
126+
};
127+
128+
transition!(ConnectingTcp {
129+
addrs,
130+
future: TcpStream::connect(&addr),
131+
timeout: state.timeout.map(|t| (t, Delay::new(Instant::now() + t))),
132+
params: state.params,
133+
tls: state.tls,
134+
})
135+
}
136+
137+
fn poll_connecting_tcp<'a>(
138+
state: &'a mut RentToOwn<'a, ConnectingTcp>,
139+
) -> Poll<AfterConnectingTcp, Error> {
140+
loop {
141+
let error = match state.future.poll() {
142+
Ok(Async::Ready(socket)) => {
143+
let state = state.take();
144+
transition!(PreparingSsl {
145+
socket: Socket::Tcp(socket),
146+
params: state.params,
147+
tls: state.tls,
148+
})
149+
}
150+
Ok(Async::NotReady) => match state.timeout {
151+
Some((_, ref mut delay)) => {
152+
try_ready!(
153+
delay
154+
.poll()
155+
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
156+
);
157+
io::Error::new(io::ErrorKind::TimedOut, "connection timed out")
158+
}
159+
None => return Ok(Async::NotReady),
160+
},
161+
Err(e) => e,
162+
};
163+
164+
let addr = match state.addrs.next() {
165+
Some(addr) => addr,
166+
None => return Err(error.into()),
167+
};
168+
169+
state.future = TcpStream::connect(&addr);
170+
if let Some((timeout, ref mut delay)) = state.timeout {
171+
delay.reset(Instant::now() + timeout);
172+
}
173+
}
174+
}
175+
176+
#[cfg(unix)]
177+
fn poll_connecting_unix<'a>(
178+
state: &'a mut RentToOwn<'a, ConnectingUnix>,
179+
) -> Poll<AfterConnectingUnix, Error> {
180+
match state.future.poll()? {
181+
Async::Ready(socket) => {
182+
let state = state.take();
183+
transition!(PreparingSsl {
184+
socket: Socket::Unix(socket),
185+
params: state.params,
186+
tls: state.tls,
187+
})
188+
}
189+
Async::NotReady => match state.timeout {
190+
Some(ref mut delay) => {
191+
try_ready!(
192+
delay
193+
.poll()
194+
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
195+
);
196+
Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out").into())
197+
}
198+
None => Ok(Async::NotReady),
199+
},
200+
}
201+
}
202+
203+
fn poll_preparing_ssl<'a>(
204+
state: &'a mut RentToOwn<'a, PreparingSsl>,
205+
) -> Poll<AfterPreparingSsl, Error> {
206+
let state = state.take();
207+
208+
let (connector, required) = match state.tls {
209+
TlsMode::None => {
210+
transition!(Ready(Box::new(state.socket)));
211+
}
212+
TlsMode::Prefer(connector) => (connector, false),
213+
TlsMode::Require(connector) => (connector, true),
214+
};
215+
216+
let mut buf = vec![];
217+
frontend::ssl_request(&mut buf);
218+
transition!(SendingSsl {
219+
future: write_all(state.socket, buf),
220+
params: state.params,
221+
connector,
222+
required,
223+
})
224+
}
225+
226+
fn poll_sending_ssl<'a>(
227+
state: &'a mut RentToOwn<'a, SendingSsl>,
228+
) -> Poll<AfterSendingSsl, Error> {
229+
let (stream, _) = try_ready!(state.future.poll());
230+
let state = state.take();
231+
transition!(ReadingSsl {
232+
future: read_exact(stream, [0]),
233+
params: state.params,
234+
connector: state.connector,
235+
required: state.required,
236+
})
237+
}
238+
239+
fn poll_reading_ssl<'a>(
240+
state: &'a mut RentToOwn<'a, ReadingSsl>,
241+
) -> Poll<AfterReadingSsl, Error> {
242+
let (stream, buf) = try_ready!(state.future.poll());
243+
let state = state.take();
244+
245+
match buf[0] {
246+
b'S' => {
247+
let future = match state.params.host() {
248+
Host::Tcp(domain) => state.connector.connect(domain, tls::Socket(stream)),
249+
Host::Unix(_) => {
250+
return Err(error::tls("TLS over unix sockets not supported".into()))
251+
}
252+
};
253+
transition!(ConnectingTls {
254+
future,
255+
params: state.params,
256+
})
257+
}
258+
b'N' if !state.required => transition!(Ready(Box::new(stream))),
259+
b'N' => Err(error::tls("TLS was required but not supported".into())),
260+
_ => Err(bad_response()),
261+
}
262+
}
263+
264+
fn poll_connecting_tls<'a>(
265+
state: &'a mut RentToOwn<'a, ConnectingTls>,
266+
) -> Poll<AfterConnectingTls, Error> {
267+
let stream = try_ready!(state.future.poll().map_err(error::tls));
268+
transition!(Ready(stream))
269+
}
270+
}
271+
272+
impl ConnectFuture {
273+
pub fn new(params: ConnectParams, tls: TlsMode) -> ConnectFuture {
274+
Connect::start(params, tls)
275+
}
276+
}

0 commit comments

Comments
 (0)