|
1 |
| -extern crate bytes; |
2 |
| -extern crate futures; |
3 | 1 | extern crate native_tls;
|
4 | 2 | extern crate tokio_io;
|
5 | 3 | extern crate tokio_postgres;
|
6 | 4 | extern crate tokio_tls;
|
7 | 5 |
|
| 6 | +#[macro_use] |
| 7 | +extern crate futures; |
| 8 | + |
8 | 9 | #[cfg(test)]
|
9 | 10 | extern crate tokio;
|
10 | 11 |
|
11 |
| -use bytes::{Buf, BufMut}; |
12 |
| -use futures::{Future, Poll}; |
13 |
| -use std::error::Error; |
14 |
| -use std::io::{self, Read, Write}; |
| 12 | +use futures::{Async, Future, Poll}; |
15 | 13 | use tokio_io::{AsyncRead, AsyncWrite};
|
16 |
| -use tokio_postgres::tls::{Socket, TlsConnect, TlsStream}; |
| 14 | +use tokio_postgres::{ChannelBinding, TlsConnect}; |
| 15 | +use tokio_tls::{Connect, TlsStream}; |
17 | 16 |
|
18 | 17 | #[cfg(test)]
|
19 | 18 | mod test;
|
20 | 19 |
|
21 | 20 | pub struct TlsConnector {
|
22 | 21 | connector: tokio_tls::TlsConnector,
|
| 22 | + domain: String, |
23 | 23 | }
|
24 | 24 |
|
25 | 25 | impl TlsConnector {
|
26 |
| - pub fn new() -> Result<TlsConnector, native_tls::Error> { |
| 26 | + pub fn new(domain: &str) -> Result<TlsConnector, native_tls::Error> { |
27 | 27 | let connector = native_tls::TlsConnector::new()?;
|
28 |
| - Ok(TlsConnector::with_connector(connector)) |
| 28 | + Ok(TlsConnector::with_connector(connector, domain)) |
29 | 29 | }
|
30 | 30 |
|
31 |
| - pub fn with_connector(connector: native_tls::TlsConnector) -> TlsConnector { |
| 31 | + pub fn with_connector(connector: native_tls::TlsConnector, domain: &str) -> TlsConnector { |
32 | 32 | TlsConnector {
|
33 | 33 | connector: tokio_tls::TlsConnector::from(connector),
|
| 34 | + domain: domain.to_string(), |
34 | 35 | }
|
35 | 36 | }
|
36 | 37 | }
|
37 | 38 |
|
38 |
| -impl TlsConnect for TlsConnector { |
39 |
| - fn connect( |
40 |
| - &self, |
41 |
| - domain: &str, |
42 |
| - socket: Socket, |
43 |
| - ) -> Box<Future<Item = Box<TlsStream>, Error = Box<Error + Sync + Send>> + Sync + Send> { |
44 |
| - let f = self |
45 |
| - .connector |
46 |
| - .connect(domain, socket) |
47 |
| - .map(|s| { |
48 |
| - let s: Box<TlsStream> = Box::new(SslStream(s)); |
49 |
| - s |
50 |
| - }).map_err(|e| { |
51 |
| - let e: Box<Error + Sync + Send> = Box::new(e); |
52 |
| - e |
53 |
| - }); |
54 |
| - Box::new(f) |
55 |
| - } |
56 |
| -} |
57 |
| - |
58 |
| -struct SslStream(tokio_tls::TlsStream<Socket>); |
59 |
| - |
60 |
| -impl Read for SslStream { |
61 |
| - fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { |
62 |
| - self.0.read(buf) |
63 |
| - } |
64 |
| -} |
65 |
| - |
66 |
| -impl AsyncRead for SslStream { |
67 |
| - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { |
68 |
| - self.0.prepare_uninitialized_buffer(buf) |
69 |
| - } |
| 39 | +impl<S> TlsConnect<S> for TlsConnector |
| 40 | +where |
| 41 | + S: AsyncRead + AsyncWrite, |
| 42 | +{ |
| 43 | + type Stream = TlsStream<S>; |
| 44 | + type Error = native_tls::Error; |
| 45 | + type Future = TlsConnectFuture<S>; |
70 | 46 |
|
71 |
| - fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error> |
72 |
| - where |
73 |
| - B: BufMut, |
74 |
| - { |
75 |
| - self.0.read_buf(buf) |
| 47 | + fn connect(self, stream: S) -> TlsConnectFuture<S> { |
| 48 | + TlsConnectFuture(self.connector.connect(&self.domain, stream)) |
76 | 49 | }
|
77 | 50 | }
|
78 | 51 |
|
79 |
| -impl Write for SslStream { |
80 |
| - fn write(&mut self, buf: &[u8]) -> io::Result<usize> { |
81 |
| - self.0.write(buf) |
82 |
| - } |
| 52 | +pub struct TlsConnectFuture<S>(Connect<S>); |
83 | 53 |
|
84 |
| - fn flush(&mut self) -> io::Result<()> { |
85 |
| - self.0.flush() |
86 |
| - } |
87 |
| -} |
| 54 | +impl<S> Future for TlsConnectFuture<S> |
| 55 | +where |
| 56 | + S: AsyncRead + AsyncWrite, |
| 57 | +{ |
| 58 | + type Item = (TlsStream<S>, ChannelBinding); |
| 59 | + type Error = native_tls::Error; |
88 | 60 |
|
89 |
| -impl AsyncWrite for SslStream { |
90 |
| - fn shutdown(&mut self) -> Poll<(), io::Error> { |
91 |
| - self.0.shutdown() |
92 |
| - } |
| 61 | + fn poll(&mut self) -> Poll<(TlsStream<S>, ChannelBinding), native_tls::Error> { |
| 62 | + let stream = try_ready!(self.0.poll()); |
| 63 | + let mut channel_binding = ChannelBinding::new(); |
93 | 64 |
|
94 |
| - fn write_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error> |
95 |
| - where |
96 |
| - B: Buf, |
97 |
| - { |
98 |
| - self.0.write_buf(buf) |
99 |
| - } |
100 |
| -} |
| 65 | + if let Some(buf) = stream.get_ref().tls_server_end_point().unwrap_or(None) { |
| 66 | + channel_binding = channel_binding.tls_server_end_point(buf); |
| 67 | + } |
101 | 68 |
|
102 |
| -impl TlsStream for SslStream { |
103 |
| - fn tls_server_end_point(&self) -> Option<Vec<u8>> { |
104 |
| - self.0.get_ref().tls_server_end_point().unwrap_or(None) |
| 69 | + Ok(Async::Ready((stream, channel_binding))) |
105 | 70 | }
|
106 | 71 | }
|
0 commit comments