@@ -8,55 +8,84 @@ use postgres_protocol::message::backend::Message;
8
8
use postgres_protocol:: message:: frontend;
9
9
use state_machine_future:: RentToOwn ;
10
10
use std:: collections:: HashMap ;
11
+ use std:: error:: Error as StdError ;
11
12
use std:: io;
12
13
use tokio_codec:: Framed ;
14
+ use tokio_io:: io:: { read_exact, write_all, ReadExact , WriteAll } ;
13
15
14
16
use error:: { self , Error } ;
15
- use params:: { ConnectParams , User } ;
17
+ use params:: { ConnectParams , Host , User } ;
16
18
use proto:: client:: Client ;
17
19
use proto:: codec:: PostgresCodec ;
18
20
use proto:: connection:: Connection ;
19
21
use proto:: socket:: { ConnectFuture , Socket } ;
20
- use { bad_response, disconnected, CancelData } ;
22
+ use tls:: { self , TlsConnect , TlsStream } ;
23
+ use { bad_response, disconnected, CancelData , TlsMode } ;
21
24
22
25
#[ derive( StateMachineFuture ) ]
23
26
pub enum Handshake {
24
- #[ state_machine_future( start, transitions( SendingStartup ) ) ]
27
+ #[ state_machine_future( start, transitions( BuildingStartup , SendingSsl ) ) ]
25
28
Start {
26
29
future : ConnectFuture ,
27
30
params : ConnectParams ,
31
+ tls : TlsMode ,
32
+ } ,
33
+ #[ state_machine_future( transitions( ReadingSsl ) ) ]
34
+ SendingSsl {
35
+ future : WriteAll < Socket , Vec < u8 > > ,
36
+ params : ConnectParams ,
37
+ connector : Box < TlsConnect > ,
38
+ required : bool ,
39
+ } ,
40
+ #[ state_machine_future( transitions( ConnectingTls , BuildingStartup ) ) ]
41
+ ReadingSsl {
42
+ future : ReadExact < Socket , [ u8 ; 1 ] > ,
43
+ params : ConnectParams ,
44
+ connector : Box < TlsConnect > ,
45
+ required : bool ,
46
+ } ,
47
+ #[ state_machine_future( transitions( BuildingStartup ) ) ]
48
+ ConnectingTls {
49
+ future :
50
+ Box < Future < Item = Box < TlsStream > , Error = Box < StdError + Sync + Send > > + Sync + Send > ,
51
+ params : ConnectParams ,
52
+ } ,
53
+ #[ state_machine_future( transitions( SendingStartup ) ) ]
54
+ BuildingStartup {
55
+ stream : Framed < Box < TlsStream > , PostgresCodec > ,
56
+ params : ConnectParams ,
28
57
} ,
29
58
#[ state_machine_future( transitions( ReadingAuth ) ) ]
30
59
SendingStartup {
31
- future : sink:: Send < Framed < Socket , PostgresCodec > > ,
60
+ future : sink:: Send < Framed < Box < TlsStream > , PostgresCodec > > ,
32
61
user : User ,
33
62
} ,
34
63
#[ state_machine_future( transitions( ReadingInfo , SendingPassword , SendingSasl ) ) ]
35
64
ReadingAuth {
36
- stream : Framed < Socket , PostgresCodec > ,
65
+ stream : Framed < Box < TlsStream > , PostgresCodec > ,
37
66
user : User ,
38
67
} ,
39
68
#[ state_machine_future( transitions( ReadingAuthCompletion ) ) ]
40
69
SendingPassword {
41
- future : sink:: Send < Framed < Socket , PostgresCodec > > ,
70
+ future : sink:: Send < Framed < Box < TlsStream > , PostgresCodec > > ,
42
71
} ,
43
72
#[ state_machine_future( transitions( ReadingSasl ) ) ]
44
73
SendingSasl {
45
- future : sink:: Send < Framed < Socket , PostgresCodec > > ,
74
+ future : sink:: Send < Framed < Box < TlsStream > , PostgresCodec > > ,
46
75
scram : ScramSha256 ,
47
76
} ,
48
77
#[ state_machine_future( transitions( SendingSasl , ReadingAuthCompletion ) ) ]
49
78
ReadingSasl {
50
- stream : Framed < Socket , PostgresCodec > ,
79
+ stream : Framed < Box < TlsStream > , PostgresCodec > ,
51
80
scram : ScramSha256 ,
52
81
} ,
53
82
#[ state_machine_future( transitions( ReadingInfo ) ) ]
54
83
ReadingAuthCompletion {
55
- stream : Framed < Socket , PostgresCodec > ,
84
+ stream : Framed < Box < TlsStream > , PostgresCodec > ,
56
85
} ,
57
86
#[ state_machine_future( transitions( Finished ) ) ]
58
87
ReadingInfo {
59
- stream : Framed < Socket , PostgresCodec > ,
88
+ stream : Framed < Box < TlsStream > , PostgresCodec > ,
60
89
cancel_data : Option < CancelData > ,
61
90
parameters : HashMap < String , String > ,
62
91
} ,
@@ -71,6 +100,84 @@ impl PollHandshake for Handshake {
71
100
let stream = try_ready ! ( state. future. poll( ) ) ;
72
101
let state = state. take ( ) ;
73
102
103
+ let ( connector, required) = match state. tls {
104
+ TlsMode :: None => {
105
+ transition ! ( BuildingStartup {
106
+ stream: Framed :: new( Box :: new( stream) , PostgresCodec ) ,
107
+ params: state. params,
108
+ } ) ;
109
+ }
110
+ TlsMode :: Prefer ( connector) => ( connector, false ) ,
111
+ TlsMode :: Require ( connector) => ( connector, true ) ,
112
+ } ;
113
+
114
+ let mut buf = vec ! [ ] ;
115
+ frontend:: ssl_request ( & mut buf) ;
116
+ transition ! ( SendingSsl {
117
+ future: write_all( stream, buf) ,
118
+ params: state. params,
119
+ connector,
120
+ required,
121
+ } )
122
+ }
123
+
124
+ fn poll_sending_ssl < ' a > (
125
+ state : & ' a mut RentToOwn < ' a , SendingSsl > ,
126
+ ) -> Poll < AfterSendingSsl , Error > {
127
+ let ( stream, _) = try_ready ! ( state. future. poll( ) ) ;
128
+ let state = state. take ( ) ;
129
+ transition ! ( ReadingSsl {
130
+ future: read_exact( stream, [ 0 ] ) ,
131
+ params: state. params,
132
+ connector: state. connector,
133
+ required: state. required,
134
+ } )
135
+ }
136
+
137
+ fn poll_reading_ssl < ' a > (
138
+ state : & ' a mut RentToOwn < ' a , ReadingSsl > ,
139
+ ) -> Poll < AfterReadingSsl , Error > {
140
+ let ( stream, buf) = try_ready ! ( state. future. poll( ) ) ;
141
+ let state = state. take ( ) ;
142
+
143
+ match buf[ 0 ] {
144
+ b'S' => {
145
+ let future = match state. params . host ( ) {
146
+ Host :: Tcp ( domain) => state. connector . connect ( domain, tls:: Socket ( stream) ) ,
147
+ Host :: Unix ( _) => {
148
+ return Err ( error:: tls ( "TLS over unix sockets not supported" . into ( ) ) )
149
+ }
150
+ } ;
151
+ transition ! ( ConnectingTls {
152
+ future,
153
+ params: state. params,
154
+ } )
155
+ }
156
+ b'N' if !state. required => transition ! ( BuildingStartup {
157
+ stream: Framed :: new( Box :: new( stream) , PostgresCodec ) ,
158
+ params: state. params,
159
+ } ) ,
160
+ b'N' => Err ( error:: tls ( "TLS was required but not supported" . into ( ) ) ) ,
161
+ _ => Err ( bad_response ( ) ) ,
162
+ }
163
+ }
164
+
165
+ fn poll_connecting_tls < ' a > (
166
+ state : & ' a mut RentToOwn < ' a , ConnectingTls > ,
167
+ ) -> Poll < AfterConnectingTls , Error > {
168
+ let stream = try_ready ! ( state. future. poll( ) . map_err( error:: tls) ) ;
169
+ let state = state. take ( ) ;
170
+ transition ! ( BuildingStartup {
171
+ stream: Framed :: new( stream, PostgresCodec ) ,
172
+ params: state. params,
173
+ } )
174
+ }
175
+
176
+ fn poll_building_startup < ' a > (
177
+ state : & ' a mut RentToOwn < ' a , BuildingStartup > ,
178
+ ) -> Poll < AfterBuildingStartup , Error > {
179
+ let state = state. take ( ) ;
180
+
74
181
let user = match state. params . user ( ) {
75
182
Some ( user) => user. clone ( ) ,
76
183
None => {
@@ -102,10 +209,8 @@ impl PollHandshake for Handshake {
102
209
) ?;
103
210
}
104
211
105
- let stream = Framed :: new ( stream, PostgresCodec ) ;
106
-
107
212
transition ! ( SendingStartup {
108
- future: stream. send( buf) ,
213
+ future: state . stream. send( buf) ,
109
214
user,
110
215
} )
111
216
}
@@ -298,8 +403,8 @@ impl PollHandshake for Handshake {
298
403
}
299
404
300
405
impl HandshakeFuture {
301
- pub fn new ( params : ConnectParams ) -> HandshakeFuture {
302
- Handshake :: start ( Socket :: connect ( & params) , params)
406
+ pub fn new ( params : ConnectParams , tls : TlsMode ) -> HandshakeFuture {
407
+ Handshake :: start ( Socket :: connect ( & params) , params, tls )
303
408
}
304
409
}
305
410
0 commit comments