1
- use crate :: client:: SocketConfig ;
2
- use crate :: config:: { Host , TargetSessionAttrs } ;
1
+ use crate :: client:: { Addr , SocketConfig } ;
2
+ use crate :: config:: { Host , LoadBalanceHosts , TargetSessionAttrs } ;
3
3
use crate :: connect_raw:: connect_raw;
4
4
use crate :: connect_socket:: connect_socket;
5
- use crate :: tls:: { MakeTlsConnect , TlsConnect } ;
5
+ use crate :: tls:: MakeTlsConnect ;
6
6
use crate :: { Client , Config , Connection , Error , SimpleQueryMessage , Socket } ;
7
7
use futures_util:: { future, pin_mut, Future , FutureExt , Stream } ;
8
+ use rand:: seq:: SliceRandom ;
8
9
use std:: task:: Poll ;
9
10
use std:: { cmp, io} ;
11
+ use tokio:: net;
10
12
11
13
pub async fn connect < T > (
12
14
mut tls : T ,
40
42
return Err ( Error :: config ( "invalid number of ports" . into ( ) ) ) ;
41
43
}
42
44
45
+ let mut indices = ( 0 ..num_hosts) . collect :: < Vec < _ > > ( ) ;
46
+ if config. load_balance_hosts == LoadBalanceHosts :: Random {
47
+ indices. shuffle ( & mut rand:: thread_rng ( ) ) ;
48
+ }
49
+
43
50
let mut error = None ;
44
- for i in 0 ..num_hosts {
51
+ for i in indices {
45
52
let host = config. host . get ( i) ;
46
53
let hostaddr = config. hostaddr . get ( i) ;
47
54
let port = config
@@ -59,25 +66,15 @@ where
59
66
Some ( Host :: Unix ( _) ) => None ,
60
67
None => None ,
61
68
} ;
62
- let tls = tls
63
- . make_tls_connect ( hostname. as_deref ( ) . unwrap_or ( "" ) )
64
- . map_err ( |e| Error :: tls ( e. into ( ) ) ) ?;
65
69
66
70
// Try to use the value of hostaddr to establish the TCP connection,
67
71
// fallback to host if hostaddr is not present.
68
72
let addr = match hostaddr {
69
73
Some ( ipaddr) => Host :: Tcp ( ipaddr. to_string ( ) ) ,
70
- None => {
71
- if let Some ( host) = host {
72
- host. clone ( )
73
- } else {
74
- // This is unreachable.
75
- return Err ( Error :: config ( "both host and hostaddr are empty" . into ( ) ) ) ;
76
- }
77
- }
74
+ None => host. cloned ( ) . unwrap ( ) ,
78
75
} ;
79
76
80
- match connect_once ( addr, hostname, port, tls, config) . await {
77
+ match connect_host ( addr, hostname, port, & mut tls, config) . await {
81
78
Ok ( ( client, connection) ) => return Ok ( ( client, connection) ) ,
82
79
Err ( e) => error = Some ( e) ,
83
80
}
@@ -86,18 +83,66 @@ where
86
83
Err ( error. unwrap ( ) )
87
84
}
88
85
89
- async fn connect_once < T > (
86
+ async fn connect_host < T > (
90
87
host : Host ,
91
88
hostname : Option < String > ,
92
89
port : u16 ,
93
- tls : T ,
90
+ tls : & mut T ,
91
+ config : & Config ,
92
+ ) -> Result < ( Client , Connection < Socket , T :: Stream > ) , Error >
93
+ where
94
+ T : MakeTlsConnect < Socket > ,
95
+ {
96
+ match host {
97
+ Host :: Tcp ( host) => {
98
+ let mut addrs = net:: lookup_host ( ( & * host, port) )
99
+ . await
100
+ . map_err ( Error :: connect) ?
101
+ . collect :: < Vec < _ > > ( ) ;
102
+
103
+ if config. load_balance_hosts == LoadBalanceHosts :: Random {
104
+ addrs. shuffle ( & mut rand:: thread_rng ( ) ) ;
105
+ }
106
+
107
+ let mut last_err = None ;
108
+ for addr in addrs {
109
+ match connect_once ( Addr :: Tcp ( addr. ip ( ) ) , hostname. as_deref ( ) , port, tls, config)
110
+ . await
111
+ {
112
+ Ok ( stream) => return Ok ( stream) ,
113
+ Err ( e) => {
114
+ last_err = Some ( e) ;
115
+ continue ;
116
+ }
117
+ } ;
118
+ }
119
+
120
+ Err ( last_err. unwrap_or_else ( || {
121
+ Error :: connect ( io:: Error :: new (
122
+ io:: ErrorKind :: InvalidInput ,
123
+ "could not resolve any addresses" ,
124
+ ) )
125
+ } ) )
126
+ }
127
+ #[ cfg( unix) ]
128
+ Host :: Unix ( path) => {
129
+ connect_once ( Addr :: Unix ( path) , hostname. as_deref ( ) , port, tls, config) . await
130
+ }
131
+ }
132
+ }
133
+
134
+ async fn connect_once < T > (
135
+ addr : Addr ,
136
+ hostname : Option < & str > ,
137
+ port : u16 ,
138
+ tls : & mut T ,
94
139
config : & Config ,
95
140
) -> Result < ( Client , Connection < Socket , T :: Stream > ) , Error >
96
141
where
97
- T : TlsConnect < Socket > ,
142
+ T : MakeTlsConnect < Socket > ,
98
143
{
99
144
let socket = connect_socket (
100
- & host ,
145
+ & addr ,
101
146
port,
102
147
config. connect_timeout ,
103
148
config. tcp_user_timeout ,
@@ -108,6 +153,10 @@ where
108
153
} ,
109
154
)
110
155
. await ?;
156
+
157
+ let tls = tls
158
+ . make_tls_connect ( hostname. unwrap_or ( "" ) )
159
+ . map_err ( |e| Error :: tls ( e. into ( ) ) ) ?;
111
160
let has_hostname = hostname. is_some ( ) ;
112
161
let ( mut client, mut connection) = connect_raw ( socket, tls, has_hostname, config) . await ?;
113
162
@@ -152,8 +201,8 @@ where
152
201
}
153
202
154
203
client. set_socket_config ( SocketConfig {
155
- host ,
156
- hostname,
204
+ addr ,
205
+ hostname : hostname . map ( |s| s . to_string ( ) ) ,
157
206
port,
158
207
connect_timeout : config. connect_timeout ,
159
208
tcp_user_timeout : config. tcp_user_timeout ,
0 commit comments