Skip to content

Commit 08b4020

Browse files
committed
Overhaul connection APIs
* `Connection` is now parameterized over the stream type, which can be any `AsyncRead + AsyncWrite`. * The `TlsMode` enum is now a trait, and `NoTls`, `PreferTls`, and `RequireTls` are types implementing that trait. * The `TlsConnect` trait no longer involves trait objects, and returns channel binding info alongside the stream type rather than requiring the stream to implement an additional trait. * The `connect` free function and `ConnectParams` type is gone in favor of a `Builder` type. It takes a pre-connected stream rather than automatically opening a TCP or Unix socket connection. Notably, we no longer have any dependency on the Tokio runtime. We do use the `tokio-codec` and `tokio-io` crates, but those don't actually depend on mio/tokio-reactor/etc. This means we can work with other futures-based networking stacks. We will almost certainly add back a convenience API that offers something akin to the old logic to open a TCP/Unix connection automatically but that will be worked out in a follow up PR.
1 parent 0e60d80 commit 08b4020

File tree

20 files changed

+954
-1061
lines changed

20 files changed

+954
-1061
lines changed

.circleci/config.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ version: 2
2222
jobs:
2323
build:
2424
docker:
25-
- image: rust:1.26.2
25+
- image: rust:1.30.1
2626
environment:
2727
RUSTFLAGS: -D warnings
2828
- image: sfackler/rust-postgres-test:4

docker/Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
FROM postgres:11-beta1
1+
FROM postgres:11
22

33
COPY sql_setup.sh /docker-entrypoint-initdb.d/

tokio-postgres-native-tls/Cargo.toml

-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ version = "0.1.0"
44
authors = ["Steven Fackler <[email protected]>"]
55

66
[dependencies]
7-
bytes = "0.4"
87
futures = "0.1"
98
native-tls = "0.2"
109
tokio-io = "0.1"

tokio-postgres-native-tls/src/lib.rs

+34-69
Original file line numberDiff line numberDiff line change
@@ -1,106 +1,71 @@
1-
extern crate bytes;
2-
extern crate futures;
31
extern crate native_tls;
42
extern crate tokio_io;
53
extern crate tokio_postgres;
64
extern crate tokio_tls;
75

6+
#[macro_use]
7+
extern crate futures;
8+
89
#[cfg(test)]
910
extern crate tokio;
1011

11-
use bytes::{Buf, BufMut};
12-
use futures::{Future, Poll};
13-
use std::error::Error;
14-
use std::io::{self, Read, Write};
12+
use futures::{Async, Future, Poll};
1513
use tokio_io::{AsyncRead, AsyncWrite};
16-
use tokio_postgres::tls::{Socket, TlsConnect, TlsStream};
14+
use tokio_postgres::{ChannelBinding, TlsConnect};
15+
use tokio_tls::{Connect, TlsStream};
1716

1817
#[cfg(test)]
1918
mod test;
2019

2120
pub struct TlsConnector {
2221
connector: tokio_tls::TlsConnector,
22+
domain: String,
2323
}
2424

2525
impl TlsConnector {
26-
pub fn new() -> Result<TlsConnector, native_tls::Error> {
26+
pub fn new(domain: &str) -> Result<TlsConnector, native_tls::Error> {
2727
let connector = native_tls::TlsConnector::new()?;
28-
Ok(TlsConnector::with_connector(connector))
28+
Ok(TlsConnector::with_connector(connector, domain))
2929
}
3030

31-
pub fn with_connector(connector: native_tls::TlsConnector) -> TlsConnector {
31+
pub fn with_connector(connector: native_tls::TlsConnector, domain: &str) -> TlsConnector {
3232
TlsConnector {
3333
connector: tokio_tls::TlsConnector::from(connector),
34+
domain: domain.to_string(),
3435
}
3536
}
3637
}
3738

38-
impl TlsConnect for TlsConnector {
39-
fn connect(
40-
&self,
41-
domain: &str,
42-
socket: Socket,
43-
) -> Box<Future<Item = Box<TlsStream>, Error = Box<Error + Sync + Send>> + Sync + Send> {
44-
let f = self
45-
.connector
46-
.connect(domain, socket)
47-
.map(|s| {
48-
let s: Box<TlsStream> = Box::new(SslStream(s));
49-
s
50-
}).map_err(|e| {
51-
let e: Box<Error + Sync + Send> = Box::new(e);
52-
e
53-
});
54-
Box::new(f)
55-
}
56-
}
57-
58-
struct SslStream(tokio_tls::TlsStream<Socket>);
59-
60-
impl Read for SslStream {
61-
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
62-
self.0.read(buf)
63-
}
64-
}
65-
66-
impl AsyncRead for SslStream {
67-
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
68-
self.0.prepare_uninitialized_buffer(buf)
69-
}
39+
impl<S> TlsConnect<S> for TlsConnector
40+
where
41+
S: AsyncRead + AsyncWrite,
42+
{
43+
type Stream = TlsStream<S>;
44+
type Error = native_tls::Error;
45+
type Future = TlsConnectFuture<S>;
7046

71-
fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
72-
where
73-
B: BufMut,
74-
{
75-
self.0.read_buf(buf)
47+
fn connect(self, stream: S) -> TlsConnectFuture<S> {
48+
TlsConnectFuture(self.connector.connect(&self.domain, stream))
7649
}
7750
}
7851

79-
impl Write for SslStream {
80-
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
81-
self.0.write(buf)
82-
}
52+
pub struct TlsConnectFuture<S>(Connect<S>);
8353

84-
fn flush(&mut self) -> io::Result<()> {
85-
self.0.flush()
86-
}
87-
}
54+
impl<S> Future for TlsConnectFuture<S>
55+
where
56+
S: AsyncRead + AsyncWrite,
57+
{
58+
type Item = (TlsStream<S>, ChannelBinding);
59+
type Error = native_tls::Error;
8860

89-
impl AsyncWrite for SslStream {
90-
fn shutdown(&mut self) -> Poll<(), io::Error> {
91-
self.0.shutdown()
92-
}
61+
fn poll(&mut self) -> Poll<(TlsStream<S>, ChannelBinding), native_tls::Error> {
62+
let stream = try_ready!(self.0.poll());
63+
let mut channel_binding = ChannelBinding::new();
9364

94-
fn write_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
95-
where
96-
B: Buf,
97-
{
98-
self.0.write_buf(buf)
99-
}
100-
}
65+
if let Some(buf) = stream.get_ref().tls_server_end_point().unwrap_or(None) {
66+
channel_binding = channel_binding.tls_server_end_point(buf);
67+
}
10168

102-
impl TlsStream for SslStream {
103-
fn tls_server_end_point(&self) -> Option<Vec<u8>> {
104-
self.0.get_ref().tls_server_end_point().unwrap_or(None)
69+
Ok(Async::Ready((stream, channel_binding)))
10570
}
10671
}

tokio-postgres-native-tls/src/test.rs

+24-13
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
11
use futures::{Future, Stream};
22
use native_tls::{self, Certificate};
3+
use tokio::net::TcpStream;
34
use tokio::runtime::current_thread::Runtime;
4-
use tokio_postgres::{self, TlsMode};
5+
use tokio_postgres::{self, PreferTls, RequireTls, TlsMode};
56

67
use TlsConnector;
78

8-
fn smoke_test(url: &str, tls: TlsMode) {
9+
fn smoke_test<T>(builder: &tokio_postgres::Builder, tls: T)
10+
where
11+
T: TlsMode<TcpStream>,
12+
T::Stream: 'static,
13+
{
914
let mut runtime = Runtime::new().unwrap();
1015

11-
let handshake = tokio_postgres::connect(url.parse().unwrap(), tls);
16+
let handshake = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap())
17+
.map_err(|e| panic!("{}", e))
18+
.and_then(|s| builder.connect(s, tls));
1219
let (mut client, connection) = runtime.block_on(handshake).unwrap();
1320
let connection = connection.map_err(|e| panic!("{}", e));
14-
runtime.handle().spawn(connection).unwrap();
21+
runtime.spawn(connection);
1522

1623
let prepare = client.prepare("SELECT 1::INT4");
1724
let statement = runtime.block_on(prepare).unwrap();
@@ -33,10 +40,11 @@ fn require() {
3340
Certificate::from_pem(include_bytes!("../../test/server.crt")).unwrap(),
3441
).build()
3542
.unwrap();
36-
let connector = TlsConnector::with_connector(connector);
3743
smoke_test(
38-
"postgres://ssl_user@localhost:5433/postgres",
39-
TlsMode::Require(Box::new(connector)),
44+
tokio_postgres::Builder::new()
45+
.user("ssl_user")
46+
.database("postgres"),
47+
RequireTls(TlsConnector::with_connector(connector, "localhost")),
4048
);
4149
}
4250

@@ -47,10 +55,11 @@ fn prefer() {
4755
Certificate::from_pem(include_bytes!("../../test/server.crt")).unwrap(),
4856
).build()
4957
.unwrap();
50-
let connector = TlsConnector::with_connector(connector);
5158
smoke_test(
52-
"postgres://ssl_user@localhost:5433/postgres",
53-
TlsMode::Prefer(Box::new(connector)),
59+
tokio_postgres::Builder::new()
60+
.user("ssl_user")
61+
.database("postgres"),
62+
PreferTls(TlsConnector::with_connector(connector, "localhost")),
5463
);
5564
}
5665

@@ -61,9 +70,11 @@ fn scram_user() {
6170
Certificate::from_pem(include_bytes!("../../test/server.crt")).unwrap(),
6271
).build()
6372
.unwrap();
64-
let connector = TlsConnector::with_connector(connector);
6573
smoke_test(
66-
"postgres://scram_user:password@localhost:5433/postgres",
67-
TlsMode::Require(Box::new(connector)),
74+
tokio_postgres::Builder::new()
75+
.user("scram_user")
76+
.password("password")
77+
.database("postgres"),
78+
RequireTls(TlsConnector::with_connector(connector, "localhost")),
6879
);
6980
}

tokio-postgres-openssl/Cargo.toml

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@ version = "0.1.0"
44
authors = ["Steven Fackler <[email protected]>"]
55

66
[dependencies]
7-
bytes = "0.4"
87
futures = "0.1"
98
openssl = "0.10"
109
tokio-io = "0.1"
11-
tokio-openssl = "0.2"
10+
tokio-openssl = "0.3"
1211
tokio-postgres = { version = "0.3", path = "../tokio-postgres" }
1312

1413
[dev-dependencies]

0 commit comments

Comments
 (0)