Skip to content

Commit cbeb0d2

Browse files
author
Gleb Pomykalov
committed
Reduce unsafety
1 parent 101a405 commit cbeb0d2

File tree

1 file changed

+74
-73
lines changed

1 file changed

+74
-73
lines changed

src/sys/socket/mod.rs

Lines changed: 74 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -774,17 +774,13 @@ impl<'a> ControlMessage<'a> {
774774
pub fn sendmsg(fd: RawFd, iov: &[IoVec<&[u8]>], cmsgs: &[ControlMessage],
775775
flags: MsgFlags, addr: Option<&SockAddr>) -> Result<usize>
776776
{
777-
let mut mhdr = mem::MaybeUninit::<msghdr>::zeroed();
778-
779777
let capacity = cmsgs.iter().map(|c| c.space()).sum();
780778

781779
// First size the buffer needed to hold the cmsgs. It must be zeroed,
782780
// because subsequent code will not clear the padding bytes.
783781
let mut cmsg_buffer = vec![0u8; capacity];
784782

785-
pack_mhdr_to_send(mhdr.as_mut_ptr(), &mut cmsg_buffer[..], &iov, &cmsgs, addr);
786-
787-
let mhdr = unsafe { mhdr.assume_init() };
783+
let mhdr = pack_mhdr_to_send(&mut cmsg_buffer[..], &iov, &cmsgs, addr);
788784

789785
let ret = unsafe { libc::sendmsg(fd, &mhdr, flags.bits()) };
790786

@@ -836,8 +832,7 @@ pub struct SendMmsgData<'a, I, C>
836832
))]
837833
pub fn sendmmsg<'a, I, C>(
838834
fd: RawFd,
839-
data: impl std::iter::IntoIterator<Item=&'a SendMmsgData<'a, I, C>,
840-
IntoIter=impl ExactSizeIterator + Iterator<Item=&'a SendMmsgData<'a, I, C>>>,
835+
data: impl std::iter::IntoIterator<Item=&'a SendMmsgData<'a, I, C>>,
841836
flags: MsgFlags
842837
) -> Result<Vec<usize>>
843838
where
@@ -846,31 +841,29 @@ pub fn sendmmsg<'a, I, C>(
846841
{
847842
let iter = data.into_iter();
848843

849-
let num_messages = iter.len();
844+
let size_hint = iter.size_hint();
845+
let reserve_items = size_hint.1.unwrap_or(size_hint.0);
850846

851-
let mut output = Vec::<libc::mmsghdr>::with_capacity(num_messages);
852-
unsafe {
853-
output.set_len(num_messages);
854-
}
847+
let mut output = Vec::<libc::mmsghdr>::with_capacity(reserve_items);
855848

856849
let mut cmsgs_buffer = vec![0u8; 0];
857850

858-
iter.enumerate().for_each(|(i, d)| {
859-
let element = &mut output[i];
860-
851+
for d in iter {
861852
let cmsgs_start = cmsgs_buffer.len();
862853
let cmsgs_required_capacity: usize = d.cmsgs.as_ref().iter().map(|c| c.space()).sum();
863854
let cmsgs_buffer_need_capacity = cmsgs_start + cmsgs_required_capacity;
864855
cmsgs_buffer.resize(cmsgs_buffer_need_capacity, 0);
865856

866-
pack_mhdr_to_send(
867-
&mut element.msg_hdr,
868-
&mut cmsgs_buffer[cmsgs_start..],
869-
&d.iov,
870-
&d.cmsgs,
871-
d.addr.as_ref()
872-
);
873-
});
857+
output.push(libc::mmsghdr {
858+
msg_hdr: pack_mhdr_to_send(
859+
&mut cmsgs_buffer[cmsgs_start..],
860+
&d.iov,
861+
&d.cmsgs,
862+
d.addr.as_ref()
863+
),
864+
msg_len: 0,
865+
});
866+
};
874867

875868
let ret = unsafe { libc::sendmmsg(fd, output.as_mut_ptr(), output.len() as _, flags.bits() as _) };
876869

@@ -947,25 +940,27 @@ pub fn recvmmsg<'a, I>(
947940
let num_messages = iter.len();
948941

949942
let mut output: Vec<libc::mmsghdr> = Vec::with_capacity(num_messages);
950-
let mut address: Vec<sockaddr_storage> = Vec::with_capacity(num_messages);
951943

952-
unsafe {
953-
output.set_len(num_messages);
954-
address.set_len(num_messages);
955-
}
944+
// Addresses should be pre-allocated and never change the address during building
945+
// of the input data for `recvmmsg`
946+
let mut addresses: Vec<sockaddr_storage> = vec![unsafe { mem::zeroed() }; num_messages];
956947

957948
let results: Vec<_> = iter.enumerate().map(|(i, d)| {
958-
let element = &mut output[i];
959-
960-
let msg_controllen = unsafe {
949+
let (msg_controllen, mhdr) = unsafe {
961950
pack_mhdr_to_receive(
962-
&mut element.msg_hdr,
963951
d.iov.as_ref(),
964952
&mut d.cmsg_buffer,
965-
&mut address[i]
953+
&mut addresses[i],
966954
)
967955
};
968956

957+
output.push(
958+
libc::mmsghdr {
959+
msg_hdr: mhdr,
960+
msg_len: 0,
961+
}
962+
);
963+
969964
(msg_controllen as usize, &mut d.cmsg_buffer)
970965
}).collect();
971966

@@ -981,15 +976,15 @@ pub fn recvmmsg<'a, I>(
981976

982977
Ok(output
983978
.into_iter()
984-
.zip(address.into_iter())
979+
.zip(addresses.into_iter())
985980
.zip(results.into_iter())
986-
.map(|((mmsghdr, mut address), (msg_controllen, cmsg_buffer))| {
981+
.map(|((mmsghdr, address), (msg_controllen, cmsg_buffer))| {
987982
unsafe {
988983
read_mhdr(
989984
mmsghdr.msg_hdr,
990985
r as isize,
991986
msg_controllen,
992-
&mut address,
987+
address,
993988
cmsg_buffer
994989
)
995990
}
@@ -1001,7 +996,7 @@ unsafe fn read_mhdr<'a, 'b>(
1001996
mhdr: msghdr,
1002997
r: isize,
1003998
msg_controllen: usize,
1004-
address: *mut sockaddr_storage,
999+
address: sockaddr_storage,
10051000
cmsg_buffer: &'a mut Option<&'b mut Vec<u8>>
10061001
) -> RecvMsg<'b> {
10071002
let cmsghdr = {
@@ -1020,7 +1015,7 @@ unsafe fn read_mhdr<'a, 'b>(
10201015
};
10211016

10221017
let address = sockaddr_storage_to_addr(
1023-
&*address ,
1018+
&address ,
10241019
mhdr.msg_namelen as usize
10251020
).ok();
10261021

@@ -1034,42 +1029,46 @@ unsafe fn read_mhdr<'a, 'b>(
10341029
}
10351030

10361031
unsafe fn pack_mhdr_to_receive<'a, I>(
1037-
out: *mut msghdr,
10381032
iov: I,
10391033
cmsg_buffer: &mut Option<&mut Vec<u8>>,
10401034
address: *mut sockaddr_storage,
1041-
) -> usize
1035+
) -> (usize, msghdr)
10421036
where
10431037
I: AsRef<[IoVec<&'a mut [u8]>]> + 'a,
10441038
{
10451039
let (msg_control, msg_controllen) = cmsg_buffer.as_mut()
10461040
.map(|v| (v.as_mut_ptr(), v.capacity()))
10471041
.unwrap_or((ptr::null_mut(), 0));
10481042

1049-
(*out).msg_name = address as *mut c_void;
1050-
(*out).msg_namelen = mem::size_of::<sockaddr_storage>() as socklen_t;
1051-
(*out).msg_iov = iov.as_ref().as_ptr() as *mut iovec;
1052-
(*out).msg_iovlen = iov.as_ref().len() as _;
1053-
(*out).msg_control = msg_control as *mut c_void;
1054-
(*out).msg_controllen = msg_controllen as _;
1055-
(*out).msg_flags = 0;
1043+
let mhdr = {
1044+
// Musl's msghdr has private fields, so this is the only way to
1045+
// initialize it.
1046+
let mut mhdr = mem::MaybeUninit::<msghdr>::zeroed();
1047+
let p = mhdr.as_mut_ptr();
1048+
(*p).msg_name = address as *mut c_void;
1049+
(*p).msg_namelen = mem::size_of::<sockaddr_storage>() as socklen_t;
1050+
(*p).msg_iov = iov.as_ref().as_ptr() as *mut iovec;
1051+
(*p).msg_iovlen = iov.as_ref().len() as _;
1052+
(*p).msg_control = msg_control as *mut c_void;
1053+
(*p).msg_controllen = msg_controllen as _;
1054+
(*p).msg_flags = 0;
1055+
mhdr.assume_init()
1056+
};
10561057

1057-
msg_controllen
1058+
(msg_controllen, mhdr)
10581059
}
10591060

1060-
10611061
fn pack_mhdr_to_send<'a, I, C>(
1062-
out: *mut msghdr,
10631062
cmsg_buffer: &mut [u8],
10641063
iov: I,
10651064
cmsgs: C,
10661065
addr: Option<&SockAddr>
1067-
)
1066+
) -> msghdr
10681067
where
10691068
I: AsRef<[IoVec<&'a [u8]>]>,
10701069
C: AsRef<[ControlMessage<'a>]>
10711070
{
1072-
let cmsg_capacity = cmsg_buffer.len();
1071+
let capacity = cmsg_buffer.len();
10731072

10741073
// Next encode the sending address, if provided
10751074
let (name, namelen) = match addr {
@@ -1081,38 +1080,43 @@ fn pack_mhdr_to_send<'a, I, C>(
10811080
};
10821081

10831082
// The message header must be initialized before the individual cmsgs.
1084-
let cmsg_ptr = if cmsg_capacity > 0 {
1083+
let cmsg_ptr = if capacity > 0 {
10851084
cmsg_buffer.as_ptr() as *mut c_void
10861085
} else {
10871086
ptr::null_mut()
10881087
};
10891088

1090-
// Musl's msghdr has private fields, so this is the only way to
1091-
// initialize it.
1092-
unsafe {
1093-
(*out).msg_name = name as *mut _;
1094-
(*out).msg_namelen = namelen;
1089+
let mhdr = unsafe {
1090+
// Musl's msghdr has private fields, so this is the only way to
1091+
// initialize it.
1092+
let mut mhdr = mem::MaybeUninit::<msghdr>::zeroed();
1093+
let p = mhdr.as_mut_ptr();
1094+
(*p).msg_name = name as *mut _;
1095+
(*p).msg_namelen = namelen;
10951096
// transmute iov into a mutable pointer. sendmsg doesn't really mutate
10961097
// the buffer, but the standard says that it takes a mutable pointer
1097-
(*out).msg_iov = iov.as_ref().as_ptr() as *mut _;
1098-
(*out).msg_iovlen = iov.as_ref().len() as _;
1099-
(*out).msg_control = cmsg_ptr;
1100-
(*out).msg_controllen = cmsg_capacity as _;
1101-
(*out).msg_flags = 0;
1102-
}
1098+
(*p).msg_iov = iov.as_ref().as_ptr() as *mut _;
1099+
(*p).msg_iovlen = iov.as_ref().len() as _;
1100+
(*p).msg_control = cmsg_ptr;
1101+
(*p).msg_controllen = capacity as _;
1102+
(*p).msg_flags = 0;
1103+
mhdr.assume_init()
1104+
};
11031105

11041106
// Encode each cmsg. This must happen after initializing the header because
11051107
// CMSG_NEXT_HDR and friends read the msg_control and msg_controllen fields.
11061108
// CMSG_FIRSTHDR is always safe
1107-
let mut pmhdr: *mut cmsghdr = unsafe { CMSG_FIRSTHDR(out) };
1109+
let mut pmhdr: *mut cmsghdr = unsafe { CMSG_FIRSTHDR(&mhdr as *const msghdr) };
11081110
for cmsg in cmsgs.as_ref() {
11091111
assert_ne!(pmhdr, ptr::null_mut());
11101112
// Safe because we know that pmhdr is valid, and we initialized it with
11111113
// sufficient space
1112-
unsafe { cmsg.encode_into(pmhdr); }
1114+
unsafe { cmsg.encode_into(pmhdr) };
11131115
// Safe because mhdr is valid
1114-
pmhdr = unsafe { CMSG_NXTHDR(out, pmhdr) };
1116+
pmhdr = unsafe { CMSG_NXTHDR(&mhdr as *const msghdr, pmhdr) };
11151117
}
1118+
1119+
mhdr
11161120
}
11171121

11181122
/// Receive message in scatter-gather vectors from a socket, and
@@ -1133,20 +1137,17 @@ pub fn recvmsg<'a>(fd: RawFd, iov: &[IoVec<&mut [u8]>],
11331137
mut cmsg_buffer: Option<&'a mut Vec<u8>>,
11341138
flags: MsgFlags) -> Result<RecvMsg<'a>>
11351139
{
1136-
let mut out = mem::MaybeUninit::<msghdr>::zeroed();
11371140
let mut address = mem::MaybeUninit::uninit();
11381141

1139-
let msg_controllen = unsafe {
1140-
pack_mhdr_to_receive(out.as_mut_ptr(), &iov, &mut cmsg_buffer, address.as_mut_ptr())
1142+
let (msg_controllen, mut mhdr) = unsafe {
1143+
pack_mhdr_to_receive(&iov, &mut cmsg_buffer, address.as_mut_ptr())
11411144
};
11421145

1143-
let mut mhdr = unsafe { out.assume_init() };
1144-
11451146
let ret = unsafe { libc::recvmsg(fd, &mut mhdr, flags.bits()) };
11461147

11471148
let r = Errno::result(ret)?;
11481149

1149-
Ok(unsafe { read_mhdr(mhdr, r, msg_controllen, address.as_mut_ptr(), &mut cmsg_buffer) })
1150+
Ok(unsafe { read_mhdr(mhdr, r, msg_controllen, address.assume_init(), &mut cmsg_buffer) })
11501151
}
11511152

11521153

0 commit comments

Comments
 (0)