Skip to content

Commit 2c6d6e5

Browse files
author
Gleb Pomykalov
committed
Improve ergonomics of recvmmsg/sendmmsg
1 parent d908dd1 commit 2c6d6e5

File tree

2 files changed

+83
-36
lines changed

2 files changed

+83
-36
lines changed

src/sys/socket/mod.rs

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,7 @@ pub fn sendmsg(fd: RawFd, iov: &[IoVec<&[u8]>], cmsgs: &[ControlMessage],
772772
// because subsequent code will not clear the padding bytes.
773773
let mut cmsg_buffer = vec![0u8; capacity];
774774

775-
unsafe { send_pack_mhdr(mhdr.as_mut_ptr(), &mut cmsg_buffer[..], iov, cmsgs, addr) };
775+
unsafe { send_pack_mhdr(mhdr.as_mut_ptr(), &mut cmsg_buffer[..], &iov, &cmsgs, addr) };
776776

777777
let mhdr = unsafe { mhdr.assume_init() };
778778

@@ -783,10 +783,15 @@ pub fn sendmsg(fd: RawFd, iov: &[IoVec<&[u8]>], cmsgs: &[ControlMessage],
783783

784784
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
785785
#[derive(Debug)]
786-
pub struct SendMmsgData<'a> {
787-
pub iov: &'a [IoVec<&'a [u8]>],
788-
pub cmsgs: &'a [ControlMessage<'a>],
789-
pub addr: Option<&'a SockAddr>,
786+
pub struct SendMmsgData<'a, I, C>
787+
where
788+
I: AsRef<[IoVec<&'a [u8]>]>,
789+
C: AsRef<[ControlMessage<'a>]>
790+
{
791+
pub iov: I,
792+
pub cmsgs: C,
793+
pub addr: Option<SockAddr>,
794+
pub _lt: std::marker::PhantomData<&'a I>,
790795
}
791796

792797
/// An extension of `sendmsg`` that allows the caller to transmit multiple
@@ -804,7 +809,11 @@ pub struct SendMmsgData<'a> {
804809
/// # References
805810
/// [sendmmsg(2)](http://man7.org/linux/man-pages/man2/sendmmsg.2.html)
806811
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
807-
pub fn sendmmsg<'a>(fd: RawFd, data: impl std::iter::IntoIterator<Item=&'a SendMmsgData<'a>>, flags: MsgFlags) -> Result<(usize, Vec<usize>)>
812+
pub fn sendmmsg<'a, I, C>(fd: RawFd, data: impl std::iter::IntoIterator<Item=&'a SendMmsgData<'a, I, C>>, flags: MsgFlags)
813+
-> Result<(usize, Vec<usize>)>
814+
where
815+
I: AsRef<[IoVec<&'a [u8]>]> + 'a,
816+
C: AsRef<[ControlMessage<'a>]> + 'a,
808817
{
809818
let iter = data.into_iter();
810819

@@ -823,11 +832,19 @@ pub fn sendmmsg<'a>(fd: RawFd, data: impl std::iter::IntoIterator<Item=&'a SendM
823832
let element = &mut output[i];
824833

825834
let cmsgs_start = cmsgs_buffer.len();
826-
let cmsgs_required_capacity: usize = d.cmsgs.iter().map(|c| c.space()).sum();
835+
let cmsgs_required_capacity: usize = d.cmsgs.as_ref().iter().map(|c| c.space()).sum();
827836
let cmsgs_buffer_need_capacity = cmsgs_start + cmsgs_required_capacity;
828837
cmsgs_buffer.resize(cmsgs_buffer_need_capacity, 0);
829838

830-
unsafe { send_pack_mhdr(&mut (*element.as_mut_ptr()).msg_hdr, &mut cmsgs_buffer[cmsgs_start..], d.iov, d.cmsgs, d.addr) };
839+
unsafe {
840+
send_pack_mhdr(
841+
&mut (*element.as_mut_ptr()).msg_hdr,
842+
&mut cmsgs_buffer[cmsgs_start..],
843+
&d.iov,
844+
&d.cmsgs,
845+
d.addr.as_ref()
846+
)
847+
};
831848
});
832849

833850
let mut initialized_data = unsafe { mem::transmute::<_, Vec<libc::mmsghdr>>(output) };
@@ -848,8 +865,11 @@ pub fn sendmmsg<'a>(fd: RawFd, data: impl std::iter::IntoIterator<Item=&'a SendM
848865

849866
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
850867
#[derive(Debug)]
851-
pub struct RecvMmsgData<'a> {
852-
pub iov: &'a [IoVec<&'a mut [u8]>],
868+
pub struct RecvMmsgData<'a, I>
869+
where
870+
I: AsRef<[IoVec<&'a mut [u8]>]> + 'a,
871+
{
872+
pub iov: I,
853873
pub cmsg_buffer: Option<&'a mut Vec<u8>>,
854874
}
855875

@@ -876,9 +896,13 @@ pub struct RecvMmsgData<'a> {
876896
/// # References
877897
/// [recvmmsg(2)](http://man7.org/linux/man-pages/man2/recvmmsg.2.html)
878898
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
879-
pub fn recvmmsg<'a>(fd: RawFd,
880-
data: impl std::iter::IntoIterator<Item=&'a mut RecvMmsgData<'a>>,
881-
flags: MsgFlags, timeout: Option<crate::sys::time::TimeSpec>) -> Result<Vec<RecvMsg<'a>>>
899+
pub fn recvmmsg<'a, I>(
900+
fd: RawFd,
901+
data: impl std::iter::IntoIterator<Item=&'a mut RecvMmsgData<'a, I>>,
902+
flags: MsgFlags, timeout: Option<crate::sys::time::TimeSpec>
903+
) -> Result<Vec<RecvMsg<'a>>>
904+
where
905+
I: AsRef<[IoVec<&'a mut [u8]>]> + 'a,
882906
{
883907
let iter = data.into_iter();
884908

@@ -899,7 +923,7 @@ pub fn recvmmsg<'a>(fd: RawFd,
899923
let msg_controllen = unsafe {
900924
recv_pack_mhdr(
901925
&mut (*element.as_mut_ptr()).msg_hdr,
902-
d.iov,
926+
d.iov.as_ref(),
903927
&mut d.cmsg_buffer,
904928
&mut address[i]
905929
)
@@ -974,15 +998,23 @@ unsafe fn recv_read_mhdr<'a, 'b>(
974998
}
975999
}
9761000

977-
unsafe fn recv_pack_mhdr(out: *mut msghdr, iov: &[IoVec<&mut [u8]>], cmsg_buffer: &mut Option<&mut Vec<u8>>, address: &mut mem::MaybeUninit<sockaddr_storage>) -> usize {
1001+
unsafe fn recv_pack_mhdr<'a, I>(
1002+
out: *mut msghdr,
1003+
iov: I,
1004+
cmsg_buffer: &mut Option<&mut Vec<u8>>,
1005+
address: &mut mem::MaybeUninit<sockaddr_storage>
1006+
) -> usize
1007+
where
1008+
I: AsRef<[IoVec<&'a mut [u8]>]> + 'a,
1009+
{
9781010
let (msg_control, msg_controllen) = cmsg_buffer.as_mut()
9791011
.map(|v| (v.as_mut_ptr(), v.capacity()))
9801012
.unwrap_or((ptr::null_mut(), 0));
9811013

9821014
(*out).msg_name = address.as_mut_ptr() as *mut c_void;
9831015
(*out).msg_namelen = mem::size_of::<sockaddr_storage>() as socklen_t;
984-
(*out).msg_iov = iov.as_ptr() as *mut iovec;
985-
(*out).msg_iovlen = iov.len() as _;
1016+
(*out).msg_iov = iov.as_ref().as_ptr() as *mut iovec;
1017+
(*out).msg_iovlen = iov.as_ref().len() as _;
9861018
(*out).msg_control = msg_control as *mut c_void;
9871019
(*out).msg_controllen = msg_controllen as _;
9881020
(*out).msg_flags = 0;
@@ -991,7 +1023,17 @@ unsafe fn recv_pack_mhdr(out: *mut msghdr, iov: &[IoVec<&mut [u8]>], cmsg_buffer
9911023
}
9921024

9931025

994-
unsafe fn send_pack_mhdr(out: *mut msghdr, cmsg_buffer: &mut [u8], iov: &[IoVec<&[u8]>], cmsgs: &[ControlMessage], addr: Option<&SockAddr>) {
1026+
unsafe fn send_pack_mhdr<'a, I, C>(
1027+
out: *mut msghdr,
1028+
cmsg_buffer: &mut [u8],
1029+
iov: I,
1030+
cmsgs: C,
1031+
addr: Option<&SockAddr>
1032+
)
1033+
where
1034+
I: AsRef<[IoVec<&'a [u8]>]>,
1035+
C: AsRef<[ControlMessage<'a>]>
1036+
{
9951037
let cmsg_capacity = cmsg_buffer.len();
9961038

9971039
// Next encode the sending address, if provided
@@ -1016,8 +1058,8 @@ unsafe fn send_pack_mhdr(out: *mut msghdr, cmsg_buffer: &mut [u8], iov: &[IoVec<
10161058
(*out).msg_namelen = namelen;
10171059
// transmute iov into a mutable pointer. sendmsg doesn't really mutate
10181060
// the buffer, but the standard says that it takes a mutable pointer
1019-
(*out).msg_iov = iov.as_ptr() as *mut _;
1020-
(*out).msg_iovlen = iov.len() as _;
1061+
(*out).msg_iov = iov.as_ref().as_ptr() as *mut _;
1062+
(*out).msg_iovlen = iov.as_ref().len() as _;
10211063
(*out).msg_control = cmsg_ptr;
10221064
(*out).msg_controllen = cmsg_capacity as _;
10231065
(*out).msg_flags = 0;
@@ -1026,7 +1068,7 @@ unsafe fn send_pack_mhdr(out: *mut msghdr, cmsg_buffer: &mut [u8], iov: &[IoVec<
10261068
// CMSG_NEXT_HDR and friends read the msg_control and msg_controllen fields.
10271069
// CMSG_FIRSTHDR is always safe
10281070
let mut pmhdr: *mut cmsghdr = CMSG_FIRSTHDR(out);
1029-
for cmsg in cmsgs {
1071+
for cmsg in cmsgs.as_ref() {
10301072
assert_ne!(pmhdr, ptr::null_mut());
10311073
// Safe because we know that pmhdr is valid, and we initialized it with
10321074
// sufficient space
@@ -1057,7 +1099,9 @@ pub fn recvmsg<'a>(fd: RawFd, iov: &'a [IoVec<&'a mut [u8]>],
10571099
let mut out = mem::MaybeUninit::<msghdr>::zeroed();
10581100
let mut address = mem::MaybeUninit::uninit();
10591101

1060-
let msg_controllen = unsafe { recv_pack_mhdr(out.as_mut_ptr(), iov, &mut cmsg_buffer, &mut address) };
1102+
let msg_controllen = unsafe {
1103+
recv_pack_mhdr(out.as_mut_ptr(), &iov, &mut cmsg_buffer, &mut address)
1104+
};
10611105

10621106
let mut mhdr = unsafe { out.assume_init() };
10631107

test/sys/test_socket.rs

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -254,30 +254,33 @@ mod recvfrom {
254254

255255
let from = sendrecv(rsock, ssock, move |s, m, flags| {
256256
let iov = [IoVec::from_slice(m)];
257-
let mut msgs = std::collections::LinkedList::new();
258-
msgs.push_back(
257+
let mut msgs = Vec::new();
258+
msgs.push(
259259
SendMmsgData {
260260
iov: &iov,
261261
cmsgs: &[],
262-
addr: Some(&sock_addr),
262+
addr: Some(sock_addr),
263+
_lt: Default::default(),
263264
});
264265
for _ in 0..15 {
265-
msgs.push_back(
266+
msgs.push(
266267
SendMmsgData {
267268
iov: &iov,
268269
cmsgs: &[],
269-
addr: Some(&sock_addr2),
270+
addr: Some(sock_addr2),
271+
_lt: Default::default(),
270272
}
271273
);
272274
}
273-
sendmmsg(s, &msgs, flags).map(move |(sent_messages, sent_bytes)| {
274-
assert!(sent_messages >= 1);
275-
assert_eq!(sent_bytes.len(), sent_messages);
276-
for sent in &sent_bytes {
277-
assert_eq!(*sent, m.len());
278-
}
279-
sent_messages
280-
})
275+
sendmmsg(s, msgs.iter(), flags)
276+
.map(move |(sent_messages, sent_bytes)| {
277+
assert!(sent_messages >= 1);
278+
assert_eq!(sent_bytes.len(), sent_messages);
279+
for sent in &sent_bytes {
280+
assert_eq!(*sent, m.len());
281+
}
282+
sent_messages
283+
})
281284
});
282285
// UDP sockets should set the from address
283286
assert_eq!(AddressFamily::Inet, from.unwrap().family());
@@ -332,7 +335,7 @@ mod recvfrom {
332335
let res = recvmmsg(rsock, &mut msgs, MsgFlags::empty(), None).expect("recvmmsg");
333336
assert_eq!(res.len(), 2);
334337

335-
for RecvMsg { address, bytes, ..} in res.into_iter() {
338+
for RecvMsg { address, bytes, .. } in res.into_iter() {
336339
assert_eq!(AddressFamily::Inet, address.unwrap().family());
337340
assert_eq!(2, bytes);
338341
}

0 commit comments

Comments
 (0)