Skip to content

Commit a237a47

Browse files
committed
Support custom types
1 parent be2ca03 commit a237a47

File tree

9 files changed

+986
-43
lines changed

9 files changed

+986
-43
lines changed

tokio-postgres/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ circle-ci = { repository = "sfackler/rust-postgres" }
3232
"with-uuid-0.6" = ["postgres-shared/with-uuid-0.6"]
3333

3434
[dependencies]
35+
antidote = "1.0"
3536
bytes = "0.4"
3637
fallible-iterator = "0.1.3"
3738
futures = "0.1.7"

tokio-postgres/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
extern crate antidote;
12
extern crate bytes;
23
extern crate fallible_iterator;
34
extern crate futures_cpupool;

tokio-postgres/src/proto/client.rs

+23-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
use antidote::Mutex;
12
use futures::sync::mpsc;
23
use postgres_protocol;
34
use postgres_protocol::message::backend::Message;
45
use postgres_protocol::message::frontend;
6+
use std::collections::HashMap;
7+
use std::sync::Arc;
58

69
use disconnected;
710
use error::{self, Error};
@@ -10,7 +13,7 @@ use proto::execute::ExecuteFuture;
1013
use proto::prepare::PrepareFuture;
1114
use proto::query::QueryStream;
1215
use proto::statement::Statement;
13-
use types::{IsNull, ToSql, Type};
16+
use types::{IsNull, Oid, ToSql, Type};
1417

1518
pub struct PendingRequest {
1619
sender: mpsc::UnboundedSender<Request>,
@@ -28,13 +31,30 @@ impl PendingRequest {
2831
}
2932
}
3033

34+
pub struct State {
35+
pub types: HashMap<Oid, Type>,
36+
pub typeinfo_query: Option<Statement>,
37+
pub typeinfo_enum_query: Option<Statement>,
38+
pub typeinfo_composite_query: Option<Statement>,
39+
}
40+
41+
#[derive(Clone)]
3142
pub struct Client {
43+
pub state: Arc<Mutex<State>>,
3244
sender: mpsc::UnboundedSender<Request>,
3345
}
3446

3547
impl Client {
3648
pub fn new(sender: mpsc::UnboundedSender<Request>) -> Client {
37-
Client { sender }
49+
Client {
50+
state: Arc::new(Mutex::new(State {
51+
types: HashMap::new(),
52+
typeinfo_query: None,
53+
typeinfo_enum_query: None,
54+
typeinfo_composite_query: None,
55+
})),
56+
sender,
57+
}
3858
}
3959

4060
pub fn prepare(&mut self, name: String, query: &str, param_types: &[Type]) -> PrepareFuture {
@@ -45,7 +65,7 @@ impl Client {
4565
Ok(())
4666
});
4767

48-
PrepareFuture::new(pending, self.sender.clone(), name)
68+
PrepareFuture::new(pending, self.sender.clone(), name, self.clone())
4969
}
5070

5171
pub fn execute(&mut self, statement: &Statement, params: &[&ToSql]) -> ExecuteFuture {

tokio-postgres/src/proto/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ mod query;
2020
mod row;
2121
mod socket;
2222
mod statement;
23+
mod typeinfo;
24+
mod typeinfo_composite;
25+
mod typeinfo_enum;
2326

2427
pub use proto::cancel::CancelFuture;
2528
pub use proto::client::Client;

tokio-postgres/src/proto/prepare.rs

+144-39
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
use fallible_iterator::FallibleIterator;
22
use futures::sync::mpsc;
3-
use futures::{Poll, Stream};
4-
use postgres_protocol::message::backend::{Message, ParameterDescriptionBody, RowDescriptionBody};
3+
use futures::{Future, Poll, Stream};
4+
use postgres_protocol::message::backend::Message;
55
use state_machine_future::RentToOwn;
6+
use std::mem;
7+
use std::vec;
68

79
use error::{self, Error};
8-
use proto::client::PendingRequest;
10+
use proto::client::{Client, PendingRequest};
911
use proto::connection::Request;
1012
use proto::statement::Statement;
11-
use types::Type;
13+
use proto::typeinfo::TypeinfoFuture;
14+
use types::{Oid, Type};
1215
use Column;
1316
use {bad_response, disconnected};
1417

@@ -19,33 +22,57 @@ pub enum Prepare {
1922
request: PendingRequest,
2023
sender: mpsc::UnboundedSender<Request>,
2124
name: String,
25+
client: Client,
2226
},
2327
#[state_machine_future(transitions(ReadParameterDescription))]
2428
ReadParseComplete {
2529
sender: mpsc::UnboundedSender<Request>,
2630
receiver: mpsc::Receiver<Message>,
2731
name: String,
32+
client: Client,
2833
},
2934
#[state_machine_future(transitions(ReadRowDescription))]
3035
ReadParameterDescription {
3136
sender: mpsc::UnboundedSender<Request>,
3237
receiver: mpsc::Receiver<Message>,
3338
name: String,
39+
client: Client,
3440
},
3541
#[state_machine_future(transitions(ReadReadyForQuery))]
3642
ReadRowDescription {
3743
sender: mpsc::UnboundedSender<Request>,
3844
receiver: mpsc::Receiver<Message>,
3945
name: String,
40-
parameters: ParameterDescriptionBody,
46+
parameters: Vec<Oid>,
47+
client: Client,
4148
},
42-
#[state_machine_future(transitions(Finished))]
49+
#[state_machine_future(transitions(GetParameterTypes, GetColumnTypes, Finished))]
4350
ReadReadyForQuery {
4451
sender: mpsc::UnboundedSender<Request>,
4552
receiver: mpsc::Receiver<Message>,
4653
name: String,
47-
parameters: ParameterDescriptionBody,
48-
columns: Option<RowDescriptionBody>,
54+
parameters: Vec<Oid>,
55+
columns: Vec<(String, Oid)>,
56+
client: Client,
57+
},
58+
#[state_machine_future(transitions(GetColumnTypes, Finished))]
59+
GetParameterTypes {
60+
future: TypeinfoFuture,
61+
remaining_parameters: vec::IntoIter<Oid>,
62+
sender: mpsc::UnboundedSender<Request>,
63+
name: String,
64+
parameters: Vec<Type>,
65+
columns: Vec<(String, Oid)>,
66+
},
67+
#[state_machine_future(transitions(Finished))]
68+
GetColumnTypes {
69+
future: TypeinfoFuture,
70+
cur_column_name: String,
71+
remaining_columns: vec::IntoIter<(String, Oid)>,
72+
sender: mpsc::UnboundedSender<Request>,
73+
name: String,
74+
parameters: Vec<Type>,
75+
columns: Vec<Column>,
4976
},
5077
#[state_machine_future(ready)]
5178
Finished(Statement),
@@ -62,6 +89,7 @@ impl PollPrepare for Prepare {
6289
sender: state.sender,
6390
receiver,
6491
name: state.name,
92+
client: state.client,
6593
})
6694
}
6795

@@ -76,6 +104,7 @@ impl PollPrepare for Prepare {
76104
sender: state.sender,
77105
receiver: state.receiver,
78106
name: state.name,
107+
client: state.client,
79108
}),
80109
Some(Message::ErrorResponse(body)) => Err(error::__db(body)),
81110
Some(_) => Err(bad_response()),
@@ -94,7 +123,8 @@ impl PollPrepare for Prepare {
94123
sender: state.sender,
95124
receiver: state.receiver,
96125
name: state.name,
97-
parameters: body,
126+
parameters: body.parameters().collect()?,
127+
client: state.client,
98128
}),
99129
Some(_) => Err(bad_response()),
100130
None => Err(disconnected()),
@@ -107,9 +137,12 @@ impl PollPrepare for Prepare {
107137
let message = try_receive!(state.receiver.poll());
108138
let state = state.take();
109139

110-
let body = match message {
111-
Some(Message::RowDescription(body)) => Some(body),
112-
Some(Message::NoData) => None,
140+
let columns = match message {
141+
Some(Message::RowDescription(body)) => body
142+
.fields()
143+
.map(|f| (f.name().to_string(), f.type_oid()))
144+
.collect()?,
145+
Some(Message::NoData) => vec![],
113146
Some(_) => return Err(bad_response()),
114147
None => return Err(disconnected()),
115148
};
@@ -119,7 +152,8 @@ impl PollPrepare for Prepare {
119152
receiver: state.receiver,
120153
name: state.name,
121154
parameters: state.parameters,
122-
columns: body,
155+
columns,
156+
client: state.client,
123157
})
124158
}
125159

@@ -130,33 +164,103 @@ impl PollPrepare for Prepare {
130164
let state = state.take();
131165

132166
match message {
133-
Some(Message::ReadyForQuery(_)) => {
134-
// FIXME handle custom types
135-
let parameters = state
136-
.parameters
137-
.parameters()
138-
.map(|oid| Type::from_oid(oid).unwrap())
139-
.collect()?;
140-
let columns = match state.columns {
141-
Some(body) => body
142-
.fields()
143-
.map(|f| {
144-
Column::new(f.name().to_string(), Type::from_oid(f.type_oid()).unwrap())
145-
})
146-
.collect()?,
147-
None => vec![],
148-
};
149-
150-
transition!(Finished(Statement::new(
151-
state.sender,
152-
state.name,
153-
parameters,
154-
columns
155-
)))
167+
Some(Message::ReadyForQuery(_)) => {}
168+
Some(_) => return Err(bad_response()),
169+
None => return Err(disconnected()),
170+
}
171+
172+
let mut parameters = state.parameters.into_iter();
173+
if let Some(oid) = parameters.next() {
174+
transition!(GetParameterTypes {
175+
future: TypeinfoFuture::new(oid, state.client),
176+
remaining_parameters: parameters,
177+
sender: state.sender,
178+
name: state.name,
179+
parameters: vec![],
180+
columns: state.columns,
181+
});
182+
}
183+
184+
let mut columns = state.columns.into_iter();
185+
if let Some((name, oid)) = columns.next() {
186+
transition!(GetColumnTypes {
187+
future: TypeinfoFuture::new(oid, state.client),
188+
cur_column_name: name,
189+
remaining_columns: columns,
190+
sender: state.sender,
191+
name: state.name,
192+
parameters: vec![],
193+
columns: vec![],
194+
});
195+
}
196+
197+
transition!(Finished(Statement::new(
198+
state.sender,
199+
state.name,
200+
vec![],
201+
vec![]
202+
)))
203+
}
204+
205+
fn poll_get_parameter_types<'a>(
206+
state: &'a mut RentToOwn<'a, GetParameterTypes>,
207+
) -> Poll<AfterGetParameterTypes, Error> {
208+
let client = loop {
209+
let (ty, client) = try_ready!(state.future.poll());
210+
state.parameters.push(ty);
211+
212+
match state.remaining_parameters.next() {
213+
Some(oid) => state.future = TypeinfoFuture::new(oid, client),
214+
None => break client,
156215
}
157-
Some(_) => Err(bad_response()),
158-
None => Err(disconnected()),
216+
};
217+
let state = state.take();
218+
219+
let mut columns = state.columns.into_iter();
220+
if let Some((name, oid)) = columns.next() {
221+
transition!(GetColumnTypes {
222+
future: TypeinfoFuture::new(oid, client),
223+
cur_column_name: name,
224+
remaining_columns: columns,
225+
sender: state.sender,
226+
name: state.name,
227+
parameters: state.parameters,
228+
columns: vec![],
229+
})
159230
}
231+
232+
transition!(Finished(Statement::new(
233+
state.sender,
234+
state.name,
235+
state.parameters,
236+
vec![],
237+
)))
238+
}
239+
240+
fn poll_get_column_types<'a>(
241+
state: &'a mut RentToOwn<'a, GetColumnTypes>,
242+
) -> Poll<AfterGetColumnTypes, Error> {
243+
loop {
244+
let (ty, client) = try_ready!(state.future.poll());
245+
let name = mem::replace(&mut state.cur_column_name, String::new());
246+
state.columns.push(Column::new(name, ty));
247+
248+
match state.remaining_columns.next() {
249+
Some((name, oid)) => {
250+
state.cur_column_name = name;
251+
state.future = TypeinfoFuture::new(oid, client);
252+
}
253+
None => break,
254+
}
255+
}
256+
let state = state.take();
257+
258+
transition!(Finished(Statement::new(
259+
state.sender,
260+
state.name,
261+
state.parameters,
262+
state.columns,
263+
)))
160264
}
161265
}
162266

@@ -165,7 +269,8 @@ impl PrepareFuture {
165269
request: PendingRequest,
166270
sender: mpsc::UnboundedSender<Request>,
167271
name: String,
272+
client: Client,
168273
) -> PrepareFuture {
169-
Prepare::start(request, sender, name)
274+
Prepare::start(request, sender, name, client)
170275
}
171276
}

0 commit comments

Comments
 (0)