@@ -7,8 +7,10 @@ use std::io;
7
7
use std:: net:: { SocketAddr , ToSocketAddrs } ;
8
8
#[ cfg( unix) ]
9
9
use std:: path:: Path ;
10
+ use std:: time:: { Duration , Instant } ;
10
11
use std:: vec;
11
12
use tokio_tcp:: TcpStream ;
13
+ use tokio_timer:: Delay ;
12
14
#[ cfg( unix) ]
13
15
use tokio_uds:: UnixStream ;
14
16
@@ -40,19 +42,25 @@ where
40
42
#[ state_machine_future( transitions( Handshaking ) ) ]
41
43
ConnectingUnix {
42
44
future : tokio_uds:: ConnectFuture ,
45
+ connect_timeout : Option < Duration > ,
46
+ timeout : Option < Delay > ,
43
47
tls_mode : T ,
44
48
params : HashMap < String , String > ,
45
49
} ,
46
50
#[ state_machine_future( transitions( ConnectingTcp ) ) ]
47
51
ResolvingDns {
48
52
future : CpuFuture < vec:: IntoIter < SocketAddr > , io:: Error > ,
53
+ connect_timeout : Option < Duration > ,
54
+ timeout : Option < Delay > ,
49
55
tls_mode : T ,
50
56
params : HashMap < String , String > ,
51
57
} ,
52
58
#[ state_machine_future( transitions( Handshaking ) ) ]
53
59
ConnectingTcp {
54
60
future : tokio_tcp:: ConnectFuture ,
55
61
addrs : vec:: IntoIter < SocketAddr > ,
62
+ connect_timeout : Option < Duration > ,
63
+ timeout : Option < Delay > ,
56
64
tls_mode : T ,
57
65
params : HashMap < String , String > ,
58
66
} ,
@@ -69,14 +77,29 @@ where
69
77
T : TlsMode < Socket > ,
70
78
{
71
79
fn poll_start < ' a > ( state : & ' a mut RentToOwn < ' a , Start < T > > ) -> Poll < AfterStart < T > , Error > {
72
- let state = state. take ( ) ;
80
+ let mut state = state. take ( ) ;
81
+
82
+ let connect_timeout = match state. params . remove ( "connect_timeout" ) {
83
+ Some ( s) => {
84
+ let seconds = s. parse :: < i64 > ( ) . map_err ( Error :: invalid_connect_timeout) ?;
85
+ if seconds <= 0 {
86
+ None
87
+ } else {
88
+ Some ( Duration :: from_secs ( seconds as u64 ) )
89
+ }
90
+ }
91
+ None => None ,
92
+ } ;
93
+ let timeout = connect_timeout. map ( |d| Delay :: new ( Instant :: now ( ) + d) ) ;
73
94
74
95
#[ cfg( unix) ]
75
96
{
76
97
if state. host . starts_with ( '/' ) {
77
98
let path = Path :: new ( & state. host ) . join ( format ! ( ".s.PGSQL.{}" , state. port) ) ;
78
99
transition ! ( ConnectingUnix {
79
100
future: UnixStream :: connect( path) ,
101
+ connect_timeout,
102
+ timeout,
80
103
tls_mode: state. tls_mode,
81
104
params: state. params,
82
105
} )
87
110
let port = state. port ;
88
111
transition ! ( ResolvingDns {
89
112
future: DNS_POOL . spawn_fn( move || ( & * host, port) . to_socket_addrs( ) ) ,
113
+ connect_timeout,
114
+ timeout,
90
115
tls_mode: state. tls_mode,
91
116
params: state. params,
92
117
} )
@@ -96,6 +121,14 @@ where
96
121
fn poll_connecting_unix < ' a > (
97
122
state : & ' a mut RentToOwn < ' a , ConnectingUnix < T > > ,
98
123
) -> Poll < AfterConnectingUnix < T > , Error > {
124
+ if let Some ( timeout) = & mut state. timeout {
125
+ match timeout. poll ( ) {
126
+ Ok ( Async :: Ready ( ( ) ) ) => return Err ( Error :: connect_timeout ( ) ) ,
127
+ Ok ( Async :: NotReady ) => { }
128
+ Err ( e) => return Err ( Error :: timer ( e) ) ,
129
+ }
130
+ }
131
+
99
132
let stream = try_ready ! ( state. future. poll( ) . map_err( Error :: connect) ) ;
100
133
let stream = Socket :: new_unix ( stream) ;
101
134
let state = state. take ( ) ;
@@ -108,6 +141,14 @@ where
108
141
fn poll_resolving_dns < ' a > (
109
142
state : & ' a mut RentToOwn < ' a , ResolvingDns < T > > ,
110
143
) -> Poll < AfterResolvingDns < T > , Error > {
144
+ if let Some ( timeout) = & mut state. timeout {
145
+ match timeout. poll ( ) {
146
+ Ok ( Async :: Ready ( ( ) ) ) => return Err ( Error :: connect_timeout ( ) ) ,
147
+ Ok ( Async :: NotReady ) => { }
148
+ Err ( e) => return Err ( Error :: timer ( e) ) ,
149
+ }
150
+ }
151
+
111
152
let mut addrs = try_ready ! ( state. future. poll( ) . map_err( Error :: connect) ) ;
112
153
let state = state. take ( ) ;
113
154
@@ -124,6 +165,8 @@ where
124
165
transition ! ( ConnectingTcp {
125
166
future: TcpStream :: connect( & addr) ,
126
167
addrs,
168
+ connect_timeout: state. connect_timeout,
169
+ timeout: state. timeout,
127
170
tls_mode: state. tls_mode,
128
171
params: state. params,
129
172
} )
@@ -132,6 +175,14 @@ where
132
175
fn poll_connecting_tcp < ' a > (
133
176
state : & ' a mut RentToOwn < ' a , ConnectingTcp < T > > ,
134
177
) -> Poll < AfterConnectingTcp < T > , Error > {
178
+ if let Some ( timeout) = & mut state. timeout {
179
+ match timeout. poll ( ) {
180
+ Ok ( Async :: Ready ( ( ) ) ) => return Err ( Error :: connect_timeout ( ) ) ,
181
+ Ok ( Async :: NotReady ) => { }
182
+ Err ( e) => return Err ( Error :: timer ( e) ) ,
183
+ }
184
+ }
185
+
135
186
let stream = loop {
136
187
match state. future . poll ( ) {
137
188
Ok ( Async :: Ready ( stream) ) => break stream,
0 commit comments