6
6
//! use openssl::ssl::{SslConnector, SslMethod};
7
7
//! use postgres_openssl::MakeTlsConnector;
8
8
//!
9
- //! # fn main() -> Result<(), Box<std::error::Error>> {
9
+ //! # fn main() -> Result<(), Box<dyn std::error::Error>> {
10
10
//! let mut builder = SslConnector::builder(SslMethod::tls())?;
11
11
//! builder.set_ca_file("database_cert.pem")?;
12
12
//! let connector = MakeTlsConnector::new(builder.build());
25
25
//! use openssl::ssl::{SslConnector, SslMethod};
26
26
//! use postgres_openssl::MakeTlsConnector;
27
27
//!
28
- //! # fn main() -> Result<(), Box<std::error::Error>> {
28
+ //! # fn main() -> Result<(), Box<dyn std::error::Error>> {
29
29
//! let mut builder = SslConnector::builder(SslMethod::tls())?;
30
30
//! builder.set_ca_file("database_cert.pem")?;
31
31
//! let connector = MakeTlsConnector::new(builder.build());
42
42
#![ doc( html_root_url = "https://docs.rs/postgres-openssl/0.3" ) ]
43
43
#![ warn( rust_2018_idioms, clippy:: all, missing_docs) ]
44
44
45
+ use futures:: task:: Context ;
46
+ use futures:: Poll ;
45
47
#[ cfg( feature = "runtime" ) ]
46
48
use openssl:: error:: ErrorStack ;
47
49
use openssl:: hash:: MessageDigest ;
@@ -51,11 +53,13 @@ use openssl::ssl::SslConnector;
51
53
use openssl:: ssl:: { ConnectConfiguration , SslRef } ;
52
54
use std:: fmt:: Debug ;
53
55
use std:: future:: Future ;
56
+ use std:: io;
54
57
use std:: pin:: Pin ;
55
58
#[ cfg( feature = "runtime" ) ]
56
59
use std:: sync:: Arc ;
57
- use tokio_io:: { AsyncRead , AsyncWrite } ;
60
+ use tokio_io:: { AsyncRead , AsyncWrite , Buf , BufMut } ;
58
61
use tokio_openssl:: { HandshakeError , SslStream } ;
62
+ use tokio_postgres:: tls;
59
63
#[ cfg( feature = "runtime" ) ]
60
64
use tokio_postgres:: tls:: MakeTlsConnect ;
61
65
use tokio_postgres:: tls:: { ChannelBinding , TlsConnect } ;
@@ -99,7 +103,7 @@ impl<S> MakeTlsConnect<S> for MakeTlsConnector
99
103
where
100
104
S : AsyncRead + AsyncWrite + Unpin + Debug + ' static + Sync + Send ,
101
105
{
102
- type Stream = SslStream < S > ;
106
+ type Stream = TlsStream < S > ;
103
107
type TlsConnect = TlsConnector ;
104
108
type Error = ErrorStack ;
105
109
@@ -130,29 +134,96 @@ impl<S> TlsConnect<S> for TlsConnector
130
134
where
131
135
S : AsyncRead + AsyncWrite + Unpin + Debug + ' static + Sync + Send ,
132
136
{
133
- type Stream = SslStream < S > ;
137
+ type Stream = TlsStream < S > ;
134
138
type Error = HandshakeError < S > ;
135
139
#[ allow( clippy:: type_complexity) ]
136
- type Future = Pin <
137
- Box < dyn Future < Output = Result < ( SslStream < S > , ChannelBinding ) , HandshakeError < S > > > + Send > ,
138
- > ;
140
+ type Future = Pin < Box < dyn Future < Output = Result < TlsStream < S > , HandshakeError < S > > > + Send > > ;
139
141
140
142
fn connect ( self , stream : S ) -> Self :: Future {
141
143
let future = async move {
142
144
let stream = tokio_openssl:: connect ( self . ssl , & self . domain , stream) . await ?;
143
-
144
- let channel_binding = match tls_server_end_point ( stream. ssl ( ) ) {
145
- Some ( buf) => ChannelBinding :: tls_server_end_point ( buf) ,
146
- None => ChannelBinding :: none ( ) ,
147
- } ;
148
-
149
- Ok ( ( stream, channel_binding) )
145
+ Ok ( TlsStream ( stream) )
150
146
} ;
151
147
152
148
Box :: pin ( future)
153
149
}
154
150
}
155
151
152
+ /// The stream returned by `TlsConnector`.
153
+ pub struct TlsStream < S > ( SslStream < S > ) ;
154
+
155
+ impl < S > AsyncRead for TlsStream < S >
156
+ where
157
+ S : AsyncRead + AsyncWrite + Unpin ,
158
+ {
159
+ unsafe fn prepare_uninitialized_buffer ( & self , buf : & mut [ u8 ] ) -> bool {
160
+ self . 0 . prepare_uninitialized_buffer ( buf)
161
+ }
162
+
163
+ fn poll_read (
164
+ mut self : Pin < & mut Self > ,
165
+ cx : & mut Context < ' _ > ,
166
+ buf : & mut [ u8 ] ,
167
+ ) -> Poll < io:: Result < usize > > {
168
+ Pin :: new ( & mut self . 0 ) . poll_read ( cx, buf)
169
+ }
170
+
171
+ fn poll_read_buf < B : BufMut > (
172
+ mut self : Pin < & mut Self > ,
173
+ cx : & mut Context < ' _ > ,
174
+ buf : & mut B ,
175
+ ) -> Poll < io:: Result < usize > >
176
+ where
177
+ Self : Sized ,
178
+ {
179
+ Pin :: new ( & mut self . 0 ) . poll_read_buf ( cx, buf)
180
+ }
181
+ }
182
+
183
+ impl < S > AsyncWrite for TlsStream < S >
184
+ where
185
+ S : AsyncRead + AsyncWrite + Unpin ,
186
+ {
187
+ fn poll_write (
188
+ mut self : Pin < & mut Self > ,
189
+ cx : & mut Context < ' _ > ,
190
+ buf : & [ u8 ] ,
191
+ ) -> Poll < io:: Result < usize > > {
192
+ Pin :: new ( & mut self . 0 ) . poll_write ( cx, buf)
193
+ }
194
+
195
+ fn poll_flush ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
196
+ Pin :: new ( & mut self . 0 ) . poll_flush ( cx)
197
+ }
198
+
199
+ fn poll_shutdown ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
200
+ Pin :: new ( & mut self . 0 ) . poll_shutdown ( cx)
201
+ }
202
+
203
+ fn poll_write_buf < B : Buf > (
204
+ mut self : Pin < & mut Self > ,
205
+ cx : & mut Context < ' _ > ,
206
+ buf : & mut B ,
207
+ ) -> Poll < io:: Result < usize > >
208
+ where
209
+ Self : Sized ,
210
+ {
211
+ Pin :: new ( & mut self . 0 ) . poll_write_buf ( cx, buf)
212
+ }
213
+ }
214
+
215
+ impl < S > tls:: TlsStream for TlsStream < S >
216
+ where
217
+ S : AsyncRead + AsyncWrite + Unpin ,
218
+ {
219
+ fn channel_binding ( & self ) -> ChannelBinding {
220
+ match tls_server_end_point ( self . 0 . ssl ( ) ) {
221
+ Some ( buf) => ChannelBinding :: tls_server_end_point ( buf) ,
222
+ None => ChannelBinding :: none ( ) ,
223
+ }
224
+ }
225
+ }
226
+
156
227
fn tls_server_end_point ( ssl : & SslRef ) -> Option < Vec < u8 > > {
157
228
let cert = ssl. peer_certificate ( ) ?;
158
229
let algo_nid = cert. signature_algorithm ( ) . object ( ) . nid ( ) ;
0 commit comments