Skip to content

Commit d451465

Browse files
authored
Merge pull request sfackler#1052 from sfackler/load-balance
Implement load balancing
2 parents b575745 + 84aed63 commit d451465

File tree

6 files changed

+151
-69
lines changed

6 files changed

+151
-69
lines changed

tokio-postgres/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ postgres-protocol = { version = "0.6.5", path = "../postgres-protocol" }
5858
postgres-types = { version = "0.2.4", path = "../postgres-types" }
5959
tokio = { version = "1.27", features = ["io-util"] }
6060
tokio-util = { version = "0.7", features = ["codec"] }
61+
rand = "0.8.5"
6162

6263
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
6364
socket2 = { version = "0.5", features = ["all"] }

tokio-postgres/src/cancel_query.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ where
3030
let has_hostname = config.hostname.is_some();
3131

3232
let socket = connect_socket::connect_socket(
33-
&config.host,
33+
&config.addr,
3434
config.port,
3535
config.connect_timeout,
3636
config.tcp_user_timeout,

tokio-postgres/src/client.rs

+13-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
use crate::codec::{BackendMessages, FrontendMessage};
2-
#[cfg(feature = "runtime")]
3-
use crate::config::Host;
42
use crate::config::SslMode;
53
use crate::connection::{Request, RequestMessages};
64
use crate::copy_out::CopyOutStream;
@@ -27,6 +25,10 @@ use postgres_protocol::message::{backend::Message, frontend};
2725
use postgres_types::BorrowToSql;
2826
use std::collections::HashMap;
2927
use std::fmt;
28+
#[cfg(feature = "runtime")]
29+
use std::net::IpAddr;
30+
#[cfg(feature = "runtime")]
31+
use std::path::PathBuf;
3032
use std::sync::Arc;
3133
use std::task::{Context, Poll};
3234
#[cfg(feature = "runtime")]
@@ -153,14 +155,22 @@ impl InnerClient {
153155
#[cfg(feature = "runtime")]
154156
#[derive(Clone)]
155157
pub(crate) struct SocketConfig {
156-
pub host: Host,
158+
pub addr: Addr,
157159
pub hostname: Option<String>,
158160
pub port: u16,
159161
pub connect_timeout: Option<Duration>,
160162
pub tcp_user_timeout: Option<Duration>,
161163
pub keepalive: Option<KeepaliveConfig>,
162164
}
163165

166+
#[cfg(feature = "runtime")]
167+
#[derive(Clone)]
168+
pub(crate) enum Addr {
169+
Tcp(IpAddr),
170+
#[cfg(unix)]
171+
Unix(PathBuf),
172+
}
173+
164174
/// An asynchronous PostgreSQL client.
165175
///
166176
/// The client is one half of what is returned when a connection is established. Users interact with the database

tokio-postgres/src/config.rs

+43
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,16 @@ pub enum ChannelBinding {
6060
Require,
6161
}
6262

63+
/// Load balancing configuration.
64+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
65+
#[non_exhaustive]
66+
pub enum LoadBalanceHosts {
67+
/// Make connection attempts to hosts in the order provided.
68+
Disable,
69+
/// Make connection attempts to hosts in a random order.
70+
Random,
71+
}
72+
6373
/// A host specification.
6474
#[derive(Debug, Clone, PartialEq, Eq)]
6575
pub enum Host {
@@ -129,6 +139,12 @@ pub enum Host {
129139
/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel
130140
/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise.
131141
/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`.
142+
/// * `load_balance_hosts` - Controls the order in which the client tries to connect to the available hosts and
143+
/// addresses. Once a connection attempt is successful no other hosts and addresses will be tried. This parameter
144+
/// is typically used in combination with multiple host names or a DNS record that returns multiple IPs. If set to
145+
/// `disable`, hosts and addresses will be tried in the order provided. If set to `random`, hosts will be tried
146+
/// in a random order, and the IP addresses resolved from a hostname will also be tried in a random order. Defaults
147+
/// to `disable`.
132148
///
133149
/// ## Examples
134150
///
@@ -190,6 +206,7 @@ pub struct Config {
190206
pub(crate) keepalive_config: KeepaliveConfig,
191207
pub(crate) target_session_attrs: TargetSessionAttrs,
192208
pub(crate) channel_binding: ChannelBinding,
209+
pub(crate) load_balance_hosts: LoadBalanceHosts,
193210
}
194211

195212
impl Default for Config {
@@ -222,6 +239,7 @@ impl Config {
222239
},
223240
target_session_attrs: TargetSessionAttrs::Any,
224241
channel_binding: ChannelBinding::Prefer,
242+
load_balance_hosts: LoadBalanceHosts::Disable,
225243
}
226244
}
227245

@@ -489,6 +507,19 @@ impl Config {
489507
self.channel_binding
490508
}
491509

510+
/// Sets the host load balancing behavior.
511+
///
512+
/// Defaults to `disable`.
513+
pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config {
514+
self.load_balance_hosts = load_balance_hosts;
515+
self
516+
}
517+
518+
/// Gets the host load balancing behavior.
519+
pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts {
520+
self.load_balance_hosts
521+
}
522+
492523
fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
493524
match key {
494525
"user" => {
@@ -612,6 +643,18 @@ impl Config {
612643
};
613644
self.channel_binding(channel_binding);
614645
}
646+
"load_balance_hosts" => {
647+
let load_balance_hosts = match value {
648+
"disable" => LoadBalanceHosts::Disable,
649+
"random" => LoadBalanceHosts::Random,
650+
_ => {
651+
return Err(Error::config_parse(Box::new(InvalidValue(
652+
"load_balance_hosts",
653+
))))
654+
}
655+
};
656+
self.load_balance_hosts(load_balance_hosts);
657+
}
615658
key => {
616659
return Err(Error::config_parse(Box::new(UnknownOption(
617660
key.to_string(),

tokio-postgres/src/connect.rs

+71-22
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
use crate::client::SocketConfig;
2-
use crate::config::{Host, TargetSessionAttrs};
1+
use crate::client::{Addr, SocketConfig};
2+
use crate::config::{Host, LoadBalanceHosts, TargetSessionAttrs};
33
use crate::connect_raw::connect_raw;
44
use crate::connect_socket::connect_socket;
5-
use crate::tls::{MakeTlsConnect, TlsConnect};
5+
use crate::tls::MakeTlsConnect;
66
use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket};
77
use futures_util::{future, pin_mut, Future, FutureExt, Stream};
8+
use rand::seq::SliceRandom;
89
use std::task::Poll;
910
use std::{cmp, io};
11+
use tokio::net;
1012

1113
pub async fn connect<T>(
1214
mut tls: T,
@@ -40,8 +42,13 @@ where
4042
return Err(Error::config("invalid number of ports".into()));
4143
}
4244

45+
let mut indices = (0..num_hosts).collect::<Vec<_>>();
46+
if config.load_balance_hosts == LoadBalanceHosts::Random {
47+
indices.shuffle(&mut rand::thread_rng());
48+
}
49+
4350
let mut error = None;
44-
for i in 0..num_hosts {
51+
for i in indices {
4552
let host = config.host.get(i);
4653
let hostaddr = config.hostaddr.get(i);
4754
let port = config
@@ -59,25 +66,15 @@ where
5966
Some(Host::Unix(_)) => None,
6067
None => None,
6168
};
62-
let tls = tls
63-
.make_tls_connect(hostname.as_deref().unwrap_or(""))
64-
.map_err(|e| Error::tls(e.into()))?;
6569

6670
// Try to use the value of hostaddr to establish the TCP connection,
6771
// fallback to host if hostaddr is not present.
6872
let addr = match hostaddr {
6973
Some(ipaddr) => Host::Tcp(ipaddr.to_string()),
70-
None => {
71-
if let Some(host) = host {
72-
host.clone()
73-
} else {
74-
// This is unreachable.
75-
return Err(Error::config("both host and hostaddr are empty".into()));
76-
}
77-
}
74+
None => host.cloned().unwrap(),
7875
};
7976

80-
match connect_once(addr, hostname, port, tls, config).await {
77+
match connect_host(addr, hostname, port, &mut tls, config).await {
8178
Ok((client, connection)) => return Ok((client, connection)),
8279
Err(e) => error = Some(e),
8380
}
@@ -86,18 +83,66 @@ where
8683
Err(error.unwrap())
8784
}
8885

89-
async fn connect_once<T>(
86+
async fn connect_host<T>(
9087
host: Host,
9188
hostname: Option<String>,
9289
port: u16,
93-
tls: T,
90+
tls: &mut T,
91+
config: &Config,
92+
) -> Result<(Client, Connection<Socket, T::Stream>), Error>
93+
where
94+
T: MakeTlsConnect<Socket>,
95+
{
96+
match host {
97+
Host::Tcp(host) => {
98+
let mut addrs = net::lookup_host((&*host, port))
99+
.await
100+
.map_err(Error::connect)?
101+
.collect::<Vec<_>>();
102+
103+
if config.load_balance_hosts == LoadBalanceHosts::Random {
104+
addrs.shuffle(&mut rand::thread_rng());
105+
}
106+
107+
let mut last_err = None;
108+
for addr in addrs {
109+
match connect_once(Addr::Tcp(addr.ip()), hostname.as_deref(), port, tls, config)
110+
.await
111+
{
112+
Ok(stream) => return Ok(stream),
113+
Err(e) => {
114+
last_err = Some(e);
115+
continue;
116+
}
117+
};
118+
}
119+
120+
Err(last_err.unwrap_or_else(|| {
121+
Error::connect(io::Error::new(
122+
io::ErrorKind::InvalidInput,
123+
"could not resolve any addresses",
124+
))
125+
}))
126+
}
127+
#[cfg(unix)]
128+
Host::Unix(path) => {
129+
connect_once(Addr::Unix(path), hostname.as_deref(), port, tls, config).await
130+
}
131+
}
132+
}
133+
134+
async fn connect_once<T>(
135+
addr: Addr,
136+
hostname: Option<&str>,
137+
port: u16,
138+
tls: &mut T,
94139
config: &Config,
95140
) -> Result<(Client, Connection<Socket, T::Stream>), Error>
96141
where
97-
T: TlsConnect<Socket>,
142+
T: MakeTlsConnect<Socket>,
98143
{
99144
let socket = connect_socket(
100-
&host,
145+
&addr,
101146
port,
102147
config.connect_timeout,
103148
config.tcp_user_timeout,
@@ -108,6 +153,10 @@ where
108153
},
109154
)
110155
.await?;
156+
157+
let tls = tls
158+
.make_tls_connect(hostname.unwrap_or(""))
159+
.map_err(|e| Error::tls(e.into()))?;
111160
let has_hostname = hostname.is_some();
112161
let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?;
113162

@@ -152,8 +201,8 @@ where
152201
}
153202

154203
client.set_socket_config(SocketConfig {
155-
host,
156-
hostname,
204+
addr,
205+
hostname: hostname.map(|s| s.to_string()),
157206
port,
158207
connect_timeout: config.connect_timeout,
159208
tcp_user_timeout: config.tcp_user_timeout,

tokio-postgres/src/connect_socket.rs

+22-43
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,50 @@
1-
use crate::config::Host;
1+
use crate::client::Addr;
22
use crate::keepalive::KeepaliveConfig;
33
use crate::{Error, Socket};
44
use socket2::{SockRef, TcpKeepalive};
55
use std::future::Future;
66
use std::io;
77
use std::time::Duration;
8+
use tokio::net::TcpStream;
89
#[cfg(unix)]
910
use tokio::net::UnixStream;
10-
use tokio::net::{self, TcpStream};
1111
use tokio::time;
1212

1313
pub(crate) async fn connect_socket(
14-
host: &Host,
14+
addr: &Addr,
1515
port: u16,
1616
connect_timeout: Option<Duration>,
1717
#[cfg_attr(not(target_os = "linux"), allow(unused_variables))] tcp_user_timeout: Option<
1818
Duration,
1919
>,
2020
keepalive_config: Option<&KeepaliveConfig>,
2121
) -> Result<Socket, Error> {
22-
match host {
23-
Host::Tcp(host) => {
24-
let addrs = net::lookup_host((&**host, port))
25-
.await
26-
.map_err(Error::connect)?;
22+
match addr {
23+
Addr::Tcp(ip) => {
24+
let stream =
25+
connect_with_timeout(TcpStream::connect((*ip, port)), connect_timeout).await?;
2726

28-
let mut last_err = None;
27+
stream.set_nodelay(true).map_err(Error::connect)?;
2928

30-
for addr in addrs {
31-
let stream =
32-
match connect_with_timeout(TcpStream::connect(addr), connect_timeout).await {
33-
Ok(stream) => stream,
34-
Err(e) => {
35-
last_err = Some(e);
36-
continue;
37-
}
38-
};
39-
40-
stream.set_nodelay(true).map_err(Error::connect)?;
41-
42-
let sock_ref = SockRef::from(&stream);
43-
#[cfg(target_os = "linux")]
44-
{
45-
sock_ref
46-
.set_tcp_user_timeout(tcp_user_timeout)
47-
.map_err(Error::connect)?;
48-
}
49-
50-
if let Some(keepalive_config) = keepalive_config {
51-
sock_ref
52-
.set_tcp_keepalive(&TcpKeepalive::from(keepalive_config))
53-
.map_err(Error::connect)?;
54-
}
29+
let sock_ref = SockRef::from(&stream);
30+
#[cfg(target_os = "linux")]
31+
{
32+
sock_ref
33+
.set_tcp_user_timeout(tcp_user_timeout)
34+
.map_err(Error::connect)?;
35+
}
5536

56-
return Ok(Socket::new_tcp(stream));
37+
if let Some(keepalive_config) = keepalive_config {
38+
sock_ref
39+
.set_tcp_keepalive(&TcpKeepalive::from(keepalive_config))
40+
.map_err(Error::connect)?;
5741
}
5842

59-
Err(last_err.unwrap_or_else(|| {
60-
Error::connect(io::Error::new(
61-
io::ErrorKind::InvalidInput,
62-
"could not resolve any addresses",
63-
))
64-
}))
43+
Ok(Socket::new_tcp(stream))
6544
}
6645
#[cfg(unix)]
67-
Host::Unix(path) => {
68-
let path = path.join(format!(".s.PGSQL.{}", port));
46+
Addr::Unix(dir) => {
47+
let path = dir.join(format!(".s.PGSQL.{}", port));
6948
let socket = connect_with_timeout(UnixStream::connect(path), connect_timeout).await?;
7049
Ok(Socket::new_unix(socket))
7150
}

0 commit comments

Comments
 (0)