Skip to content

Commit 93558f6

Browse files
author
Gleb Pomykalov
committed
Support sendmmsg/recvmmsg
1 parent b5ee610 commit 93558f6

File tree

3 files changed

+377
-74
lines changed

3 files changed

+377
-74
lines changed

src/sys/socket/mod.rs

Lines changed: 250 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ pub use libc::{
5050
// Needed by the cmsg_space macro
5151
#[doc(hidden)]
5252
pub use libc::{c_uint, CMSG_SPACE};
53+
use std::mem::MaybeUninit;
5354

5455
/// These constants are used to specify the communication semantics
5556
/// when creating a socket with [`socket()`](fn.socket.html)
@@ -763,61 +764,276 @@ impl<'a> ControlMessage<'a> {
763764
pub fn sendmsg(fd: RawFd, iov: &[IoVec<&[u8]>], cmsgs: &[ControlMessage],
764765
flags: MsgFlags, addr: Option<&SockAddr>) -> Result<usize>
765766
{
767+
let mut mhdr = mem::MaybeUninit::<msghdr>::zeroed();
768+
766769
let capacity = cmsgs.iter().map(|c| c.space()).sum();
767770

768771
// First size the buffer needed to hold the cmsgs. It must be zeroed,
769772
// because subsequent code will not clear the padding bytes.
770-
let cmsg_buffer = vec![0u8; capacity];
773+
let mut cmsg_buffer = vec![0u8; capacity];
774+
775+
unsafe { send_pack_mhdr(mhdr.as_mut_ptr(), &mut cmsg_buffer[..], iov, cmsgs, addr) };
776+
777+
let mhdr = unsafe { mhdr.assume_init() };
778+
779+
let ret = unsafe { libc::sendmsg(fd, &mhdr, flags.bits()) };
780+
781+
Errno::result(ret).map(|r| r as usize)
782+
}
783+
784+
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
785+
#[derive(Debug)]
786+
pub struct SendMmsgData<'a> {
787+
pub iov: &'a [IoVec<&'a [u8]>],
788+
pub cmsgs: &'a [ControlMessage<'a>],
789+
pub addr: Option<&'a SockAddr>,
790+
}
791+
792+
/// An extension of `sendmsg`` that allows the caller to transmit multiple
793+
/// messages on a socket using a single system call. (This has performance
794+
/// benefits for some applications.). Supported on Linux and FreeBSD
795+
///
796+
/// Allocations are performed for cmsgs and to build `msghdr` buffer
797+
///
798+
/// # Arguments
799+
///
800+
/// * `fd`: Socket file descriptor
801+
/// * `data`: Struct that implements `IntoIterator` with `SendMmsgData` items
802+
/// * `flags`: Optional flags passed directly to the operating system.
803+
///
804+
/// # References
805+
/// [sendmmsg(2)](http://man7.org/linux/man-pages/man2/sendmmsg.2.html)
806+
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
807+
pub fn sendmmsg<'a>(fd: RawFd, data: impl std::iter::IntoIterator<Item=&'a SendMmsgData<'a>>, flags: MsgFlags) -> Result<(usize, Vec<usize>)>
808+
{
809+
let iter = data.into_iter();
810+
811+
let (min_size, max_size) = iter.size_hint();
812+
let reserve_items = max_size.unwrap_or(min_size);
813+
814+
let mut output: Vec<MaybeUninit<libc::mmsghdr>> = vec![MaybeUninit::zeroed(); reserve_items];
815+
816+
let mut cmsgs_buffer = vec![0u8; 0];
817+
818+
iter.enumerate().for_each(|(i, d)| {
819+
if output.len() < i {
820+
output.resize(i, MaybeUninit::zeroed());
821+
}
822+
823+
let element = &mut output[i];
824+
825+
let cmsgs_start = cmsgs_buffer.len();
826+
let cmsgs_required_capacity: usize = d.cmsgs.iter().map(|c| c.space()).sum();
827+
let cmsgs_buffer_need_capacity = cmsgs_start + cmsgs_required_capacity;
828+
cmsgs_buffer.resize(cmsgs_buffer_need_capacity, 0);
829+
830+
unsafe { send_pack_mhdr(&mut (*element.as_mut_ptr()).msg_hdr, &mut cmsgs_buffer[cmsgs_start..], d.iov, d.cmsgs, d.addr) };
831+
});
832+
833+
let mut initialized_data = unsafe { mem::transmute::<_, Vec<libc::mmsghdr>>(output) };
834+
835+
let ret = unsafe { libc::sendmmsg(fd, initialized_data.as_mut_ptr(), initialized_data.len() as u32, flags.bits()) };
836+
837+
let sent_messages = Errno::result(ret)? as usize;
838+
let mut sent_bytes = Vec::with_capacity(sent_messages);
839+
unsafe { sent_bytes.set_len(sent_messages) };
840+
841+
for i in 0..sent_messages {
842+
sent_bytes[i] = initialized_data[i].msg_len as usize;
843+
}
844+
845+
Ok((sent_messages, sent_bytes))
846+
}
847+
848+
849+
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
850+
#[derive(Debug)]
851+
pub struct RecvMmsgData<'a> {
852+
pub iov: &'a [IoVec<&'a mut [u8]>],
853+
pub cmsg_buffer: Option<&'a mut Vec<u8>>,
854+
}
855+
856+
/// An extension of recvmsg(2) that allows the caller to receive multiple
857+
/// messages from a socket using a single system call. (This has
858+
/// performance benefits for some applications.)
859+
///
860+
/// `iov` and `cmsg_buffer` should be constucted similarly to recvmsg
861+
///
862+
/// Multiple allocations are performed
863+
///
864+
/// # Arguments
865+
///
866+
/// * `fd`: Socket file descriptor
867+
/// * `data`: Struct that implements `IntoIterator` with `RecvMmsgData` items
868+
/// * `flags`: Optional flags passed directly to the operating system.
869+
///
870+
/// # RecvMmsgData
871+
///
872+
/// * `iov`: Scatter-gather list of buffers to receive the message
873+
/// * `cmsg_buffer`: Space to receive ancillary data. Should be created by
874+
/// [`cmsg_space!`](macro.cmsg_space.html)
875+
///
876+
/// # References
877+
/// [recvmmsg(2)](http://man7.org/linux/man-pages/man2/recvmmsg.2.html)
878+
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
879+
pub fn recvmmsg<'a>(fd: RawFd,
880+
data: impl std::iter::IntoIterator<Item=&'a mut RecvMmsgData<'a>>,
881+
flags: MsgFlags, timeout: Option<crate::sys::time::TimeSpec>) -> Result<Vec<RecvMsg<'a>>>
882+
{
883+
let iter = data.into_iter();
884+
885+
let (min_size, max_size) = iter.size_hint();
886+
let reserve_items = max_size.unwrap_or(min_size);
887+
888+
let mut output: Vec<MaybeUninit<libc::mmsghdr>> = vec![MaybeUninit::zeroed(); reserve_items];
889+
let mut address: Vec<MaybeUninit<sockaddr_storage>> = vec![MaybeUninit::uninit(); reserve_items];
890+
891+
let results: Vec<_> = iter.enumerate().map(|(i, d)| {
892+
if output.len() < i {
893+
output.resize(i, MaybeUninit::zeroed());
894+
address.resize(i, MaybeUninit::uninit());
895+
}
896+
897+
let element = &mut output[i];
898+
899+
let msg_controllen = unsafe {
900+
recv_pack_mhdr(
901+
&mut (*element.as_mut_ptr()).msg_hdr,
902+
d.iov,
903+
&mut d.cmsg_buffer,
904+
&mut address[i]
905+
)
906+
};
907+
908+
(msg_controllen as usize, &mut d.cmsg_buffer)
909+
}).collect();
910+
911+
let mut initialized_data = unsafe { mem::transmute::<_, Vec<libc::mmsghdr>>(output) };
912+
913+
let timeout = if let Some(mut t) = timeout {
914+
t.as_mut() as *mut libc::timespec
915+
} else {
916+
ptr::null_mut()
917+
};
918+
919+
let ret = unsafe { libc::recvmmsg(fd, initialized_data.as_mut_ptr(), initialized_data.len() as u32, flags.bits(), timeout) };
920+
921+
let r = Errno::result(ret)?;
922+
923+
Ok(initialized_data
924+
.into_iter()
925+
.zip(address.into_iter())
926+
.zip(results.into_iter())
927+
.map(|((mmsghdr, address), (msg_controllen, cmsg_buffer))| {
928+
unsafe {
929+
recv_read_mhdr(
930+
mmsghdr.msg_hdr,
931+
r as isize,
932+
msg_controllen,
933+
address,
934+
cmsg_buffer
935+
)
936+
}
937+
})
938+
.collect())
939+
}
940+
941+
unsafe fn recv_read_mhdr<'a, 'b>(
942+
mhdr: msghdr,
943+
r: isize,
944+
msg_controllen: usize,
945+
address: MaybeUninit<sockaddr_storage>,
946+
cmsg_buffer: &'a mut Option<&'b mut Vec<u8>>
947+
) -> RecvMsg<'b> {
948+
let cmsghdr = {
949+
if mhdr.msg_controllen > 0 {
950+
// got control message(s)
951+
cmsg_buffer
952+
.as_mut()
953+
.unwrap()
954+
.set_len(mhdr.msg_controllen as usize);
955+
debug_assert!(!mhdr.msg_control.is_null());
956+
debug_assert!(msg_controllen >= mhdr.msg_controllen as usize);
957+
CMSG_FIRSTHDR(&mhdr as *const msghdr)
958+
} else {
959+
ptr::null()
960+
}.as_ref()
961+
};
962+
963+
let address = sockaddr_storage_to_addr(
964+
&address.assume_init(),
965+
mhdr.msg_namelen as usize
966+
).ok();
967+
968+
RecvMsg {
969+
bytes: r as usize,
970+
cmsghdr,
971+
address,
972+
flags: MsgFlags::from_bits_truncate(mhdr.msg_flags),
973+
mhdr,
974+
}
975+
}
976+
977+
unsafe fn recv_pack_mhdr(out: *mut msghdr, iov: &[IoVec<&mut [u8]>], cmsg_buffer: &mut Option<&mut Vec<u8>>, address: &mut mem::MaybeUninit<sockaddr_storage>) -> usize {
978+
let (msg_control, msg_controllen) = cmsg_buffer.as_mut()
979+
.map(|v| (v.as_mut_ptr(), v.capacity()))
980+
.unwrap_or((ptr::null_mut(), 0));
981+
982+
(*out).msg_name = address.as_mut_ptr() as *mut c_void;
983+
(*out).msg_namelen = mem::size_of::<sockaddr_storage>() as socklen_t;
984+
(*out).msg_iov = iov.as_ptr() as *mut iovec;
985+
(*out).msg_iovlen = iov.len() as _;
986+
(*out).msg_control = msg_control as *mut c_void;
987+
(*out).msg_controllen = msg_controllen as _;
988+
(*out).msg_flags = 0;
989+
990+
msg_controllen
991+
}
992+
993+
994+
unsafe fn send_pack_mhdr(out: *mut msghdr, cmsg_buffer: &mut [u8], iov: &[IoVec<&[u8]>], cmsgs: &[ControlMessage], addr: Option<&SockAddr>) {
995+
let cmsg_capacity = cmsg_buffer.len();
771996

772997
// Next encode the sending address, if provided
773998
let (name, namelen) = match addr {
774999
Some(addr) => {
775-
let (x, y) = unsafe { addr.as_ffi_pair() };
1000+
let (x, y) = addr.as_ffi_pair();
7761001
(x as *const _, y)
7771002
},
7781003
None => (ptr::null(), 0),
7791004
};
7801005

7811006
// The message header must be initialized before the individual cmsgs.
782-
let cmsg_ptr = if capacity > 0 {
1007+
let cmsg_ptr = if cmsg_capacity > 0 {
7831008
cmsg_buffer.as_ptr() as *mut c_void
7841009
} else {
7851010
ptr::null_mut()
7861011
};
7871012

788-
let mhdr = unsafe {
789-
// Musl's msghdr has private fields, so this is the only way to
790-
// initialize it.
791-
let mut mhdr = mem::MaybeUninit::<msghdr>::zeroed();
792-
let p = mhdr.as_mut_ptr();
793-
(*p).msg_name = name as *mut _;
794-
(*p).msg_namelen = namelen;
795-
// transmute iov into a mutable pointer. sendmsg doesn't really mutate
796-
// the buffer, but the standard says that it takes a mutable pointer
797-
(*p).msg_iov = iov.as_ptr() as *mut _;
798-
(*p).msg_iovlen = iov.len() as _;
799-
(*p).msg_control = cmsg_ptr;
800-
(*p).msg_controllen = capacity as _;
801-
(*p).msg_flags = 0;
802-
mhdr.assume_init()
803-
};
1013+
// Musl's msghdr has private fields, so this is the only way to
1014+
// initialize it.
1015+
(*out).msg_name = name as *mut _;
1016+
(*out).msg_namelen = namelen;
1017+
// transmute iov into a mutable pointer. sendmsg doesn't really mutate
1018+
// the buffer, but the standard says that it takes a mutable pointer
1019+
(*out).msg_iov = iov.as_ptr() as *mut _;
1020+
(*out).msg_iovlen = iov.len() as _;
1021+
(*out).msg_control = cmsg_ptr;
1022+
(*out).msg_controllen = cmsg_capacity as _;
1023+
(*out).msg_flags = 0;
8041024

8051025
// Encode each cmsg. This must happen after initializing the header because
8061026
// CMSG_NEXT_HDR and friends read the msg_control and msg_controllen fields.
8071027
// CMSG_FIRSTHDR is always safe
808-
let mut pmhdr: *mut cmsghdr = unsafe{CMSG_FIRSTHDR(&mhdr as *const msghdr)};
1028+
let mut pmhdr: *mut cmsghdr = CMSG_FIRSTHDR(out);
8091029
for cmsg in cmsgs {
8101030
assert_ne!(pmhdr, ptr::null_mut());
8111031
// Safe because we know that pmhdr is valid, and we initialized it with
8121032
// sufficient space
813-
unsafe { cmsg.encode_into(pmhdr) };
1033+
cmsg.encode_into(pmhdr);
8141034
// Safe because mhdr is valid
815-
pmhdr = unsafe{CMSG_NXTHDR(&mhdr as *const msghdr, pmhdr)};
1035+
pmhdr = CMSG_NXTHDR(out, pmhdr);
8161036
}
817-
818-
let ret = unsafe { libc::sendmsg(fd, &mhdr, flags.bits()) };
819-
820-
Errno::result(ret).map(|r| r as usize)
8211037
}
8221038

8231039
/// Receive message in scatter-gather vectors from a socket, and
@@ -834,62 +1050,22 @@ pub fn sendmsg(fd: RawFd, iov: &[IoVec<&[u8]>], cmsgs: &[ControlMessage],
8341050
///
8351051
/// # References
8361052
/// [recvmsg(2)](http://pubs.opengroup.org/onlinepubs/9699919799/functions/recvmsg.html)
837-
pub fn recvmsg<'a>(fd: RawFd, iov: &[IoVec<&mut [u8]>],
1053+
pub fn recvmsg<'a>(fd: RawFd, iov: &'a [IoVec<&'a mut [u8]>],
8381054
mut cmsg_buffer: Option<&'a mut Vec<u8>>,
8391055
flags: MsgFlags) -> Result<RecvMsg<'a>>
8401056
{
1057+
let mut out = mem::MaybeUninit::<msghdr>::zeroed();
8411058
let mut address = mem::MaybeUninit::uninit();
842-
let (msg_control, msg_controllen) = cmsg_buffer.as_mut()
843-
.map(|v| (v.as_mut_ptr(), v.capacity()))
844-
.unwrap_or((ptr::null_mut(), 0));
845-
let mut mhdr = {
846-
unsafe {
847-
// Musl's msghdr has private fields, so this is the only way to
848-
// initialize it.
849-
let mut mhdr = mem::MaybeUninit::<msghdr>::zeroed();
850-
let p = mhdr.as_mut_ptr();
851-
(*p).msg_name = address.as_mut_ptr() as *mut c_void;
852-
(*p).msg_namelen = mem::size_of::<sockaddr_storage>() as socklen_t;
853-
(*p).msg_iov = iov.as_ptr() as *mut iovec;
854-
(*p).msg_iovlen = iov.len() as _;
855-
(*p).msg_control = msg_control as *mut c_void;
856-
(*p).msg_controllen = msg_controllen as _;
857-
(*p).msg_flags = 0;
858-
mhdr.assume_init()
859-
}
860-
};
1059+
1060+
let msg_controllen = unsafe { recv_pack_mhdr(out.as_mut_ptr(), iov, &mut cmsg_buffer, &mut address) };
1061+
1062+
let mut mhdr = unsafe { out.assume_init() };
8611063

8621064
let ret = unsafe { libc::recvmsg(fd, &mut mhdr, flags.bits()) };
8631065

864-
Errno::result(ret).map(|r| {
865-
let cmsghdr = unsafe {
866-
if mhdr.msg_controllen > 0 {
867-
// got control message(s)
868-
cmsg_buffer
869-
.as_mut()
870-
.unwrap()
871-
.set_len(mhdr.msg_controllen as usize);
872-
debug_assert!(!mhdr.msg_control.is_null());
873-
debug_assert!(msg_controllen >= mhdr.msg_controllen as usize);
874-
CMSG_FIRSTHDR(&mhdr as *const msghdr)
875-
} else {
876-
ptr::null()
877-
}.as_ref()
878-
};
1066+
let r = Errno::result(ret)?;
8791067

880-
let address = unsafe {
881-
sockaddr_storage_to_addr(&address.assume_init(),
882-
mhdr.msg_namelen as usize
883-
).ok()
884-
};
885-
RecvMsg {
886-
bytes: r as usize,
887-
cmsghdr,
888-
address,
889-
flags: MsgFlags::from_bits_truncate(mhdr.msg_flags),
890-
mhdr,
891-
}
892-
})
1068+
Ok(unsafe { recv_read_mhdr(mhdr, r, msg_controllen, address, &mut cmsg_buffer) })
8931069
}
8941070

8951071

0 commit comments

Comments
 (0)