Skip to content

Commit 919012d

Browse files
committed
Finish convenience API
1 parent 7df7fc7 commit 919012d

File tree

8 files changed

+280
-2
lines changed

8 files changed

+280
-2
lines changed

tokio-postgres/Cargo.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ circle-ci = { repository = "sfackler/rust-postgres" }
2828

2929
[features]
3030
default = ["runtime"]
31-
runtime = ["tokio-tcp", "tokio-uds"]
31+
runtime = ["tokio-tcp", "tokio-uds", "futures-cpupool", "lazy_static"]
3232

3333
"with-bit-vec-0.5" = ["bit-vec-05"]
3434
"with-chrono-0.4" = ["chrono-04"]
@@ -42,7 +42,6 @@ antidote = "1.0"
4242
bytes = "0.4"
4343
fallible-iterator = "0.1.6"
4444
futures = "0.1.7"
45-
futures-cpupool = "0.1"
4645
log = "0.4"
4746
phf = "0.7.23"
4847
postgres-protocol = { version = "0.3.0", path = "../postgres-protocol" }
@@ -52,6 +51,8 @@ tokio-io = "0.1"
5251
void = "1.0"
5352

5453
tokio-tcp = { version = "0.1", optional = true }
54+
futures-cpupool = { version = "0.1", optional = true }
55+
lazy_static = { version = "1.0", optional = true }
5556

5657
bit-vec-05 = { version = "0.5", package = "bit-vec", optional = true }
5758
chrono-04 = { version = "0.4", package = "chrono", optional = true }

tokio-postgres/src/builder.rs

+12
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@ use std::iter;
33
use std::str::{self, FromStr};
44
use tokio_io::{AsyncRead, AsyncWrite};
55

6+
#[cfg(feature = "runtime")]
7+
use crate::proto::ConnectFuture;
68
use crate::proto::HandshakeFuture;
9+
#[cfg(feature = "runtime")]
10+
use crate::{Connect, Socket};
711
use crate::{Error, Handshake, TlsMode};
812

913
#[derive(Clone)]
@@ -55,6 +59,14 @@ impl Builder {
5559
{
5660
Handshake(HandshakeFuture::new(stream, tls_mode, self.params.clone()))
5761
}
62+
63+
#[cfg(feature = "runtime")]
64+
pub fn connect<T>(&self, tls_mode: T) -> Connect<T>
65+
where
66+
T: TlsMode<Socket>,
67+
{
68+
Connect(ConnectFuture::new(tls_mode, self.params.clone()))
69+
}
5870
}
5971

6072
impl FromStr for Builder {

tokio-postgres/src/error/mod.rs

+27
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody};
55
use std::error::{self, Error as _Error};
66
use std::fmt;
77
use std::io;
8+
#[cfg(feature = "runtime")]
9+
use std::num::ParseIntError;
810

911
pub use self::sqlstate::*;
1012

@@ -346,6 +348,11 @@ enum Kind {
346348
UnsupportedAuthentication,
347349
Authentication,
348350
ConnectionSyntax,
351+
Connect,
352+
#[cfg(feature = "runtime")]
353+
MissingHost,
354+
#[cfg(feature = "runtime")]
355+
InvalidPort,
349356
}
350357

351358
struct ErrorInner {
@@ -383,6 +390,11 @@ impl fmt::Display for Error {
383390
Kind::UnsupportedAuthentication => "unsupported authentication method requested",
384391
Kind::Authentication => "authentication error",
385392
Kind::ConnectionSyntax => "invalid connection string",
393+
Kind::Connect => "error connecting to server",
394+
#[cfg(feature = "runtime")]
395+
Kind::MissingHost => "host not provided",
396+
#[cfg(feature = "runtime")]
397+
Kind::InvalidPort => "invalid port",
386398
};
387399
fmt.write_str(s)?;
388400
if let Some(ref cause) = self.0.cause {
@@ -485,4 +497,19 @@ impl Error {
485497
pub(crate) fn connection_syntax(e: Box<dyn error::Error + Sync + Send>) -> Error {
486498
Error::new(Kind::ConnectionSyntax, Some(e))
487499
}
500+
501+
#[cfg(feature = "runtime")]
502+
pub(crate) fn connect(e: io::Error) -> Error {
503+
Error::new(Kind::Connect, Some(Box::new(e)))
504+
}
505+
506+
#[cfg(feature = "runtime")]
507+
pub(crate) fn missing_host() -> Error {
508+
Error::new(Kind::MissingHost, None)
509+
}
510+
511+
#[cfg(feature = "runtime")]
512+
pub(crate) fn invalid_port(e: ParseIntError) -> Error {
513+
Error::new(Kind::InvalidPort, Some(Box::new(e)))
514+
}
488515
}

tokio-postgres/src/lib.rs

+21
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,27 @@ where
180180
}
181181
}
182182

183+
#[cfg(feature = "runtime")]
184+
#[must_use = "futures do nothing unless polled"]
185+
pub struct Connect<T>(proto::ConnectFuture<T>)
186+
where
187+
T: TlsMode<Socket>;
188+
189+
#[cfg(feature = "runtime")]
190+
impl<T> Future for Connect<T>
191+
where
192+
T: TlsMode<Socket>,
193+
{
194+
type Item = (Client, Connection<T::Stream>);
195+
type Error = Error;
196+
197+
fn poll(&mut self) -> Poll<(Client, Connection<T::Stream>), Error> {
198+
let (client, connection) = try_ready!(self.0.poll());
199+
200+
Ok(Async::Ready((Client(client), Connection(connection))))
201+
}
202+
}
203+
183204
#[must_use = "futures do nothing unless polled"]
184205
pub struct Prepare(proto::PrepareFuture);
185206

tokio-postgres/src/proto/connect.rs

+177
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
use futures::{try_ready, Async, Future, Poll};
2+
use futures_cpupool::{CpuFuture, CpuPool};
3+
use lazy_static::lazy_static;
4+
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
5+
use std::collections::HashMap;
6+
use std::io;
7+
use std::net::{SocketAddr, ToSocketAddrs};
8+
#[cfg(unix)]
9+
use std::path::Path;
10+
use std::vec;
11+
use tokio_tcp::TcpStream;
12+
#[cfg(unix)]
13+
use tokio_uds::UnixStream;
14+
15+
use crate::proto::{Client, Connection, HandshakeFuture};
16+
use crate::{Error, Socket, TlsMode};
17+
18+
lazy_static! {
19+
static ref DNS_POOL: CpuPool = futures_cpupool::Builder::new()
20+
.name_prefix("postgres-dns-")
21+
.pool_size(2)
22+
.create();
23+
}
24+
25+
#[derive(StateMachineFuture)]
26+
pub enum Connect<T>
27+
where
28+
T: TlsMode<Socket>,
29+
{
30+
#[state_machine_future(start)]
31+
#[cfg_attr(unix, state_machine_future(transitions(ConnectingUnix, ResolvingDns)))]
32+
#[cfg_attr(not(unix), state_machine_future(transitions(ConnectingTcp)))]
33+
Start {
34+
tls_mode: T,
35+
params: HashMap<String, String>,
36+
},
37+
#[cfg(unix)]
38+
#[state_machine_future(transitions(Handshaking))]
39+
ConnectingUnix {
40+
future: tokio_uds::ConnectFuture,
41+
tls_mode: T,
42+
params: HashMap<String, String>,
43+
},
44+
#[state_machine_future(transitions(ConnectingTcp))]
45+
ResolvingDns {
46+
future: CpuFuture<vec::IntoIter<SocketAddr>, io::Error>,
47+
tls_mode: T,
48+
params: HashMap<String, String>,
49+
},
50+
#[state_machine_future(transitions(Handshaking))]
51+
ConnectingTcp {
52+
future: tokio_tcp::ConnectFuture,
53+
addrs: vec::IntoIter<SocketAddr>,
54+
tls_mode: T,
55+
params: HashMap<String, String>,
56+
},
57+
#[state_machine_future(transitions(Finished))]
58+
Handshaking { future: HandshakeFuture<Socket, T> },
59+
#[state_machine_future(ready)]
60+
Finished((Client, Connection<T::Stream>)),
61+
#[state_machine_future(error)]
62+
Failed(Error),
63+
}
64+
65+
impl<T> PollConnect<T> for Connect<T>
66+
where
67+
T: TlsMode<Socket>,
68+
{
69+
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
70+
let mut state = state.take();
71+
72+
let host = match state.params.remove("host") {
73+
Some(host) => host,
74+
None => return Err(Error::missing_host()),
75+
};
76+
77+
let port = match state.params.remove("port") {
78+
Some(port) => port.parse::<u16>().map_err(Error::invalid_port)?,
79+
None => 5432,
80+
};
81+
82+
#[cfg(unix)]
83+
{
84+
if host.starts_with('/') {
85+
let path = Path::new(&host).join(format!(".s.PGSQL.{}", port));
86+
transition!(ConnectingUnix {
87+
future: UnixStream::connect(path),
88+
tls_mode: state.tls_mode,
89+
params: state.params,
90+
})
91+
}
92+
}
93+
94+
transition!(ResolvingDns {
95+
future: DNS_POOL.spawn_fn(move || (&*host, port).to_socket_addrs()),
96+
tls_mode: state.tls_mode,
97+
params: state.params,
98+
})
99+
}
100+
101+
#[cfg(unix)]
102+
fn poll_connecting_unix<'a>(
103+
state: &'a mut RentToOwn<'a, ConnectingUnix<T>>,
104+
) -> Poll<AfterConnectingUnix<T>, Error> {
105+
let stream = try_ready!(state.future.poll().map_err(Error::connect));
106+
let stream = Socket::new_unix(stream);
107+
let state = state.take();
108+
109+
transition!(Handshaking {
110+
future: HandshakeFuture::new(stream, state.tls_mode, state.params)
111+
})
112+
}
113+
114+
fn poll_resolving_dns<'a>(
115+
state: &'a mut RentToOwn<'a, ResolvingDns<T>>,
116+
) -> Poll<AfterResolvingDns<T>, Error> {
117+
let mut addrs = try_ready!(state.future.poll().map_err(Error::connect));
118+
let state = state.take();
119+
120+
let addr = match addrs.next() {
121+
Some(addr) => addr,
122+
None => {
123+
return Err(Error::connect(io::Error::new(
124+
io::ErrorKind::InvalidData,
125+
"resolved 0 addresses",
126+
)))
127+
}
128+
};
129+
130+
transition!(ConnectingTcp {
131+
future: TcpStream::connect(&addr),
132+
addrs,
133+
tls_mode: state.tls_mode,
134+
params: state.params,
135+
})
136+
}
137+
138+
fn poll_connecting_tcp<'a>(
139+
state: &'a mut RentToOwn<'a, ConnectingTcp<T>>,
140+
) -> Poll<AfterConnectingTcp<T>, Error> {
141+
let stream = loop {
142+
match state.future.poll() {
143+
Ok(Async::Ready(stream)) => break Socket::new_tcp(stream),
144+
Ok(Async::NotReady) => return Ok(Async::NotReady),
145+
Err(e) => {
146+
let addr = match state.addrs.next() {
147+
Some(addr) => addr,
148+
None => return Err(Error::connect(e)),
149+
};
150+
state.future = TcpStream::connect(&addr);
151+
}
152+
}
153+
};
154+
let state = state.take();
155+
156+
transition!(Handshaking {
157+
future: HandshakeFuture::new(stream, state.tls_mode, state.params),
158+
})
159+
}
160+
161+
fn poll_handshaking<'a>(
162+
state: &'a mut RentToOwn<'a, Handshaking<T>>,
163+
) -> Poll<AfterHandshaking<T>, Error> {
164+
let r = try_ready!(state.future.poll());
165+
166+
transition!(Finished(r))
167+
}
168+
}
169+
170+
impl<T> ConnectFuture<T>
171+
where
172+
T: TlsMode<Socket>,
173+
{
174+
pub fn new(tls_mode: T, params: HashMap<String, String>) -> ConnectFuture<T> {
175+
Connect::start(tls_mode, params)
176+
}
177+
}

tokio-postgres/src/proto/mod.rs

+4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ mod bind;
2222
mod cancel;
2323
mod client;
2424
mod codec;
25+
#[cfg(feature = "runtime")]
26+
mod connect;
2527
mod connection;
2628
mod copy_in;
2729
mod copy_out;
@@ -42,6 +44,8 @@ pub use crate::proto::bind::BindFuture;
4244
pub use crate::proto::cancel::CancelFuture;
4345
pub use crate::proto::client::Client;
4446
pub use crate::proto::codec::PostgresCodec;
47+
#[cfg(feature = "runtime")]
48+
pub use crate::proto::connect::ConnectFuture;
4549
pub use crate::proto::connection::Connection;
4650
pub use crate::proto::copy_in::CopyInFuture;
4751
pub use crate::proto::copy_out::CopyOutStream;

tokio-postgres/tests/test/main.rs

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ use tokio_postgres::types::{Kind, Type};
1414
use tokio_postgres::{AsyncMessage, Client, Connection, NoTls};
1515

1616
mod parse;
17+
#[cfg(feature = "runtime")]
18+
mod runtime;
1719
mod types;
1820

1921
fn connect(

tokio-postgres/tests/test/runtime.rs

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
use futures::Future;
2+
use tokio::runtime::current_thread::Runtime;
3+
use tokio_postgres::{Client, Connection, Error, NoTls, Socket};
4+
5+
fn connect(s: &str) -> impl Future<Item = (Client, Connection<Socket>), Error = Error> {
6+
s.parse::<tokio_postgres::Builder>().unwrap().connect(NoTls)
7+
}
8+
9+
#[test]
10+
#[ignore] // FIXME doesn't work with our docker-based tests :(
11+
fn unix_socket() {
12+
let mut runtime = Runtime::new().unwrap();
13+
14+
let connect = connect("host=/var/run/postgresql port=5433 user=postgres");
15+
let (mut client, connection) = runtime.block_on(connect).unwrap();
16+
let connection = connection.map_err(|e| panic!("{}", e));
17+
runtime.spawn(connection);
18+
19+
let execute = client.batch_execute("SELECT 1");
20+
runtime.block_on(execute).unwrap();
21+
}
22+
23+
#[test]
24+
fn tcp() {
25+
let mut runtime = Runtime::new().unwrap();
26+
27+
let connect = connect("host=localhost port=5433 user=postgres");
28+
let (mut client, connection) = runtime.block_on(connect).unwrap();
29+
let connection = connection.map_err(|e| panic!("{}", e));
30+
runtime.spawn(connection);
31+
32+
let execute = client.batch_execute("SELECT 1");
33+
runtime.block_on(execute).unwrap();
34+
}

0 commit comments

Comments
 (0)