Skip to content

Commit ad39e0b

Browse files
committed
also support sendmmsg
renames: RecvMMsg -> MultHdrs RecvMMsgItems -> MultiResults Adding a lifetime reference to RecvMsg The name is not 100% correct now, it can be useful for both sending and receiving messages: to collect hardware sending timestamps you need to use control messages as well
1 parent a56d66f commit ad39e0b

File tree

2 files changed

+107
-101
lines changed

2 files changed

+107
-101
lines changed

src/sys/socket/mod.rs

Lines changed: 83 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -574,15 +574,20 @@ macro_rules! cmsg_space {
574574
}
575575

576576
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
577-
pub struct RecvMsg<'a, S> {
577+
/// Contains outcome of sending or receiving a message
578+
///
579+
/// Use [`cmsgs`][RecvMsg::cmsgs] to access all the control messages present, and
580+
/// [`iovs`][RecvMsg::iovs`] to access underlying io slices.
581+
pub struct RecvMsg<'a, 's, S> {
578582
pub bytes: usize,
579583
cmsghdr: Option<&'a cmsghdr>,
580584
pub address: Option<S>,
581585
pub flags: MsgFlags,
586+
iobufs: std::marker::PhantomData<& 's()>,
582587
mhdr: msghdr,
583588
}
584589

585-
impl<'a, S> RecvMsg<'a, S> {
590+
impl<'a, S> RecvMsg<'a, '_, S> {
586591
/// Iterate over the valid control messages pointed to by this
587592
/// msghdr.
588593
pub fn cmsgs(&self) -> CmsgIterator {
@@ -1411,24 +1416,6 @@ pub fn sendmsg<S>(fd: RawFd, iov: &[IoSlice<'_>], cmsgs: &[ControlMessage],
14111416
Errno::result(ret).map(|r| r as usize)
14121417
}
14131418

1414-
#[cfg(any(
1415-
target_os = "linux",
1416-
target_os = "android",
1417-
target_os = "freebsd",
1418-
target_os = "netbsd",
1419-
))]
1420-
#[derive(Debug)]
1421-
pub struct SendMmsgData<'a, I, C, S>
1422-
where
1423-
I: AsRef<[IoSlice<'a>]>,
1424-
C: AsRef<[ControlMessage<'a>]>,
1425-
S: SockaddrLike + 'a
1426-
{
1427-
pub iov: I,
1428-
pub cmsgs: C,
1429-
pub addr: Option<S>,
1430-
pub _lt: std::marker::PhantomData<&'a I>,
1431-
}
14321419

14331420
/// An extension of `sendmsg` that allows the caller to transmit multiple
14341421
/// messages on a socket using a single system call. This has performance
@@ -1453,51 +1440,66 @@ pub struct SendMmsgData<'a, I, C, S>
14531440
target_os = "freebsd",
14541441
target_os = "netbsd",
14551442
))]
1456-
pub fn sendmmsg<'a, I, C, S>(
1443+
pub fn sendmmsg<'a, XS, AS, C, I, S>(
14571444
fd: RawFd,
1458-
data: impl std::iter::IntoIterator<Item=&'a SendMmsgData<'a, I, C, S>>,
1445+
data: &'a mut MultHdrs<S>,
1446+
slices: XS,
1447+
// one address per group of slices
1448+
addrs: AS,
1449+
// shared across all the messages
1450+
cmsgs: C,
14591451
flags: MsgFlags
1460-
) -> Result<Vec<usize>>
1452+
) -> crate::Result<MultiResults<'a, S>>
14611453
where
1454+
XS: IntoIterator<Item = I>,
1455+
AS: AsRef<[Option<S>]>,
14621456
I: AsRef<[IoSlice<'a>]> + 'a,
14631457
C: AsRef<[ControlMessage<'a>]> + 'a,
14641458
S: SockaddrLike + 'a
14651459
{
1466-
let iter = data.into_iter();
14671460

1468-
let size_hint = iter.size_hint();
1469-
let reserve_items = size_hint.1.unwrap_or(size_hint.0);
1461+
let mut count = 0;
14701462

1471-
let mut output = Vec::<libc::mmsghdr>::with_capacity(reserve_items);
14721463

1473-
let mut cmsgs_buffers = Vec::<Vec<u8>>::with_capacity(reserve_items);
1474-
1475-
for d in iter {
1476-
let capacity: usize = d.cmsgs.as_ref().iter().map(|c| c.space()).sum();
1477-
let mut cmsgs_buffer = vec![0u8; capacity];
1464+
for (i, ((slice, addr), mmsghdr)) in slices.into_iter().zip(addrs.as_ref()).zip(data.items.iter_mut() ).enumerate() {
1465+
let mut p = &mut mmsghdr.msg_hdr;
1466+
p.msg_iov = slice.as_ref().as_ptr() as *mut libc::iovec;
1467+
p.msg_iovlen = slice.as_ref().len() as _;
14781468

1479-
output.push(libc::mmsghdr {
1480-
msg_hdr: pack_mhdr_to_send(
1481-
&mut cmsgs_buffer,
1482-
&d.iov,
1483-
&d.cmsgs,
1484-
d.addr.as_ref()
1485-
),
1486-
msg_len: 0,
1487-
});
1488-
cmsgs_buffers.push(cmsgs_buffer);
1489-
};
1469+
(*p).msg_namelen = addr.as_ref().map_or(0, S::len);
1470+
(*p).msg_name = addr.as_ref().map_or(ptr::null(), S::as_ptr) as _;
1471+
1472+
// Encode each cmsg. This must happen after initializing the header because
1473+
// CMSG_NEXT_HDR and friends read the msg_control and msg_controllen fields.
1474+
// CMSG_FIRSTHDR is always safe
1475+
let mut pmhdr: *mut cmsghdr = unsafe { CMSG_FIRSTHDR(p) };
1476+
for cmsg in cmsgs.as_ref() {
1477+
assert_ne!(pmhdr, ptr::null_mut());
1478+
// Safe because we know that pmhdr is valid, and we initialized it with
1479+
// sufficient space
1480+
unsafe { cmsg.encode_into(pmhdr) };
1481+
// Safe because mhdr is valid
1482+
pmhdr = unsafe { CMSG_NXTHDR(p, pmhdr) };
1483+
}
14901484

1491-
let ret = unsafe { libc::sendmmsg(fd, output.as_mut_ptr(), output.len() as _, flags.bits() as _) };
1485+
count = i+1;
1486+
}
14921487

1493-
let sent_messages = Errno::result(ret)? as usize;
1494-
let mut sent_bytes = Vec::with_capacity(sent_messages);
1488+
let sent = Errno::result(unsafe {
1489+
libc::sendmmsg(
1490+
fd,
1491+
data.items.as_mut_ptr(),
1492+
count as _,
1493+
flags.bits() as _
1494+
)
1495+
})? as usize;
14951496

1496-
for item in &output {
1497-
sent_bytes.push(item.msg_len as usize);
1498-
}
1497+
Ok(MultiResults {
1498+
rmm: data,
1499+
current_index: 0,
1500+
received: sent
1501+
})
14991502

1500-
Ok(sent_bytes)
15011503
}
15021504

15031505

@@ -1508,8 +1510,8 @@ pub fn sendmmsg<'a, I, C, S>(
15081510
target_os = "netbsd",
15091511
))]
15101512
#[derive(Debug)]
1511-
/// Preallocated structures needed for [`recvmmsg`] function
1512-
pub struct RecvMMsg<S> {
1513+
/// Preallocated structures needed for [`recvmmsg`] and [`sendmmsg`] functions
1514+
pub struct MultHdrs<S> {
15131515
// preallocated boxed slice of mmsghdr
15141516
items: Box<[libc::mmsghdr]>,
15151517
addresses: Box<[mem::MaybeUninit<S>]>,
@@ -1526,8 +1528,8 @@ pub struct RecvMMsg<S> {
15261528
target_os = "freebsd",
15271529
target_os = "netbsd",
15281530
))]
1529-
impl<S> RecvMMsg<S> {
1530-
/// Preallocate structure used by [`recvmmsg`], takes number of headers to preallocate
1531+
impl<S> MultHdrs<S> {
1532+
/// Preallocate structure used by [`recvmmsg`] and [`sendmmsg`] takes number of headers to preallocate
15311533
///
15321534
/// `cmsg_buffer` should be created with [`cmsg_space!`] if needed
15331535
pub fn preallocate(num_slices: usize, cmsg_buffer: Option<Vec<u8>>) -> Self
@@ -1598,21 +1600,21 @@ impl<S> RecvMMsg<S> {
15981600
))]
15991601
pub fn recvmmsg<'a, XS, S, I>(
16001602
fd: RawFd,
1601-
data: &'a mut RecvMMsg<S>,
1603+
data: &'a mut MultHdrs<S>,
16021604
slices: XS,
16031605
flags: MsgFlags,
16041606
mut timeout: Option<crate::sys::time::TimeSpec>,
1605-
) -> crate::Result<RecvMMsgItems<'a, S>>
1607+
) -> crate::Result<MultiResults<'a, S>>
16061608
where
1607-
XS: ExactSizeIterator<Item = I>,
1609+
XS: IntoIterator<Item = I>,
16081610
I: AsRef<[IoSliceMut<'a>]>,
16091611
{
1610-
let count = std::cmp::min(slices.len(), data.items.len());
1611-
1612-
for (slice, mmsghdr) in slices.zip(data.items.iter_mut()) {
1612+
let mut count = 0;
1613+
for (i, (slice, mmsghdr)) in slices.into_iter().zip(data.items.iter_mut()).enumerate() {
16131614
let mut p = &mut mmsghdr.msg_hdr;
16141615
p.msg_iov = slice.as_ref().as_ptr() as *mut libc::iovec;
16151616
p.msg_iovlen = slice.as_ref().len() as _;
1617+
count = i + 1;
16161618
}
16171619

16181620
let timeout_ptr = timeout
@@ -1629,7 +1631,7 @@ where
16291631
)
16301632
})? as usize;
16311633

1632-
Ok(RecvMMsgItems {
1634+
Ok(MultiResults {
16331635
rmm: data,
16341636
current_index: 0,
16351637
received,
@@ -1643,9 +1645,12 @@ where
16431645
target_os = "netbsd",
16441646
))]
16451647
#[derive(Debug)]
1646-
pub struct RecvMMsgItems<'a, S> {
1648+
/// Iterator over results of [`recvmmsg`]/[`sendmmsg`]
1649+
///
1650+
///
1651+
pub struct MultiResults<'a, S> {
16471652
// preallocated structures
1648-
rmm: &'a RecvMMsg<S>,
1653+
rmm: &'a MultHdrs<S>,
16491654
current_index: usize,
16501655
received: usize,
16511656
}
@@ -1656,11 +1661,11 @@ pub struct RecvMMsgItems<'a, S> {
16561661
target_os = "freebsd",
16571662
target_os = "netbsd",
16581663
))]
1659-
impl<'a, S> Iterator for RecvMMsgItems<'a, S>
1664+
impl<'a, S> Iterator for MultiResults<'a, S>
16601665
where
16611666
S: Copy + SockaddrLike,
16621667
{
1663-
type Item = RecvMsg<'a, S>;
1668+
type Item = RecvMsg<'a, 'a, S>;
16641669

16651670
fn next(&mut self) -> Option<Self::Item> {
16661671
if self.current_index >= self.received {
@@ -1684,13 +1689,17 @@ where
16841689
}
16851690
}
16861691

1687-
impl<'a, S> RecvMsg<'a, S> {
1692+
impl<'a, S> RecvMsg<'_, 'a, S> {
16881693
/// Iterate over the filled io slices pointed by this msghdr
1689-
pub fn iovs(&self) -> IoSliceIterator {
1694+
pub fn iovs(&self) -> IoSliceIterator<'a> {
16901695
IoSliceIterator {
16911696
index: 0,
16921697
remaining: self.bytes,
16931698
slices: unsafe {
1699+
// safe for as long as mgdr is properly initialized and references are valid.
1700+
// for multi messages API we initialize it with an empty
1701+
// slice and replace with a concrete buffer
1702+
// for single message API we hold a lifetime reference to ioslices
16941703
std::slice::from_raw_parts(self.mhdr.msg_iov as *const _, self.mhdr.msg_iovlen as _)
16951704
},
16961705
}
@@ -1782,7 +1791,7 @@ mod test {
17821791
let cmsg = cmsg_space!(crate::sys::socket::Timestamps);
17831792
sendmsg(ssock, &iov1, &[], flags, Some(&sock_addr)).unwrap();
17841793

1785-
let mut data = super::RecvMMsg::<()>::preallocate(recv_iovs.len(), Some(cmsg));
1794+
let mut data = super::MultHdrs::<()>::preallocate(recv_iovs.len(), Some(cmsg));
17861795

17871796
let t = sys::time::TimeSpec::from_duration(std::time::Duration::from_secs(10));
17881797

@@ -1817,12 +1826,12 @@ mod test {
18171826
Ok(())
18181827
}
18191828
}
1820-
unsafe fn read_mhdr<'a, S>(
1829+
unsafe fn read_mhdr<'a, 'i, S>(
18211830
mhdr: msghdr,
18221831
r: isize,
18231832
msg_controllen: usize,
18241833
address: S,
1825-
) -> RecvMsg<'a, S>
1834+
) -> RecvMsg<'a, 'i, S>
18261835
where S: SockaddrLike
18271836
{
18281837
let cmsghdr = {
@@ -1841,6 +1850,7 @@ unsafe fn read_mhdr<'a, S>(
18411850
address: Some(address),
18421851
flags: MsgFlags::from_bits_truncate(mhdr.msg_flags),
18431852
mhdr,
1853+
iobufs: std::marker::PhantomData,
18441854
}
18451855
}
18461856

@@ -1948,8 +1958,9 @@ fn pack_mhdr_to_send<'a, I, C, S>(
19481958
/// [recvmsg(2)](https://pubs.opengroup.org/onlinepubs/9699919799/functions/recvmsg.html)
19491959
pub fn recvmsg<'a, 'outer, 'inner, S>(fd: RawFd, iov: &'outer mut [IoSliceMut<'inner>],
19501960
mut cmsg_buffer: Option<&'a mut Vec<u8>>,
1951-
flags: MsgFlags) -> Result<RecvMsg<'a, S>>
1952-
where S: SockaddrLike + 'a
1961+
flags: MsgFlags) -> Result<RecvMsg<'a, 'inner, S>>
1962+
where S: SockaddrLike + 'a,
1963+
'inner: 'outer
19531964
{
19541965
let mut address = mem::MaybeUninit::uninit();
19551966

test/sys/test_socket.rs

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -446,36 +446,31 @@ mod recvfrom {
446446
).expect("send socket failed");
447447

448448
let from = sendrecv(rsock, ssock, move |s, m, flags| {
449-
let iov = [IoSlice::new(m)];
450-
let mut msgs = vec![
451-
SendMmsgData {
452-
iov: &iov,
453-
cmsgs: &[],
454-
addr: Some(sock_addr),
455-
_lt: Default::default(),
456-
}
457-
];
458-
459449
let batch_size = 15;
450+
let mut iovs = Vec::with_capacity(1 + batch_size);
451+
let mut addrs = Vec::with_capacity(1 + batch_size);
452+
let mut data = MultHdrs::preallocate(1 + batch_size, None);
453+
let iov = IoSlice::new(m);
454+
// first chunk:
455+
iovs.push([iov]);
456+
addrs.push(Some(sock_addr));
460457

461458
for _ in 0..batch_size {
462-
msgs.push(
463-
SendMmsgData {
464-
iov: &iov,
465-
cmsgs: &[],
466-
addr: Some(sock_addr2),
467-
_lt: Default::default(),
468-
}
469-
);
459+
iovs.push([iov]);
460+
addrs.push(Some(sock_addr2));
470461
}
471-
sendmmsg(s, msgs.iter(), flags)
472-
.map(move |sent_bytes| {
473-
assert!(!sent_bytes.is_empty());
474-
for sent in &sent_bytes {
475-
assert_eq!(*sent, m.len());
476-
}
477-
sent_bytes.len()
478-
})
462+
463+
let res = sendmmsg(s, &mut data, &iovs, addrs, &[], flags)?;
464+
let mut sent_messages = 0;
465+
let mut sent_bytes = 0;
466+
for item in res {
467+
sent_messages += 1;
468+
sent_bytes += item.bytes;
469+
}
470+
//
471+
assert_eq!(sent_messages, iovs.len());
472+
assert_eq!(sent_bytes, sent_messages * m.len());
473+
Ok(sent_messages)
479474
}, |_, _ | {});
480475
// UDP sockets should set the from address
481476
assert_eq!(AddressFamily::Inet, from.unwrap().family().unwrap());
@@ -524,7 +519,7 @@ mod recvfrom {
524519
msgs.extend(receive_buffers.iter_mut().map(|buf| {
525520
[IoSliceMut::new(&mut buf[..])]
526521
}));
527-
let mut data = RecvMMsg::<SockaddrIn>::preallocate(msgs.len(), None);
522+
let mut data = MultHdrs::<SockaddrIn>::preallocate(msgs.len(), None);
528523

529524
let res: Vec<RecvMsg<SockaddrIn>> = recvmmsg(rsock, &mut data, msgs.iter(), MsgFlags::empty(), None).expect("recvmmsg").collect();
530525
assert_eq!(res.len(), DATA.len());
@@ -590,7 +585,7 @@ mod recvfrom {
590585
[IoSliceMut::new(&mut buf[..])]
591586
}));
592587

593-
let mut data = RecvMMsg::<SockaddrIn>::preallocate(NUM_MESSAGES_SENT + 2, None);
588+
let mut data = MultHdrs::<SockaddrIn>::preallocate(NUM_MESSAGES_SENT + 2, None);
594589

595590
let res: Vec<RecvMsg<SockaddrIn>> = recvmmsg(rsock, &mut data, msgs.iter(), MsgFlags::MSG_DONTWAIT, None).expect("recvmmsg").collect();
596591
assert_eq!(res.len(), NUM_MESSAGES_SENT);
@@ -1701,7 +1696,7 @@ fn test_recvmmsg_timestampns() {
17011696
let mut buffer = vec![0u8; message.len()];
17021697
let cmsgspace = nix::cmsg_space!(TimeSpec);
17031698
let iov = vec![[IoSliceMut::new(&mut buffer)]];
1704-
let mut data = RecvMMsg::preallocate(1, Some(cmsgspace));
1699+
let mut data = MultHdrs::preallocate(1, Some(cmsgspace));
17051700
let r: Vec<RecvMsg<()>> = recvmmsg(in_socket, &mut data, iov.iter(), flags, None).unwrap().collect();
17061701
let rtime = match r[0].cmsgs().next() {
17071702
Some(ControlMessageOwned::ScmTimestampns(rtime)) => rtime,

0 commit comments

Comments
 (0)