forked from sfackler/rust-postgres
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconnect_once.rs
121 lines (109 loc) · 3.9 KB
/
connect_once.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#![allow(clippy::large_enum_variant)]
use futures::{try_ready, Async, Future, Poll, Stream};
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
use std::io;
use crate::config::TargetSessionAttrs;
use crate::proto::{
Client, ConnectRawFuture, ConnectSocketFuture, Connection, MaybeTlsStream, SimpleQueryStream,
};
use crate::{Config, Error, SimpleQueryMessage, Socket, TlsConnect};
#[derive(StateMachineFuture)]
pub enum ConnectOnce<T>
where
T: TlsConnect<Socket>,
{
#[state_machine_future(start, transitions(ConnectingSocket))]
Start { idx: usize, tls: T, config: Config },
#[state_machine_future(transitions(ConnectingRaw))]
ConnectingSocket {
future: ConnectSocketFuture,
idx: usize,
tls: T,
config: Config,
},
#[state_machine_future(transitions(CheckingSessionAttrs, Finished))]
ConnectingRaw {
future: ConnectRawFuture<Socket, T>,
target_session_attrs: TargetSessionAttrs,
},
#[state_machine_future(transitions(Finished))]
CheckingSessionAttrs {
stream: SimpleQueryStream,
client: Client,
connection: Connection<MaybeTlsStream<Socket, T::Stream>>,
},
#[state_machine_future(ready)]
Finished((Client, Connection<MaybeTlsStream<Socket, T::Stream>>)),
#[state_machine_future(error)]
Failed(Error),
}
impl<T> PollConnectOnce<T> for ConnectOnce<T>
where
T: TlsConnect<Socket>,
{
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
let state = state.take();
transition!(ConnectingSocket {
future: ConnectSocketFuture::new(state.config.clone(), state.idx),
idx: state.idx,
tls: state.tls,
config: state.config,
})
}
fn poll_connecting_socket<'a>(
state: &'a mut RentToOwn<'a, ConnectingSocket<T>>,
) -> Poll<AfterConnectingSocket<T>, Error> {
let socket = try_ready!(state.future.poll());
let state = state.take();
transition!(ConnectingRaw {
target_session_attrs: state.config.0.target_session_attrs,
future: ConnectRawFuture::new(socket, state.tls, state.config, Some(state.idx)),
})
}
fn poll_connecting_raw<'a>(
state: &'a mut RentToOwn<'a, ConnectingRaw<T>>,
) -> Poll<AfterConnectingRaw<T>, Error> {
let (client, connection) = try_ready!(state.future.poll());
if let TargetSessionAttrs::ReadWrite = state.target_session_attrs {
transition!(CheckingSessionAttrs {
stream: client.simple_query("SHOW transaction_read_only"),
client,
connection,
})
} else {
transition!(Finished((client, connection)))
}
}
fn poll_checking_session_attrs<'a>(
state: &'a mut RentToOwn<'a, CheckingSessionAttrs<T>>,
) -> Poll<AfterCheckingSessionAttrs<T>, Error> {
loop {
if let Async::Ready(()) = state.connection.poll()? {
return Err(Error::closed());
}
match try_ready!(state.stream.poll()) {
Some(SimpleQueryMessage::Row(row)) => {
if row.try_get(0)? == Some("on") {
return Err(Error::connect(io::Error::new(
io::ErrorKind::PermissionDenied,
"database does not allow writes",
)));
} else {
let state = state.take();
transition!(Finished((state.client, state.connection)))
}
}
Some(_) => {}
None => return Err(Error::closed()),
}
}
}
}
impl<T> ConnectOnceFuture<T>
where
T: TlsConnect<Socket>,
{
pub fn new(idx: usize, tls: T, config: Config) -> ConnectOnceFuture<T> {
ConnectOnce::start(idx, tls, config)
}
}