1
- use futures:: { try_ready, Future , Poll } ;
1
+ use futures:: { try_ready, Async , Future , Poll } ;
2
2
use state_machine_future:: { transition, RentToOwn , StateMachineFuture } ;
3
3
use std:: collections:: HashMap ;
4
+ use std:: vec;
4
5
5
6
use crate :: proto:: { Client , ConnectOnceFuture , Connection } ;
6
7
use crate :: { Error , MakeTlsMode , Socket } ;
@@ -20,11 +21,16 @@ where
20
21
future : T :: Future ,
21
22
host : String ,
22
23
port : u16 ,
24
+ addrs : vec:: IntoIter < ( String , u16 ) > ,
25
+ make_tls_mode : T ,
23
26
params : HashMap < String , String > ,
24
27
} ,
25
- #[ state_machine_future( transitions( Finished ) ) ]
28
+ #[ state_machine_future( transitions( MakingTlsMode , Finished ) ) ]
26
29
Connecting {
27
30
future : ConnectOnceFuture < T :: TlsMode > ,
31
+ addrs : vec:: IntoIter < ( String , u16 ) > ,
32
+ make_tls_mode : T ,
33
+ params : HashMap < String , String > ,
28
34
} ,
29
35
#[ state_machine_future( ready) ]
30
36
Finished ( ( Client , Connection < T :: Stream > ) ) ,
@@ -43,16 +49,42 @@ where
43
49
Some ( host) => host,
44
50
None => return Err ( Error :: missing_host ( ) ) ,
45
51
} ;
52
+ let mut addrs = host
53
+ . split ( ',' )
54
+ . map ( |s| ( s. to_string ( ) , 0u16 ) )
55
+ . collect :: < Vec < _ > > ( ) ;
46
56
47
- let port = match state. params . remove ( "port" ) {
48
- Some ( port) => port. parse :: < u16 > ( ) . map_err ( Error :: invalid_port) ?,
49
- None => 5432 ,
50
- } ;
57
+ let port = state. params . remove ( "port" ) . unwrap_or_else ( String :: new) ;
58
+ let mut ports = port
59
+ . split ( ',' )
60
+ . map ( |s| {
61
+ if s. is_empty ( ) {
62
+ Ok ( 5432 )
63
+ } else {
64
+ s. parse :: < u16 > ( ) . map_err ( Error :: invalid_port)
65
+ }
66
+ } )
67
+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
68
+ if ports. len ( ) == 1 {
69
+ ports. resize ( addrs. len ( ) , ports[ 0 ] ) ;
70
+ }
71
+ if addrs. len ( ) != ports. len ( ) {
72
+ return Err ( Error :: invalid_port_count ( ) ) ;
73
+ }
74
+
75
+ for ( addr, port) in addrs. iter_mut ( ) . zip ( ports) {
76
+ addr. 1 = port;
77
+ }
78
+
79
+ let mut addrs = addrs. into_iter ( ) ;
80
+ let ( host, port) = addrs. next ( ) . expect ( "addrs cannot be empty" ) ;
51
81
52
82
transition ! ( MakingTlsMode {
53
83
future: state. make_tls_mode. make_tls_mode( & host) ,
54
84
host,
55
85
port,
86
+ addrs,
87
+ make_tls_mode: state. make_tls_mode,
56
88
params: state. params,
57
89
} )
58
90
}
@@ -64,15 +96,36 @@ where
64
96
let state = state. take ( ) ;
65
97
66
98
transition ! ( Connecting {
67
- future: ConnectOnceFuture :: new( state. host, state. port, tls_mode, state. params) ,
99
+ future: ConnectOnceFuture :: new( state. host, state. port, tls_mode, state. params. clone( ) ) ,
100
+ addrs: state. addrs,
101
+ make_tls_mode: state. make_tls_mode,
102
+ params: state. params,
68
103
} )
69
104
}
70
105
71
106
fn poll_connecting < ' a > (
72
107
state : & ' a mut RentToOwn < ' a , Connecting < T > > ,
73
108
) -> Poll < AfterConnecting < T > , Error > {
74
- let r = try_ready ! ( state. future. poll( ) ) ;
75
- transition ! ( Finished ( r) )
109
+ match state. future . poll ( ) {
110
+ Ok ( Async :: Ready ( r) ) => transition ! ( Finished ( r) ) ,
111
+ Ok ( Async :: NotReady ) => Ok ( Async :: NotReady ) ,
112
+ Err ( e) => {
113
+ let mut state = state. take ( ) ;
114
+ let ( host, port) = match state. addrs . next ( ) {
115
+ Some ( addr) => addr,
116
+ None => return Err ( e) ,
117
+ } ;
118
+
119
+ transition ! ( MakingTlsMode {
120
+ future: state. make_tls_mode. make_tls_mode( & host) ,
121
+ host,
122
+ port,
123
+ addrs: state. addrs,
124
+ make_tls_mode: state. make_tls_mode,
125
+ params: state. params,
126
+ } )
127
+ }
128
+ }
76
129
}
77
130
}
78
131
0 commit comments