Skip to content

Commit 70758bc

Browse files
committed
tokio-postgres TLS setup
1 parent 5fbe20f commit 70758bc

File tree

5 files changed

+230
-33
lines changed

5 files changed

+230
-33
lines changed

tokio-postgres/src/lib.rs

+10-2
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ pub use postgres_shared::{CancelData, Notification};
3535

3636
use error::Error;
3737
use params::ConnectParams;
38+
use tls::TlsConnect;
3839
use types::{FromSql, ToSql, Type};
3940

4041
mod proto;
42+
pub mod tls;
4143

4244
static NEXT_STATEMENT_ID: AtomicUsize = AtomicUsize::new(0);
4345

@@ -55,8 +57,14 @@ fn disconnected() -> Error {
5557
))
5658
}
5759

58-
pub fn connect(params: ConnectParams) -> Handshake {
59-
Handshake(proto::HandshakeFuture::new(params))
60+
pub enum TlsMode {
61+
None,
62+
Prefer(Box<TlsConnect>),
63+
Require(Box<TlsConnect>),
64+
}
65+
66+
pub fn connect(params: ConnectParams, tls: TlsMode) -> Handshake {
67+
Handshake(proto::HandshakeFuture::new(params, tls))
6068
}
6169

6270
pub struct Client(proto::Client);

tokio-postgres/src/proto/connection.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use tokio_codec::Framed;
99
use disconnected;
1010
use error::{self, Error};
1111
use proto::codec::PostgresCodec;
12-
use proto::socket::Socket;
12+
use tls::TlsStream;
1313
use {bad_response, CancelData};
1414

1515
pub struct Request {
@@ -25,7 +25,7 @@ enum State {
2525
}
2626

2727
pub struct Connection {
28-
stream: Framed<Socket, PostgresCodec>,
28+
stream: Framed<Box<TlsStream>, PostgresCodec>,
2929
cancel_data: CancelData,
3030
parameters: HashMap<String, String>,
3131
receiver: mpsc::UnboundedReceiver<Request>,
@@ -37,7 +37,7 @@ pub struct Connection {
3737

3838
impl Connection {
3939
pub fn new(
40-
stream: Framed<Socket, PostgresCodec>,
40+
stream: Framed<Box<TlsStream>, PostgresCodec>,
4141
cancel_data: CancelData,
4242
parameters: HashMap<String, String>,
4343
receiver: mpsc::UnboundedReceiver<Request>,

tokio-postgres/src/proto/handshake.rs

+120-15
Original file line numberDiff line numberDiff line change
@@ -8,55 +8,84 @@ use postgres_protocol::message::backend::Message;
88
use postgres_protocol::message::frontend;
99
use state_machine_future::RentToOwn;
1010
use std::collections::HashMap;
11+
use std::error::Error as StdError;
1112
use std::io;
1213
use tokio_codec::Framed;
14+
use tokio_io::io::{read_exact, write_all, ReadExact, WriteAll};
1315

1416
use error::{self, Error};
15-
use params::{ConnectParams, User};
17+
use params::{ConnectParams, Host, User};
1618
use proto::client::Client;
1719
use proto::codec::PostgresCodec;
1820
use proto::connection::Connection;
1921
use proto::socket::{ConnectFuture, Socket};
20-
use {bad_response, disconnected, CancelData};
22+
use tls::{self, TlsConnect, TlsStream};
23+
use {bad_response, disconnected, CancelData, TlsMode};
2124

2225
#[derive(StateMachineFuture)]
2326
pub enum Handshake {
24-
#[state_machine_future(start, transitions(SendingStartup))]
27+
#[state_machine_future(start, transitions(BuildingStartup, SendingSsl))]
2528
Start {
2629
future: ConnectFuture,
2730
params: ConnectParams,
31+
tls: TlsMode,
32+
},
33+
#[state_machine_future(transitions(ReadingSsl))]
34+
SendingSsl {
35+
future: WriteAll<Socket, Vec<u8>>,
36+
params: ConnectParams,
37+
connector: Box<TlsConnect>,
38+
required: bool,
39+
},
40+
#[state_machine_future(transitions(ConnectingTls, BuildingStartup))]
41+
ReadingSsl {
42+
future: ReadExact<Socket, [u8; 1]>,
43+
params: ConnectParams,
44+
connector: Box<TlsConnect>,
45+
required: bool,
46+
},
47+
#[state_machine_future(transitions(BuildingStartup))]
48+
ConnectingTls {
49+
future:
50+
Box<Future<Item = Box<TlsStream>, Error = Box<StdError + Sync + Send>> + Sync + Send>,
51+
params: ConnectParams,
52+
},
53+
#[state_machine_future(transitions(SendingStartup))]
54+
BuildingStartup {
55+
stream: Framed<Box<TlsStream>, PostgresCodec>,
56+
params: ConnectParams,
2857
},
2958
#[state_machine_future(transitions(ReadingAuth))]
3059
SendingStartup {
31-
future: sink::Send<Framed<Socket, PostgresCodec>>,
60+
future: sink::Send<Framed<Box<TlsStream>, PostgresCodec>>,
3261
user: User,
3362
},
3463
#[state_machine_future(transitions(ReadingInfo, SendingPassword, SendingSasl))]
3564
ReadingAuth {
36-
stream: Framed<Socket, PostgresCodec>,
65+
stream: Framed<Box<TlsStream>, PostgresCodec>,
3766
user: User,
3867
},
3968
#[state_machine_future(transitions(ReadingAuthCompletion))]
4069
SendingPassword {
41-
future: sink::Send<Framed<Socket, PostgresCodec>>,
70+
future: sink::Send<Framed<Box<TlsStream>, PostgresCodec>>,
4271
},
4372
#[state_machine_future(transitions(ReadingSasl))]
4473
SendingSasl {
45-
future: sink::Send<Framed<Socket, PostgresCodec>>,
74+
future: sink::Send<Framed<Box<TlsStream>, PostgresCodec>>,
4675
scram: ScramSha256,
4776
},
4877
#[state_machine_future(transitions(SendingSasl, ReadingAuthCompletion))]
4978
ReadingSasl {
50-
stream: Framed<Socket, PostgresCodec>,
79+
stream: Framed<Box<TlsStream>, PostgresCodec>,
5180
scram: ScramSha256,
5281
},
5382
#[state_machine_future(transitions(ReadingInfo))]
5483
ReadingAuthCompletion {
55-
stream: Framed<Socket, PostgresCodec>,
84+
stream: Framed<Box<TlsStream>, PostgresCodec>,
5685
},
5786
#[state_machine_future(transitions(Finished))]
5887
ReadingInfo {
59-
stream: Framed<Socket, PostgresCodec>,
88+
stream: Framed<Box<TlsStream>, PostgresCodec>,
6089
cancel_data: Option<CancelData>,
6190
parameters: HashMap<String, String>,
6291
},
@@ -71,6 +100,84 @@ impl PollHandshake for Handshake {
71100
let stream = try_ready!(state.future.poll());
72101
let state = state.take();
73102

103+
let (connector, required) = match state.tls {
104+
TlsMode::None => {
105+
transition!(BuildingStartup {
106+
stream: Framed::new(Box::new(stream), PostgresCodec),
107+
params: state.params,
108+
});
109+
}
110+
TlsMode::Prefer(connector) => (connector, false),
111+
TlsMode::Require(connector) => (connector, true),
112+
};
113+
114+
let mut buf = vec![];
115+
frontend::ssl_request(&mut buf);
116+
transition!(SendingSsl {
117+
future: write_all(stream, buf),
118+
params: state.params,
119+
connector,
120+
required,
121+
})
122+
}
123+
124+
fn poll_sending_ssl<'a>(
125+
state: &'a mut RentToOwn<'a, SendingSsl>,
126+
) -> Poll<AfterSendingSsl, Error> {
127+
let (stream, _) = try_ready!(state.future.poll());
128+
let state = state.take();
129+
transition!(ReadingSsl {
130+
future: read_exact(stream, [0]),
131+
params: state.params,
132+
connector: state.connector,
133+
required: state.required,
134+
})
135+
}
136+
137+
fn poll_reading_ssl<'a>(
138+
state: &'a mut RentToOwn<'a, ReadingSsl>,
139+
) -> Poll<AfterReadingSsl, Error> {
140+
let (stream, buf) = try_ready!(state.future.poll());
141+
let state = state.take();
142+
143+
match buf[0] {
144+
b'S' => {
145+
let future = match state.params.host() {
146+
Host::Tcp(domain) => state.connector.connect(domain, tls::Socket(stream)),
147+
Host::Unix(_) => {
148+
return Err(error::tls("TLS over unix sockets not supported".into()))
149+
}
150+
};
151+
transition!(ConnectingTls {
152+
future,
153+
params: state.params,
154+
})
155+
}
156+
b'N' if !state.required => transition!(BuildingStartup {
157+
stream: Framed::new(Box::new(stream), PostgresCodec),
158+
params: state.params,
159+
}),
160+
b'N' => Err(error::tls("TLS was required but not supported".into())),
161+
_ => Err(bad_response()),
162+
}
163+
}
164+
165+
fn poll_connecting_tls<'a>(
166+
state: &'a mut RentToOwn<'a, ConnectingTls>,
167+
) -> Poll<AfterConnectingTls, Error> {
168+
let stream = try_ready!(state.future.poll().map_err(error::tls));
169+
let state = state.take();
170+
transition!(BuildingStartup {
171+
stream: Framed::new(stream, PostgresCodec),
172+
params: state.params,
173+
})
174+
}
175+
176+
fn poll_building_startup<'a>(
177+
state: &'a mut RentToOwn<'a, BuildingStartup>,
178+
) -> Poll<AfterBuildingStartup, Error> {
179+
let state = state.take();
180+
74181
let user = match state.params.user() {
75182
Some(user) => user.clone(),
76183
None => {
@@ -102,10 +209,8 @@ impl PollHandshake for Handshake {
102209
)?;
103210
}
104211

105-
let stream = Framed::new(stream, PostgresCodec);
106-
107212
transition!(SendingStartup {
108-
future: stream.send(buf),
213+
future: state.stream.send(buf),
109214
user,
110215
})
111216
}
@@ -298,8 +403,8 @@ impl PollHandshake for Handshake {
298403
}
299404

300405
impl HandshakeFuture {
301-
pub fn new(params: ConnectParams) -> HandshakeFuture {
302-
Handshake::start(Socket::connect(&params), params)
406+
pub fn new(params: ConnectParams, tls: TlsMode) -> HandshakeFuture {
407+
Handshake::start(Socket::connect(&params), params, tls)
303408
}
304409
}
305410

tokio-postgres/src/tls.rs

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
use bytes::{Buf, BufMut};
2+
use futures::{Future, Poll};
3+
use std::error::Error;
4+
use std::io::{self, Read, Write};
5+
use tokio_io::{AsyncRead, AsyncWrite};
6+
7+
use proto;
8+
9+
pub struct Socket(pub(crate) proto::Socket);
10+
11+
impl Read for Socket {
12+
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
13+
self.0.read(buf)
14+
}
15+
}
16+
17+
impl AsyncRead for Socket {
18+
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
19+
self.0.prepare_uninitialized_buffer(buf)
20+
}
21+
22+
fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
23+
where
24+
B: BufMut,
25+
{
26+
self.0.read_buf(buf)
27+
}
28+
}
29+
30+
impl Write for Socket {
31+
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
32+
self.0.write(buf)
33+
}
34+
35+
fn flush(&mut self) -> io::Result<()> {
36+
self.0.flush()
37+
}
38+
}
39+
40+
impl AsyncWrite for Socket {
41+
fn shutdown(&mut self) -> Poll<(), io::Error> {
42+
self.0.shutdown()
43+
}
44+
45+
fn write_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
46+
where
47+
B: Buf,
48+
{
49+
self.0.write_buf(buf)
50+
}
51+
}
52+
53+
pub trait TlsConnect {
54+
fn connect(
55+
&self,
56+
domain: &str,
57+
socket: Socket,
58+
) -> Box<Future<Item = Box<TlsStream>, Error = Box<Error + Sync + Send>> + Sync + Send>;
59+
}
60+
61+
pub trait TlsStream: 'static + Sync + Send + AsyncRead + AsyncWrite {}
62+
63+
impl TlsStream for proto::Socket {}

0 commit comments

Comments
 (0)