Skip to content

Commit 956ba12

Browse files
committed
Conversions from INET to IpAddr
We ignore the netmask when deserializing and use /32 or /128 when serializing. Closes sfackler#430
1 parent fd3f3fe commit 956ba12

File tree

3 files changed

+146
-0
lines changed

3 files changed

+146
-0
lines changed

postgres-protocol/src/types/mod.rs

+88
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use byteorder::{BigEndian, ByteOrder, ReadBytesExt, WriteBytesExt};
33
use fallible_iterator::FallibleIterator;
44
use std::boxed::Box as StdBox;
55
use std::error::Error;
6+
use std::io::Read;
7+
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
68
use std::str;
79

810
use crate::{write_nullable, FromUsize, IsNull, Oid};
@@ -16,6 +18,9 @@ const RANGE_UPPER_INCLUSIVE: u8 = 0b0000_0100;
1618
const RANGE_LOWER_INCLUSIVE: u8 = 0b0000_0010;
1719
const RANGE_EMPTY: u8 = 0b0000_0001;
1820

21+
const PGSQL_AF_INET: u8 = 2;
22+
const PGSQL_AF_INET6: u8 = 3;
23+
1924
/// Serializes a `BOOL` value.
2025
#[inline]
2126
pub fn bool_to_sql(v: bool, buf: &mut Vec<u8>) {
@@ -956,3 +961,86 @@ impl<'a> FallibleIterator for PathPoints<'a> {
956961
(len, Some(len))
957962
}
958963
}
964+
965+
/// Serializes a Postgres inet.
966+
#[inline]
967+
pub fn inet_to_sql(addr: IpAddr, netmask: u8, buf: &mut Vec<u8>) {
968+
let family = match addr {
969+
IpAddr::V4(_) => PGSQL_AF_INET,
970+
IpAddr::V6(_) => PGSQL_AF_INET6,
971+
};
972+
buf.push(family);
973+
buf.push(netmask);
974+
buf.push(0); // is_cidr
975+
match addr {
976+
IpAddr::V4(addr) => {
977+
buf.push(4);
978+
buf.extend_from_slice(&addr.octets());
979+
}
980+
IpAddr::V6(addr) => {
981+
buf.push(16);
982+
buf.extend_from_slice(&addr.octets());
983+
}
984+
}
985+
}
986+
987+
/// Deserializes a Postgres inet.
988+
#[inline]
989+
pub fn inet_from_sql(mut buf: &[u8]) -> Result<Inet, StdBox<dyn Error + Sync + Send>> {
990+
let family = buf.read_u8()?;
991+
let netmask = buf.read_u8()?;
992+
buf.read_u8()?; // is_cidr
993+
let len = buf.read_u8()?;
994+
995+
let addr = match family {
996+
PGSQL_AF_INET => {
997+
if netmask > 32 {
998+
return Err("invalid IPv4 netmask".into());
999+
}
1000+
if len != 4 {
1001+
return Err("invalid IPv4 address length".into());
1002+
}
1003+
let mut addr = [0; 4];
1004+
buf.read_exact(&mut addr)?;
1005+
IpAddr::V4(Ipv4Addr::from(addr))
1006+
}
1007+
PGSQL_AF_INET6 => {
1008+
if netmask > 128 {
1009+
return Err("invalid IPv6 netmask".into());
1010+
}
1011+
if len != 16 {
1012+
return Err("invalid IPv6 address length".into());
1013+
}
1014+
let mut addr = [0; 16];
1015+
buf.read_exact(&mut addr)?;
1016+
IpAddr::V6(Ipv6Addr::from(addr))
1017+
}
1018+
_ => return Err("invalid IP family".into()),
1019+
};
1020+
1021+
if !buf.is_empty() {
1022+
return Err("invalid buffer size".into());
1023+
}
1024+
1025+
Ok(Inet { addr, netmask })
1026+
}
1027+
1028+
/// A Postgres network address.
1029+
pub struct Inet {
1030+
addr: IpAddr,
1031+
netmask: u8,
1032+
}
1033+
1034+
impl Inet {
1035+
/// Returns the IP address.
1036+
#[inline]
1037+
pub fn addr(&self) -> IpAddr {
1038+
self.addr
1039+
}
1040+
1041+
/// Returns the netmask.
1042+
#[inline]
1043+
pub fn netmask(&self) -> u8 {
1044+
self.netmask
1045+
}
1046+
}

tokio-postgres/src/types/mod.rs

+27
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use std::collections::HashMap;
88
use std::error::Error;
99
use std::fmt;
1010
use std::hash::BuildHasher;
11+
use std::net::IpAddr;
1112
use std::sync::Arc;
1213
use std::time::{Duration, SystemTime, UNIX_EPOCH};
1314

@@ -248,6 +249,7 @@ impl WrongType {
248249
/// | `&[u8]`/`Vec<u8>` | BYTEA |
249250
/// | `HashMap<String, Option<String>>` | HSTORE |
250251
/// | `SystemTime` | TIMESTAMP, TIMESTAMP WITH TIME ZONE |
252+
/// | `IpAddr` | INET |
251253
///
252254
/// In addition, some implementations are provided for types in third party
253255
/// crates. These are disabled by default; to opt into one of these
@@ -469,6 +471,15 @@ impl<'a> FromSql<'a> for SystemTime {
469471
accepts!(TIMESTAMP, TIMESTAMPTZ);
470472
}
471473

474+
impl<'a> FromSql<'a> for IpAddr {
475+
fn from_sql(_: &Type, raw: &'a [u8]) -> Result<IpAddr, Box<dyn Error + Sync + Send>> {
476+
let inet = types::inet_from_sql(raw)?;
477+
Ok(inet.addr())
478+
}
479+
480+
accepts!(INET);
481+
}
482+
472483
/// An enum representing the nullability of a Postgres value.
473484
pub enum IsNull {
474485
/// The value is NULL.
@@ -498,6 +509,7 @@ pub enum IsNull {
498509
/// | `&[u8]`/Vec<u8>` | BYTEA |
499510
/// | `HashMap<String, Option<String>>` | HSTORE |
500511
/// | `SystemTime` | TIMESTAMP, TIMESTAMP WITH TIME ZONE |
512+
/// | `IpAddr` | INET |
501513
///
502514
/// In addition, some implementations are provided for types in third party
503515
/// crates. These are disabled by default; to opt into one of these
@@ -771,6 +783,21 @@ impl ToSql for SystemTime {
771783
to_sql_checked!();
772784
}
773785

786+
impl ToSql for IpAddr {
787+
fn to_sql(&self, _: &Type, w: &mut Vec<u8>) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
788+
let netmask = match self {
789+
IpAddr::V4(_) => 32,
790+
IpAddr::V6(_) => 128,
791+
};
792+
types::inet_to_sql(*self, netmask, w);
793+
Ok(IsNull::No)
794+
}
795+
796+
accepts!(INET);
797+
798+
to_sql_checked!();
799+
}
800+
774801
fn downcast(len: usize) -> Result<i32, Box<dyn Error + Sync + Send>> {
775802
if len > i32::max_value() as usize {
776803
Err("value too large to transmit".into())

tokio-postgres/tests/test/types/mod.rs

+31
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::error::Error;
44
use std::f32;
55
use std::f64;
66
use std::fmt;
7+
use std::net::IpAddr;
78
use std::result;
89
use std::time::{Duration, UNIX_EPOCH};
910
use tokio::runtime::current_thread::Runtime;
@@ -624,3 +625,33 @@ fn system_time() {
624625
],
625626
);
626627
}
628+
629+
#[test]
630+
fn inet() {
631+
test_type(
632+
"INET",
633+
&[
634+
(Some("127.0.0.1".parse::<IpAddr>().unwrap()), "'127.0.0.1'"),
635+
(
636+
Some("127.0.0.1".parse::<IpAddr>().unwrap()),
637+
"'127.0.0.1/32'",
638+
),
639+
(
640+
Some(
641+
"2001:4f8:3:ba:2e0:81ff:fe22:d1f1"
642+
.parse::<IpAddr>()
643+
.unwrap(),
644+
),
645+
"'2001:4f8:3:ba:2e0:81ff:fe22:d1f1'",
646+
),
647+
(
648+
Some(
649+
"2001:4f8:3:ba:2e0:81ff:fe22:d1f1"
650+
.parse::<IpAddr>()
651+
.unwrap(),
652+
),
653+
"'2001:4f8:3:ba:2e0:81ff:fe22:d1f1/128'",
654+
),
655+
],
656+
);
657+
}

0 commit comments

Comments
 (0)