Skip to content

Commit dc9d07e

Browse files
committed
Return a custom TlsStream rather than a ChannelBinding up front
1 parent 6c77baa commit dc9d07e

File tree

9 files changed

+220
-53
lines changed

9 files changed

+220
-53
lines changed

codegen/src/sqlstate.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -92,5 +92,6 @@ fn make_map(codes: &LinkedHashMap<String, Vec<String>>, file: &mut BufWriter<Fil
9292
#[rustfmt::skip]
9393
static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = \n{};\n",
9494
builder.build()
95-
).unwrap();
95+
)
96+
.unwrap();
9697
}

postgres-native-tls/src/lib.rs

+82-11
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
//! use postgres_native_tls::MakeTlsConnector;
88
//! use std::fs;
99
//!
10-
//! # fn main() -> Result<(), Box<std::error::Error>> {
10+
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
1111
//! let cert = fs::read("database_cert.pem")?;
1212
//! let cert = Certificate::from_pem(&cert)?;
1313
//! let connector = TlsConnector::builder()
@@ -30,7 +30,7 @@
3030
//! use postgres_native_tls::MakeTlsConnector;
3131
//! use std::fs;
3232
//!
33-
//! # fn main() -> Result<(), Box<std::error::Error>> {
33+
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
3434
//! let cert = fs::read("database_cert.pem")?;
3535
//! let cert = Certificate::from_pem(&cert)?;
3636
//! let connector = TlsConnector::builder()
@@ -48,13 +48,16 @@
4848
#![doc(html_root_url = "https://docs.rs/postgres-native-tls/0.3")]
4949
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
5050

51+
use futures::task::Context;
52+
use futures::Poll;
5153
use std::future::Future;
54+
use std::io;
5255
use std::pin::Pin;
53-
use tokio_io::{AsyncRead, AsyncWrite};
56+
use tokio_io::{AsyncRead, AsyncWrite, Buf, BufMut};
57+
use tokio_postgres::tls;
5458
#[cfg(feature = "runtime")]
5559
use tokio_postgres::tls::MakeTlsConnect;
5660
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
57-
use tokio_tls::TlsStream;
5861

5962
#[cfg(test)]
6063
mod test;
@@ -111,20 +114,88 @@ where
111114
type Stream = TlsStream<S>;
112115
type Error = native_tls::Error;
113116
#[allow(clippy::type_complexity)]
114-
type Future = Pin<
115-
Box<dyn Future<Output = Result<(TlsStream<S>, ChannelBinding), native_tls::Error>> + Send>,
116-
>;
117+
type Future = Pin<Box<dyn Future<Output = Result<TlsStream<S>, native_tls::Error>> + Send>>;
117118

118119
fn connect(self, stream: S) -> Self::Future {
119120
let future = async move {
120121
let stream = self.connector.connect(&self.domain, stream).await?;
121122

122-
// FIXME https://github.com/tokio-rs/tokio/issues/1383
123-
let channel_binding = ChannelBinding::none();
124-
125-
Ok((stream, channel_binding))
123+
Ok(TlsStream(stream))
126124
};
127125

128126
Box::pin(future)
129127
}
130128
}
129+
130+
/// The stream returned by `TlsConnector`.
131+
pub struct TlsStream<S>(tokio_tls::TlsStream<S>);
132+
133+
impl<S> AsyncRead for TlsStream<S>
134+
where
135+
S: AsyncRead + AsyncWrite + Unpin,
136+
{
137+
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
138+
self.0.prepare_uninitialized_buffer(buf)
139+
}
140+
141+
fn poll_read(
142+
mut self: Pin<&mut Self>,
143+
cx: &mut Context<'_>,
144+
buf: &mut [u8],
145+
) -> Poll<io::Result<usize>> {
146+
Pin::new(&mut self.0).poll_read(cx, buf)
147+
}
148+
149+
fn poll_read_buf<B: BufMut>(
150+
mut self: Pin<&mut Self>,
151+
cx: &mut Context<'_>,
152+
buf: &mut B,
153+
) -> Poll<io::Result<usize>>
154+
where
155+
Self: Sized,
156+
{
157+
Pin::new(&mut self.0).poll_read_buf(cx, buf)
158+
}
159+
}
160+
161+
impl<S> AsyncWrite for TlsStream<S>
162+
where
163+
S: AsyncRead + AsyncWrite + Unpin,
164+
{
165+
fn poll_write(
166+
mut self: Pin<&mut Self>,
167+
cx: &mut Context<'_>,
168+
buf: &[u8],
169+
) -> Poll<io::Result<usize>> {
170+
Pin::new(&mut self.0).poll_write(cx, buf)
171+
}
172+
173+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
174+
Pin::new(&mut self.0).poll_flush(cx)
175+
}
176+
177+
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
178+
Pin::new(&mut self.0).poll_shutdown(cx)
179+
}
180+
181+
fn poll_write_buf<B: Buf>(
182+
mut self: Pin<&mut Self>,
183+
cx: &mut Context<'_>,
184+
buf: &mut B,
185+
) -> Poll<io::Result<usize>>
186+
where
187+
Self: Sized,
188+
{
189+
Pin::new(&mut self.0).poll_write_buf(cx, buf)
190+
}
191+
}
192+
193+
impl<S> tls::TlsStream for TlsStream<S>
194+
where
195+
S: AsyncRead + AsyncWrite + Unpin,
196+
{
197+
fn channel_binding(&self) -> ChannelBinding {
198+
// FIXME https://github.com/tokio-rs/tokio/issues/1383
199+
ChannelBinding::none()
200+
}
201+
}

postgres-openssl/src/lib.rs

+86-15
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
//! use openssl::ssl::{SslConnector, SslMethod};
77
//! use postgres_openssl::MakeTlsConnector;
88
//!
9-
//! # fn main() -> Result<(), Box<std::error::Error>> {
9+
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
1010
//! let mut builder = SslConnector::builder(SslMethod::tls())?;
1111
//! builder.set_ca_file("database_cert.pem")?;
1212
//! let connector = MakeTlsConnector::new(builder.build());
@@ -25,7 +25,7 @@
2525
//! use openssl::ssl::{SslConnector, SslMethod};
2626
//! use postgres_openssl::MakeTlsConnector;
2727
//!
28-
//! # fn main() -> Result<(), Box<std::error::Error>> {
28+
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
2929
//! let mut builder = SslConnector::builder(SslMethod::tls())?;
3030
//! builder.set_ca_file("database_cert.pem")?;
3131
//! let connector = MakeTlsConnector::new(builder.build());
@@ -42,6 +42,8 @@
4242
#![doc(html_root_url = "https://docs.rs/postgres-openssl/0.3")]
4343
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
4444

45+
use futures::task::Context;
46+
use futures::Poll;
4547
#[cfg(feature = "runtime")]
4648
use openssl::error::ErrorStack;
4749
use openssl::hash::MessageDigest;
@@ -51,11 +53,13 @@ use openssl::ssl::SslConnector;
5153
use openssl::ssl::{ConnectConfiguration, SslRef};
5254
use std::fmt::Debug;
5355
use std::future::Future;
56+
use std::io;
5457
use std::pin::Pin;
5558
#[cfg(feature = "runtime")]
5659
use std::sync::Arc;
57-
use tokio_io::{AsyncRead, AsyncWrite};
60+
use tokio_io::{AsyncRead, AsyncWrite, Buf, BufMut};
5861
use tokio_openssl::{HandshakeError, SslStream};
62+
use tokio_postgres::tls;
5963
#[cfg(feature = "runtime")]
6064
use tokio_postgres::tls::MakeTlsConnect;
6165
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
@@ -99,7 +103,7 @@ impl<S> MakeTlsConnect<S> for MakeTlsConnector
99103
where
100104
S: AsyncRead + AsyncWrite + Unpin + Debug + 'static + Sync + Send,
101105
{
102-
type Stream = SslStream<S>;
106+
type Stream = TlsStream<S>;
103107
type TlsConnect = TlsConnector;
104108
type Error = ErrorStack;
105109

@@ -130,29 +134,96 @@ impl<S> TlsConnect<S> for TlsConnector
130134
where
131135
S: AsyncRead + AsyncWrite + Unpin + Debug + 'static + Sync + Send,
132136
{
133-
type Stream = SslStream<S>;
137+
type Stream = TlsStream<S>;
134138
type Error = HandshakeError<S>;
135139
#[allow(clippy::type_complexity)]
136-
type Future = Pin<
137-
Box<dyn Future<Output = Result<(SslStream<S>, ChannelBinding), HandshakeError<S>>> + Send>,
138-
>;
140+
type Future = Pin<Box<dyn Future<Output = Result<TlsStream<S>, HandshakeError<S>>> + Send>>;
139141

140142
fn connect(self, stream: S) -> Self::Future {
141143
let future = async move {
142144
let stream = tokio_openssl::connect(self.ssl, &self.domain, stream).await?;
143-
144-
let channel_binding = match tls_server_end_point(stream.ssl()) {
145-
Some(buf) => ChannelBinding::tls_server_end_point(buf),
146-
None => ChannelBinding::none(),
147-
};
148-
149-
Ok((stream, channel_binding))
145+
Ok(TlsStream(stream))
150146
};
151147

152148
Box::pin(future)
153149
}
154150
}
155151

152+
/// The stream returned by `TlsConnector`.
153+
pub struct TlsStream<S>(SslStream<S>);
154+
155+
impl<S> AsyncRead for TlsStream<S>
156+
where
157+
S: AsyncRead + AsyncWrite + Unpin,
158+
{
159+
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
160+
self.0.prepare_uninitialized_buffer(buf)
161+
}
162+
163+
fn poll_read(
164+
mut self: Pin<&mut Self>,
165+
cx: &mut Context<'_>,
166+
buf: &mut [u8],
167+
) -> Poll<io::Result<usize>> {
168+
Pin::new(&mut self.0).poll_read(cx, buf)
169+
}
170+
171+
fn poll_read_buf<B: BufMut>(
172+
mut self: Pin<&mut Self>,
173+
cx: &mut Context<'_>,
174+
buf: &mut B,
175+
) -> Poll<io::Result<usize>>
176+
where
177+
Self: Sized,
178+
{
179+
Pin::new(&mut self.0).poll_read_buf(cx, buf)
180+
}
181+
}
182+
183+
impl<S> AsyncWrite for TlsStream<S>
184+
where
185+
S: AsyncRead + AsyncWrite + Unpin,
186+
{
187+
fn poll_write(
188+
mut self: Pin<&mut Self>,
189+
cx: &mut Context<'_>,
190+
buf: &[u8],
191+
) -> Poll<io::Result<usize>> {
192+
Pin::new(&mut self.0).poll_write(cx, buf)
193+
}
194+
195+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
196+
Pin::new(&mut self.0).poll_flush(cx)
197+
}
198+
199+
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
200+
Pin::new(&mut self.0).poll_shutdown(cx)
201+
}
202+
203+
fn poll_write_buf<B: Buf>(
204+
mut self: Pin<&mut Self>,
205+
cx: &mut Context<'_>,
206+
buf: &mut B,
207+
) -> Poll<io::Result<usize>>
208+
where
209+
Self: Sized,
210+
{
211+
Pin::new(&mut self.0).poll_write_buf(cx, buf)
212+
}
213+
}
214+
215+
impl<S> tls::TlsStream for TlsStream<S>
216+
where
217+
S: AsyncRead + AsyncWrite + Unpin,
218+
{
219+
fn channel_binding(&self) -> ChannelBinding {
220+
match tls_server_end_point(self.0.ssl()) {
221+
Some(buf) => ChannelBinding::tls_server_end_point(buf),
222+
None => ChannelBinding::none(),
223+
}
224+
}
225+
}
226+
156227
fn tls_server_end_point(ssl: &SslRef) -> Option<Vec<u8>> {
157228
let cert = ssl.peer_certificate()?;
158229
let algo_nid = cert.signature_algorithm().object().nid();

tokio-postgres/src/cancel_query_raw.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ where
1616
S: AsyncRead + AsyncWrite + Unpin,
1717
T: TlsConnect<S>,
1818
{
19-
let (mut stream, _) = connect_tls::connect_tls(stream, mode, tls).await?;
19+
let mut stream = connect_tls::connect_tls(stream, mode, tls).await?;
2020

2121
let mut buf = BytesMut::new();
2222
frontend::cancel_request(process_id, secret_key, &mut buf);

tokio-postgres/src/connect_raw.rs

+11-13
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCod
22
use crate::config::{self, Config};
33
use crate::connect_tls::connect_tls;
44
use crate::maybe_tls_stream::MaybeTlsStream;
5-
use crate::tls::{ChannelBinding, TlsConnect};
5+
use crate::tls::{TlsConnect, TlsStream};
66
use crate::{Client, Connection, Error};
77
use bytes::BytesMut;
88
use fallible_iterator::FallibleIterator;
@@ -86,15 +86,15 @@ where
8686
S: AsyncRead + AsyncWrite + Unpin,
8787
T: TlsConnect<S>,
8888
{
89-
let (stream, channel_binding) = connect_tls(stream, config.ssl_mode, tls).await?;
89+
let stream = connect_tls(stream, config.ssl_mode, tls).await?;
9090

9191
let mut stream = StartupStream {
9292
inner: Framed::new(stream, PostgresCodec),
9393
buf: BackendMessages::empty(),
9494
};
9595

9696
startup(&mut stream, config).await?;
97-
authenticate(&mut stream, channel_binding, config).await?;
97+
authenticate(&mut stream, config).await?;
9898
let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
9999

100100
let (sender, receiver) = mpsc::unbounded();
@@ -132,14 +132,10 @@ where
132132
.map_err(Error::io)
133133
}
134134

135-
async fn authenticate<S, T>(
136-
stream: &mut StartupStream<S, T>,
137-
channel_binding: ChannelBinding,
138-
config: &Config,
139-
) -> Result<(), Error>
135+
async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
140136
where
141137
S: AsyncRead + AsyncWrite + Unpin,
142-
T: AsyncRead + AsyncWrite + Unpin,
138+
T: TlsStream + Unpin,
143139
{
144140
match stream.try_next().await.map_err(Error::io)? {
145141
Some(Message::AuthenticationOk) => {
@@ -172,7 +168,7 @@ where
172168
authenticate_password(stream, output.as_bytes()).await?;
173169
}
174170
Some(Message::AuthenticationSasl(body)) => {
175-
authenticate_sasl(stream, body, channel_binding, config).await?;
171+
authenticate_sasl(stream, body, config).await?;
176172
}
177173
Some(Message::AuthenticationKerberosV5)
178174
| Some(Message::AuthenticationScmCredential)
@@ -225,12 +221,11 @@ where
225221
async fn authenticate_sasl<S, T>(
226222
stream: &mut StartupStream<S, T>,
227223
body: AuthenticationSaslBody,
228-
channel_binding: ChannelBinding,
229224
config: &Config,
230225
) -> Result<(), Error>
231226
where
232227
S: AsyncRead + AsyncWrite + Unpin,
233-
T: AsyncRead + AsyncWrite + Unpin,
228+
T: TlsStream + Unpin,
234229
{
235230
let password = config
236231
.password
@@ -248,7 +243,10 @@ where
248243
}
249244
}
250245

251-
let channel_binding = channel_binding
246+
let channel_binding = stream
247+
.inner
248+
.get_ref()
249+
.channel_binding()
252250
.tls_server_end_point
253251
.filter(|_| config.channel_binding != config::ChannelBinding::Disable)
254252
.map(sasl::ChannelBinding::tls_server_end_point);

0 commit comments

Comments
 (0)