Skip to content

Commit 3955d26

Browse files
committed
Don't hold strong references in statements
There's no need for the connection to stay open until statements drop - they'll be cleaned up anyway once the connection dies.
1 parent 1788a03 commit 3955d26

File tree

9 files changed

+159
-127
lines changed

9 files changed

+159
-127
lines changed

tokio-postgres/src/proto/client.rs

+85-35
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use postgres_protocol;
44
use postgres_protocol::message::backend::Message;
55
use postgres_protocol::message::frontend;
66
use std::collections::HashMap;
7-
use std::sync::Arc;
7+
use std::sync::{Arc, Weak};
88

99
use disconnected;
1010
use error::{self, Error};
@@ -15,67 +15,120 @@ use proto::query::QueryStream;
1515
use proto::statement::Statement;
1616
use types::{IsNull, Oid, ToSql, Type};
1717

18-
pub struct PendingRequest {
19-
sender: mpsc::UnboundedSender<Request>,
20-
messages: Result<Vec<u8>, Error>,
21-
}
18+
pub struct PendingRequest(Result<Vec<u8>, Error>);
2219

23-
impl PendingRequest {
24-
pub fn send(self) -> Result<mpsc::Receiver<Message>, Error> {
25-
let messages = self.messages?;
26-
let (sender, receiver) = mpsc::channel(0);
27-
self.sender
28-
.unbounded_send(Request { messages, sender })
29-
.map(|_| receiver)
30-
.map_err(|_| disconnected())
20+
pub struct WeakClient(Weak<Inner>);
21+
22+
impl WeakClient {
23+
pub fn upgrade(&self) -> Option<Client> {
24+
self.0.upgrade().map(Client)
3125
}
3226
}
3327

34-
pub struct State {
35-
pub types: HashMap<Oid, Type>,
36-
pub typeinfo_query: Option<Statement>,
37-
pub typeinfo_enum_query: Option<Statement>,
38-
pub typeinfo_composite_query: Option<Statement>,
28+
struct State {
29+
types: HashMap<Oid, Type>,
30+
typeinfo_query: Option<Statement>,
31+
typeinfo_enum_query: Option<Statement>,
32+
typeinfo_composite_query: Option<Statement>,
3933
}
4034

41-
#[derive(Clone)]
42-
pub struct Client {
43-
pub state: Arc<Mutex<State>>,
35+
struct Inner {
36+
state: Mutex<State>,
4437
sender: mpsc::UnboundedSender<Request>,
4538
}
4639

40+
#[derive(Clone)]
41+
pub struct Client(Arc<Inner>);
42+
4743
impl Client {
4844
pub fn new(sender: mpsc::UnboundedSender<Request>) -> Client {
49-
Client {
50-
state: Arc::new(Mutex::new(State {
45+
Client(Arc::new(Inner {
46+
state: Mutex::new(State {
5147
types: HashMap::new(),
5248
typeinfo_query: None,
5349
typeinfo_enum_query: None,
5450
typeinfo_composite_query: None,
55-
})),
51+
}),
5652
sender,
57-
}
53+
}))
54+
}
55+
56+
pub fn downgrade(&self) -> WeakClient {
57+
WeakClient(Arc::downgrade(&self.0))
58+
}
59+
60+
pub fn cached_type(&self, oid: Oid) -> Option<Type> {
61+
self.0.state.lock().types.get(&oid).cloned()
62+
}
63+
64+
pub fn cache_type(&self, ty: &Type) {
65+
self.0.state.lock().types.insert(ty.oid(), ty.clone());
66+
}
67+
68+
pub fn typeinfo_query(&self) -> Option<Statement> {
69+
self.0.state.lock().typeinfo_query.clone()
70+
}
71+
72+
pub fn set_typeinfo_query(&self, statement: &Statement) {
73+
self.0.state.lock().typeinfo_query = Some(statement.clone());
5874
}
5975

60-
pub fn prepare(&mut self, name: String, query: &str, param_types: &[Type]) -> PrepareFuture {
76+
pub fn typeinfo_enum_query(&self) -> Option<Statement> {
77+
self.0.state.lock().typeinfo_enum_query.clone()
78+
}
79+
80+
pub fn set_typeinfo_enum_query(&self, statement: &Statement) {
81+
self.0.state.lock().typeinfo_enum_query = Some(statement.clone());
82+
}
83+
84+
pub fn typeinfo_composite_query(&self) -> Option<Statement> {
85+
self.0.state.lock().typeinfo_composite_query.clone()
86+
}
87+
88+
pub fn set_typeinfo_composite_query(&self, statement: &Statement) {
89+
self.0.state.lock().typeinfo_composite_query = Some(statement.clone());
90+
}
91+
92+
pub fn send(&self, request: PendingRequest) -> Result<mpsc::Receiver<Message>, Error> {
93+
let messages = request.0?;
94+
let (sender, receiver) = mpsc::channel(0);
95+
self.0
96+
.sender
97+
.unbounded_send(Request { messages, sender })
98+
.map(|_| receiver)
99+
.map_err(|_| disconnected())
100+
}
101+
102+
pub fn prepare(&self, name: String, query: &str, param_types: &[Type]) -> PrepareFuture {
61103
let pending = self.pending(|buf| {
62104
frontend::parse(&name, query, param_types.iter().map(|t| t.oid()), buf)?;
63105
frontend::describe(b'S', &name, buf)?;
64106
frontend::sync(buf);
65107
Ok(())
66108
});
67109

68-
PrepareFuture::new(pending, self.sender.clone(), name, self.clone())
110+
PrepareFuture::new(self.clone(), pending, name)
69111
}
70112

71-
pub fn execute(&mut self, statement: &Statement, params: &[&ToSql]) -> ExecuteFuture {
113+
pub fn execute(&self, statement: &Statement, params: &[&ToSql]) -> ExecuteFuture {
72114
let pending = self.pending_execute(statement, params);
73-
ExecuteFuture::new(pending, statement.clone())
115+
ExecuteFuture::new(self.clone(), pending, statement.clone())
74116
}
75117

76-
pub fn query(&mut self, statement: &Statement, params: &[&ToSql]) -> QueryStream {
118+
pub fn query(&self, statement: &Statement, params: &[&ToSql]) -> QueryStream {
77119
let pending = self.pending_execute(statement, params);
78-
QueryStream::new(pending, statement.clone())
120+
QueryStream::new(self.clone(), pending, statement.clone())
121+
}
122+
123+
pub fn close_statement(&self, name: &str) {
124+
let mut buf = vec![];
125+
frontend::close(b'S', name, &mut buf).expect("statement name not valid");
126+
frontend::sync(&mut buf);
127+
let (sender, _) = mpsc::channel(0);
128+
let _ = self.0.sender.unbounded_send(Request {
129+
messages: buf,
130+
sender,
131+
});
79132
}
80133

81134
fn pending_execute(&self, statement: &Statement, params: &[&ToSql]) -> PendingRequest {
@@ -109,9 +162,6 @@ impl Client {
109162
F: FnOnce(&mut Vec<u8>) -> Result<(), Error>,
110163
{
111164
let mut buf = vec![];
112-
PendingRequest {
113-
sender: self.sender.clone(),
114-
messages: messages(&mut buf).map(|()| buf),
115-
}
165+
PendingRequest(messages(&mut buf).map(|()| buf))
116166
}
117167
}

tokio-postgres/src/proto/execute.rs

+5-4
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@ use postgres_protocol::message::backend::Message;
44
use state_machine_future::RentToOwn;
55

66
use error::{self, Error};
7-
use proto::client::PendingRequest;
7+
use proto::client::{Client, PendingRequest};
88
use proto::statement::Statement;
99
use {bad_response, disconnected};
1010

1111
#[derive(StateMachineFuture)]
1212
pub enum Execute {
1313
#[state_machine_future(start, transitions(ReadResponse))]
1414
Start {
15+
client: Client,
1516
request: PendingRequest,
1617
statement: Statement,
1718
},
@@ -31,7 +32,7 @@ pub enum Execute {
3132
impl PollExecute for Execute {
3233
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll<AfterStart, Error> {
3334
let state = state.take();
34-
let receiver = state.request.send()?;
35+
let receiver = state.client.send(state.request)?;
3536

3637
// the statement can drop after this point, since its close will queue up after the execution
3738
transition!(ReadResponse { receiver })
@@ -82,7 +83,7 @@ impl PollExecute for Execute {
8283
}
8384

8485
impl ExecuteFuture {
85-
pub fn new(request: PendingRequest, statement: Statement) -> ExecuteFuture {
86-
Execute::start(request, statement)
86+
pub fn new(client: Client, request: PendingRequest, statement: Statement) -> ExecuteFuture {
87+
Execute::start(client, request, statement)
8788
}
8889
}

tokio-postgres/src/proto/prepare.rs

+14-34
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ use std::vec;
88

99
use error::{self, Error};
1010
use proto::client::{Client, PendingRequest};
11-
use proto::connection::Request;
1211
use proto::statement::Statement;
1312
use proto::typeinfo::TypeinfoFuture;
1413
use types::{Oid, Type};
@@ -19,47 +18,41 @@ use {bad_response, disconnected};
1918
pub enum Prepare {
2019
#[state_machine_future(start, transitions(ReadParseComplete))]
2120
Start {
21+
client: Client,
2222
request: PendingRequest,
23-
sender: mpsc::UnboundedSender<Request>,
2423
name: String,
25-
client: Client,
2624
},
2725
#[state_machine_future(transitions(ReadParameterDescription))]
2826
ReadParseComplete {
29-
sender: mpsc::UnboundedSender<Request>,
27+
client: Client,
3028
receiver: mpsc::Receiver<Message>,
3129
name: String,
32-
client: Client,
3330
},
3431
#[state_machine_future(transitions(ReadRowDescription))]
3532
ReadParameterDescription {
36-
sender: mpsc::UnboundedSender<Request>,
33+
client: Client,
3734
receiver: mpsc::Receiver<Message>,
3835
name: String,
39-
client: Client,
4036
},
4137
#[state_machine_future(transitions(ReadReadyForQuery))]
4238
ReadRowDescription {
43-
sender: mpsc::UnboundedSender<Request>,
39+
client: Client,
4440
receiver: mpsc::Receiver<Message>,
4541
name: String,
4642
parameters: Vec<Oid>,
47-
client: Client,
4843
},
4944
#[state_machine_future(transitions(GetParameterTypes, GetColumnTypes, Finished))]
5045
ReadReadyForQuery {
51-
sender: mpsc::UnboundedSender<Request>,
46+
client: Client,
5247
receiver: mpsc::Receiver<Message>,
5348
name: String,
5449
parameters: Vec<Oid>,
5550
columns: Vec<(String, Oid)>,
56-
client: Client,
5751
},
5852
#[state_machine_future(transitions(GetColumnTypes, Finished))]
5953
GetParameterTypes {
6054
future: TypeinfoFuture,
6155
remaining_parameters: vec::IntoIter<Oid>,
62-
sender: mpsc::UnboundedSender<Request>,
6356
name: String,
6457
parameters: Vec<Type>,
6558
columns: Vec<(String, Oid)>,
@@ -69,7 +62,6 @@ pub enum Prepare {
6962
future: TypeinfoFuture,
7063
cur_column_name: String,
7164
remaining_columns: vec::IntoIter<(String, Oid)>,
72-
sender: mpsc::UnboundedSender<Request>,
7365
name: String,
7466
parameters: Vec<Type>,
7567
columns: Vec<Column>,
@@ -83,10 +75,9 @@ pub enum Prepare {
8375
impl PollPrepare for Prepare {
8476
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll<AfterStart, Error> {
8577
let state = state.take();
86-
let receiver = state.request.send()?;
78+
let receiver = state.client.send(state.request)?;
8779

8880
transition!(ReadParseComplete {
89-
sender: state.sender,
9081
receiver,
9182
name: state.name,
9283
client: state.client,
@@ -101,7 +92,6 @@ impl PollPrepare for Prepare {
10192

10293
match message {
10394
Some(Message::ParseComplete) => transition!(ReadParameterDescription {
104-
sender: state.sender,
10595
receiver: state.receiver,
10696
name: state.name,
10797
client: state.client,
@@ -120,7 +110,6 @@ impl PollPrepare for Prepare {
120110

121111
match message {
122112
Some(Message::ParameterDescription(body)) => transition!(ReadRowDescription {
123-
sender: state.sender,
124113
receiver: state.receiver,
125114
name: state.name,
126115
parameters: body.parameters().collect()?,
@@ -148,7 +137,6 @@ impl PollPrepare for Prepare {
148137
};
149138

150139
transition!(ReadReadyForQuery {
151-
sender: state.sender,
152140
receiver: state.receiver,
153141
name: state.name,
154142
parameters: state.parameters,
@@ -174,7 +162,6 @@ impl PollPrepare for Prepare {
174162
transition!(GetParameterTypes {
175163
future: TypeinfoFuture::new(oid, state.client),
176164
remaining_parameters: parameters,
177-
sender: state.sender,
178165
name: state.name,
179166
parameters: vec![],
180167
columns: state.columns,
@@ -187,15 +174,14 @@ impl PollPrepare for Prepare {
187174
future: TypeinfoFuture::new(oid, state.client),
188175
cur_column_name: name,
189176
remaining_columns: columns,
190-
sender: state.sender,
191177
name: state.name,
192178
parameters: vec![],
193179
columns: vec![],
194180
});
195181
}
196182

197183
transition!(Finished(Statement::new(
198-
state.sender,
184+
state.client.downgrade(),
199185
state.name,
200186
vec![],
201187
vec![]
@@ -222,15 +208,14 @@ impl PollPrepare for Prepare {
222208
future: TypeinfoFuture::new(oid, client),
223209
cur_column_name: name,
224210
remaining_columns: columns,
225-
sender: state.sender,
226211
name: state.name,
227212
parameters: state.parameters,
228213
columns: vec![],
229214
})
230215
}
231216

232217
transition!(Finished(Statement::new(
233-
state.sender,
218+
client.downgrade(),
234219
state.name,
235220
state.parameters,
236221
vec![],
@@ -240,7 +225,7 @@ impl PollPrepare for Prepare {
240225
fn poll_get_column_types<'a>(
241226
state: &'a mut RentToOwn<'a, GetColumnTypes>,
242227
) -> Poll<AfterGetColumnTypes, Error> {
243-
loop {
228+
let client = loop {
244229
let (ty, client) = try_ready!(state.future.poll());
245230
let name = mem::replace(&mut state.cur_column_name, String::new());
246231
state.columns.push(Column::new(name, ty));
@@ -250,13 +235,13 @@ impl PollPrepare for Prepare {
250235
state.cur_column_name = name;
251236
state.future = TypeinfoFuture::new(oid, client);
252237
}
253-
None => break,
238+
None => break client,
254239
}
255-
}
240+
};
256241
let state = state.take();
257242

258243
transition!(Finished(Statement::new(
259-
state.sender,
244+
client.downgrade(),
260245
state.name,
261246
state.parameters,
262247
state.columns,
@@ -265,12 +250,7 @@ impl PollPrepare for Prepare {
265250
}
266251

267252
impl PrepareFuture {
268-
pub fn new(
269-
request: PendingRequest,
270-
sender: mpsc::UnboundedSender<Request>,
271-
name: String,
272-
client: Client,
273-
) -> PrepareFuture {
274-
Prepare::start(request, sender, name, client)
253+
pub fn new(client: Client, request: PendingRequest, name: String) -> PrepareFuture {
254+
Prepare::start(client, request, name)
275255
}
276256
}

0 commit comments

Comments
 (0)