diff --git a/.travis.yml b/.travis.yml index 52338e473d..8618c4ddc1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,7 +3,7 @@ sudo: false language: rust rust: - nightly - - 1.0.0 + - 1.1.0 os: - linux @@ -37,7 +37,7 @@ deploy: upload-dir: nix/${TRAVIS_BRANCH}/${TRAVIS_OS_NAME} acl: public_read on: - condition: "\"$TRAVIS_RUST_VERSION/$ARCH\" == \"1.0.0/x86_64\"" + condition: "\"$TRAVIS_RUST_VERSION/$ARCH\" == \"1.1.0/x86_64\"" repo: carllerche/nix-rust branch: - master diff --git a/nix-test/src/const.c b/nix-test/src/const.c index a0a6a11678..604476290d 100644 --- a/nix-test/src/const.c +++ b/nix-test/src/const.c @@ -289,6 +289,9 @@ get_int_const(const char* err) { GET_CONST(MSG_OOB); GET_CONST(MSG_PEEK); GET_CONST(MSG_DONTWAIT); + GET_CONST(MSG_EOR); + GET_CONST(MSG_TRUNC); + GET_CONST(MSG_CTRUNC); GET_CONST(SHUT_RD); GET_CONST(SHUT_WR); GET_CONST(SHUT_RDWR); @@ -312,6 +315,7 @@ get_int_const(const char* err) { // GET_CONST(SO_PEEK_OFF); GET_CONST(SO_PEERCRED); GET_CONST(SO_SNDBUFFORCE); + GET_CONST(MSG_ERRQUEUE); #endif return -1; diff --git a/src/sys/socket/consts.rs b/src/sys/socket/consts.rs index 4deb9f59c2..35b071657a 100644 --- a/src/sys/socket/consts.rs +++ b/src/sys/socket/consts.rs @@ -84,16 +84,23 @@ mod os { pub const INADDR_NONE: InAddrT = 0xffffffff; pub const INADDR_BROADCAST: InAddrT = 0xffffffff; - pub type SockMessageFlags = i32; + pub type SockMessageFlags = c_int; // Flags for send/recv and their relatives pub const MSG_OOB: SockMessageFlags = 0x1; pub const MSG_PEEK: SockMessageFlags = 0x2; + pub const MSG_CTRUNC: SockMessageFlags = 0x08; + pub const MSG_TRUNC: SockMessageFlags = 0x20; pub const MSG_DONTWAIT: SockMessageFlags = 0x40; + pub const MSG_EOR: SockMessageFlags = 0x80; + pub const MSG_ERRQUEUE: SockMessageFlags = 0x2000; // shutdown flags pub const SHUT_RD: c_int = 0; pub const SHUT_WR: c_int = 1; pub const SHUT_RDWR: c_int = 2; + + // Ancillary message types + pub const SCM_RIGHTS: c_int = 1; } // Not all of these constants exist on freebsd @@ -197,12 +204,18 @@ mod os { // Flags for send/recv and their relatives pub const MSG_OOB: SockMessageFlags = 0x1; pub const MSG_PEEK: SockMessageFlags = 0x2; + pub const MSG_EOR: SockMessageFlags = 0x8; + pub const MSG_TRUNC: SockMessageFlags = 0x10; + pub const MSG_CTRUNC: SockMessageFlags = 0x20; pub const MSG_DONTWAIT: SockMessageFlags = 0x80; // shutdown flags pub const SHUT_RD: c_int = 0; pub const SHUT_WR: c_int = 1; pub const SHUT_RDWR: c_int = 2; + + // Ancillary message types + pub const SCM_RIGHTS: c_int = 1; } #[cfg(target_os = "dragonfly")] @@ -340,6 +353,9 @@ mod test { MSG_OOB, MSG_PEEK, MSG_DONTWAIT, + MSG_EOR, + MSG_TRUNC, + MSG_CTRUNC, SHUT_RD, SHUT_WR, SHUT_RDWR @@ -370,6 +386,7 @@ mod test { SO_RCVBUFFORCE, // SO_PEEK_OFF, SO_PEERCRED, - SO_SNDBUFFORCE); + SO_SNDBUFFORCE, + MSG_ERRQUEUE); } } diff --git a/src/sys/socket/ffi.rs b/src/sys/socket/ffi.rs index 11fd1ff1c3..1351071d10 100644 --- a/src/sys/socket/ffi.rs +++ b/src/sys/socket/ffi.rs @@ -1,5 +1,9 @@ -use libc::{c_int, c_void, socklen_t}; +// Silence invalid warnings due to rust-lang/rust#16719 +#![allow(improper_ctypes)] + +use libc::{c_int, c_void, socklen_t, ssize_t}; pub use libc::{socket, listen, bind, accept, connect, setsockopt, sendto, recvfrom, getsockname, getpeername, recv, send}; +use super::msghdr; extern { pub fn getsockopt( @@ -15,4 +19,7 @@ extern { protocol: c_int, sv: *mut c_int ) -> c_int; + + pub fn sendmsg(sockfd: c_int, msg: *const msghdr, flags: c_int) -> ssize_t; + pub fn recvmsg(sockfd: c_int, msg: *mut msghdr, flags: c_int) -> ssize_t; } diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs index 21a93060d2..37081a1f10 100644 --- a/src/sys/socket/mod.rs +++ b/src/sys/socket/mod.rs @@ -7,8 +7,9 @@ use features; use fcntl::{fcntl, FD_CLOEXEC, O_NONBLOCK}; use fcntl::FcntlArg::{F_SETFD, F_SETFL}; use libc::{c_void, c_int, socklen_t, size_t}; -use std::{mem, ptr}; +use std::{mem, ptr, slice}; use std::os::unix::io::RawFd; +use sys::uio::IoVec; mod addr; mod consts; @@ -76,6 +77,291 @@ bitflags!( } ); +/// Copy the in-memory representation of src into the byte slice dst, +/// updating the slice to point to the remainder of dst only. Unsafe +/// because it exposes all bytes in src, which may be UB if some of them +/// are uninitialized (including padding). +unsafe fn copy_bytes<'a, 'b, T: ?Sized>(src: &T, dst: &'a mut &'b mut [u8]) { + let srclen = mem::size_of_val(src); + let mut tmpdst = &mut [][..]; + mem::swap(&mut tmpdst, dst); + let (target, mut remainder) = tmpdst.split_at_mut(srclen); + // Safe because the mutable borrow of dst guarantees that src does not alias it. + ptr::copy_nonoverlapping(src as *const T as *const u8, target.as_mut_ptr(), srclen); + mem::swap(dst, &mut remainder); +} + +// Private because we don't expose any external functions that operate +// directly on this type; we just use it internally at FFI boundaries. +// Note that in some cases we store pointers in *const fields that the +// kernel will proceed to mutate, so users should be careful about the +// actual mutability of data pointed to by this structure. +#[repr(C)] +struct msghdr<'a> { + msg_name: *const c_void, + msg_namelen: socklen_t, + msg_iov: *const IoVec<&'a [u8]>, + msg_iovlen: size_t, + msg_control: *const c_void, + msg_controllen: size_t, + msg_flags: c_int, +} + +#[cfg(target_os = "linux")] +type type_of_cmsg_len = size_t; +#[cfg(not(target_os = "linux"))] +type type_of_cmsg_len = socklen_t; + +// As above, private because we don't expose any external functions that +// operate directly on this type, or any external types with a public +// cmsghdr member. +#[repr(C)] +struct cmsghdr { + pub cmsg_len: type_of_cmsg_len, + pub cmsg_level: c_int, + pub cmsg_type: c_int, + cmsg_data: [size_t; 0] +} + +/// A structure used to make room in a cmsghdr passed to recvmsg. The +/// size and alignment match that of a cmsghdr followed by a T, but the +/// fields are not accessible, as the actual types will change on a call +/// to recvmsg. +/// +/// To make room for multiple messages, nest the type parameter with +/// tuples, e.g. +/// `let cmsg: CmsgSpace<([RawFd; 3], CmsgSpace<[RawFd; 2]>)> = CmsgSpace::new();` +pub struct CmsgSpace { + _hdr: cmsghdr, + _data: T, +} + +impl CmsgSpace { + /// Create a CmsgSpace. The structure is used only for space, so + /// the fields are uninitialized. + pub fn new() -> Self { + // Safe because the fields themselves aren't accessible. + unsafe { mem::uninitialized() } + } +} + +pub struct RecvMsg<'a> { + // The number of bytes received. + pub bytes: usize, + cmsg_buffer: &'a [u8], + pub address: Option, + pub flags: SockMessageFlags, +} + +impl<'a> RecvMsg<'a> { + /// Iterate over the valid control messages pointed to by this + /// msghdr. + pub fn cmsgs(&self) -> CmsgIterator { + CmsgIterator(self.cmsg_buffer) + } +} + +pub struct CmsgIterator<'a>(&'a [u8]); + +impl<'a> Iterator for CmsgIterator<'a> { + type Item = ControlMessage<'a>; + + // The implementation loosely follows CMSG_FIRSTHDR / CMSG_NXTHDR, + // although we handle the invariants in slightly different places to + // get a better iterator interface. + fn next(&mut self) -> Option> { + let buf = self.0; + let sizeof_cmsghdr = mem::size_of::(); + if buf.len() < sizeof_cmsghdr { + return None; + } + let cmsg: &cmsghdr = unsafe { mem::transmute(buf.as_ptr()) }; + + // This check is only in the glibc implementation of CMSG_NXTHDR + // (although it claims the kernel header checks this), but such + // a structure is clearly invalid, either way. + let cmsg_len = cmsg.cmsg_len as usize; + if cmsg_len < sizeof_cmsghdr { + return None; + } + let len = cmsg_len - sizeof_cmsghdr; + + // Advance our internal pointer. + if cmsg_align(cmsg_len) > buf.len() { + return None; + } + self.0 = &buf[cmsg_align(cmsg_len)..]; + + match (cmsg.cmsg_level, cmsg.cmsg_type) { + (SOL_SOCKET, SCM_RIGHTS) => unsafe { + Some(ControlMessage::ScmRights( + slice::from_raw_parts( + &cmsg.cmsg_data as *const _ as *const _, + len / mem::size_of::()))) + }, + (_, _) => unsafe { + Some(ControlMessage::Unknown(UnknownCmsg( + &cmsg, + slice::from_raw_parts( + &cmsg.cmsg_data as *const _ as *const _, + len)))) + } + } + } +} + +/// A type-safe wrapper around a single control message. More types may +/// be added to this enum; do not exhaustively pattern-match it. +/// [Further reading](http://man7.org/linux/man-pages/man3/cmsg.3.html) +pub enum ControlMessage<'a> { + /// A message of type SCM_RIGHTS, containing an array of file + /// descriptors passed between processes. See the description in the + /// "Ancillary messages" section of the + /// [unix(7) man page](http://man7.org/linux/man-pages/man7/unix.7.html). + ScmRights(&'a [RawFd]), + #[doc(hidden)] + Unknown(UnknownCmsg<'a>), +} + +// An opaque structure used to prevent cmsghdr from being a public type +#[doc(hidden)] +pub struct UnknownCmsg<'a>(&'a cmsghdr, &'a [u8]); + +fn cmsg_align(len: usize) -> usize { + let round_to = mem::size_of::(); + if len % round_to == 0 { + len + } else { + len + round_to - (len % round_to) + } +} + +impl<'a> ControlMessage<'a> { + /// The value of CMSG_SPACE on this message. + fn space(&self) -> usize { + cmsg_align(self.len()) + } + + /// The value of CMSG_LEN on this message. + fn len(&self) -> usize { + mem::size_of::() + match *self { + ControlMessage::ScmRights(fds) => { + mem::size_of_val(fds) + }, + ControlMessage::Unknown(UnknownCmsg(_, bytes)) => { + mem::size_of_val(bytes) + } + } + } + + // Unsafe: start and end of buffer must be size_t-aligned (that is, + // cmsg_align'd). Updates the provided slice; panics if the buffer + // is too small. + unsafe fn encode_into<'b>(&self, buf: &mut &'b mut [u8]) { + match *self { + ControlMessage::ScmRights(fds) => { + let cmsg = cmsghdr { + cmsg_len: self.len() as type_of_cmsg_len, + cmsg_level: SOL_SOCKET, + cmsg_type: SCM_RIGHTS, + cmsg_data: [], + }; + copy_bytes(&cmsg, buf); + copy_bytes(fds, buf); + }, + ControlMessage::Unknown(UnknownCmsg(orig_cmsg, bytes)) => { + copy_bytes(orig_cmsg, buf); + copy_bytes(bytes, buf); + } + } + } +} + + +/// Send data in scatter-gather vectors to a socket, possibly accompanied +/// by ancillary data. Optionally direct the message at the given address, +/// as with sendto. +/// +/// Allocates if cmsgs is nonempty. +pub fn sendmsg<'a>(fd: RawFd, iov: &[IoVec<&'a [u8]>], cmsgs: &[ControlMessage<'a>], flags: SockMessageFlags, addr: Option<&'a SockAddr>) -> Result { + let mut capacity = 0; + for cmsg in cmsgs { + capacity += cmsg.space(); + } + // Alignment hackery. Note that capacity is guaranteed to be a + // multiple of size_t. Note also that the resulting vector claims + // to have length == capacity, so it's presently uninitialized. + let mut cmsg_buffer = unsafe { + let mut vec = Vec::::with_capacity(capacity / mem::size_of::()); + let ptr = vec.as_mut_ptr(); + mem::forget(vec); + Vec::::from_raw_parts(ptr as *mut _, capacity, capacity) + }; + { + let mut ptr = &mut cmsg_buffer[..]; + for cmsg in cmsgs { + unsafe { cmsg.encode_into(&mut ptr) }; + } + } + + let (name, namelen) = match addr { + Some(addr) => { let (x, y) = unsafe { addr.as_ffi_pair() }; (x as *const _, y) } + None => (0 as *const _, 0), + }; + + let mhdr = msghdr { + msg_name: name as *const c_void, + msg_namelen: namelen, + msg_iov: iov.as_ptr(), + msg_iovlen: iov.len() as size_t, + msg_control: cmsg_buffer.as_ptr() as *const c_void, + msg_controllen: cmsg_buffer.len() as size_t, + msg_flags: 0, + }; + let ret = unsafe { ffi::sendmsg(fd, &mhdr, flags) }; + + if ret < 0 { + Err(Error::Sys(Errno::last())) + } else { + Ok(ret as usize) + } +} + +/// Receive message in scatter-gather vectors from a socket, and +/// optionally receive ancillary data into the provided buffer. +/// If no ancillary data is desired, use () as the type parameter. +pub fn recvmsg<'a, T>(fd: RawFd, iov: &[IoVec<&mut [u8]>], cmsg_buffer: Option<&'a mut CmsgSpace>, flags: SockMessageFlags) -> Result> { + let mut address: sockaddr_storage = unsafe { mem::uninitialized() }; + let (msg_control, msg_controllen) = match cmsg_buffer { + Some(cmsg_buffer) => (cmsg_buffer as *mut _, mem::size_of_val(cmsg_buffer)), + None => (0 as *mut _, 0), + }; + let mut mhdr = msghdr { + msg_name: &mut address as *const _ as *const c_void, + msg_namelen: mem::size_of::() as socklen_t, + msg_iov: iov.as_ptr() as *const IoVec<&[u8]>, // safe cast to add const-ness + msg_iovlen: iov.len() as size_t, + msg_control: msg_control as *const c_void, + msg_controllen: msg_controllen as size_t, + msg_flags: 0, + }; + let ret = unsafe { ffi::recvmsg(fd, &mut mhdr, flags) }; + + if ret < 0 { + Err(Error::Sys(Errno::last())) + } else { + Ok(unsafe { RecvMsg { + bytes: ret as usize, + cmsg_buffer: slice::from_raw_parts(mhdr.msg_control as *const u8, + mhdr.msg_controllen as usize), + address: sockaddr_storage_to_addr(&address, + mhdr.msg_namelen as usize).ok(), + flags: mhdr.msg_flags, + } }) + } +} + + /// Create an endpoint for communication /// /// [Further reading](http://man7.org/linux/man-pages/man2/socket.2.html) @@ -384,6 +670,10 @@ pub unsafe fn sockaddr_storage_to_addr( addr: &sockaddr_storage, len: usize) -> Result { + if len < mem::size_of_val(&addr.ss_family) { + return Err(Error::Sys(Errno::ENOTCONN)); + } + match addr.ss_family as c_int { consts::AF_INET => { assert!(len as usize == mem::size_of::()); diff --git a/test/sys/test_socket.rs b/test/sys/test_socket.rs index e8cb4bedea..7b95767e51 100644 --- a/test/sys/test_socket.rs +++ b/test/sys/test_socket.rs @@ -2,7 +2,7 @@ use nix::sys::socket::{InetAddr, UnixAddr, getsockname}; use std::{mem, net}; use std::path::Path; use std::str::FromStr; -use std::os::unix::io::AsRawFd; +use std::os::unix::io::{AsRawFd, RawFd}; use ports::localhost; #[test] @@ -63,3 +63,56 @@ pub fn test_socketpair() { assert_eq!(&buf[..], b"hello"); } + +#[test] +pub fn test_scm_rights() { + use nix::sys::uio::IoVec; + use nix::unistd::{pipe, read, write, close}; + use nix::sys::socket::{socketpair, sendmsg, recvmsg, + AddressFamily, SockType, SockFlag, + ControlMessage, CmsgSpace, + MSG_TRUNC, MSG_CTRUNC}; + + let (fd1, fd2) = socketpair(AddressFamily::Unix, SockType::Stream, 0, + SockFlag::empty()) + .unwrap(); + let (r, w) = pipe().unwrap(); + let mut received_r: Option = None; + + { + let iov = [IoVec::from_slice(b"hello")]; + let fds = [r]; + let cmsg = ControlMessage::ScmRights(&fds); + assert_eq!(sendmsg(fd1, &iov, &[cmsg], 0, None).unwrap(), 5); + close(r).unwrap(); + close(fd1).unwrap(); + } + + { + let mut buf = [0u8; 5]; + let iov = [IoVec::from_mut_slice(&mut buf[..])]; + let mut cmsgspace: CmsgSpace<[RawFd; 1]> = CmsgSpace::new(); + let msg = recvmsg(fd2, &iov, Some(&mut cmsgspace), 0).unwrap(); + + for cmsg in msg.cmsgs() { + if let ControlMessage::ScmRights(fd) = cmsg { + assert_eq!(received_r, None); + assert_eq!(fd.len(), 1); + received_r = Some(fd[0]); + } else { + panic!("unexpected cmsg"); + } + } + assert_eq!(msg.flags & (MSG_TRUNC | MSG_CTRUNC), 0); + close(fd2).unwrap(); + } + + let received_r = received_r.expect("Did not receive passed fd"); + // Ensure that the received file descriptor works + write(w, b"world").unwrap(); + let mut buf = [0u8; 5]; + read(received_r, &mut buf).unwrap(); + assert_eq!(&buf[..], b"world"); + close(received_r).unwrap(); + close(w).unwrap(); +}