Skip to content

Commit e49bd3a

Browse files
committed
Refactor Socket::recv_from_vectored
And add Socket::recv_from_vectored_with_flags.
1 parent 758355f commit e49bd3a

File tree

4 files changed

+106
-82
lines changed

4 files changed

+106
-82
lines changed

src/socket.rs

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ impl Socket {
310310
}
311311

312312
/// Receives data on the socket from the remote address to which it is
313-
/// connected. Unlike [`recv`] that allows passing multiple buffers.
313+
/// connected. Unlike [`recv`] this allows passing multiple buffers.
314314
///
315315
/// The [`connect`] method will connect this socket to a remote address.
316316
/// This method might fail if the socket is not connected.
@@ -367,19 +367,30 @@ impl Socket {
367367
sys::recv_from(self.inner, buf, flags)
368368
}
369369

370-
/// Identical to [`recv_from_with_flags`] but reads into a slice of buffers.
370+
/// Receives data from the socket. Returns the amount of bytes read, the
371+
/// [`RecvFlags`] and the remote address from the data is coming. Unlike
372+
/// [`recv_from`] this allows passing multiple buffers.
371373
///
372-
/// In addition to the number of bytes read, this function returns the flags for the received message.
373-
/// See [`RecvFlags`] for more information about the flags.
374-
///
375-
/// [`recv_from_with_flags`]: #method.recv_from_with_flags
374+
/// [`recv_from`]: Socket::recv_from
376375
#[cfg(not(target_os = "redox"))]
377376
pub fn recv_from_vectored(
378377
&self,
379378
bufs: &mut [IoSliceMut<'_>],
379+
) -> io::Result<(usize, RecvFlags, SockAddr)> {
380+
self.recv_from_vectored_with_flags(bufs, 0)
381+
}
382+
383+
/// Identical to [`recv_from_vectored`] but allows for specification of
384+
/// arbitrary flags to the underlying `recvmsg`/`WSARecvFrom` call.
385+
///
386+
/// [`recv_from_vectored`]: Socket::recv_from_vectored
387+
#[cfg(not(target_os = "redox"))]
388+
pub fn recv_from_vectored_with_flags(
389+
&self,
390+
bufs: &mut [IoSliceMut<'_>],
380391
flags: i32,
381392
) -> io::Result<(usize, RecvFlags, SockAddr)> {
382-
self.inner().recv_from_vectored(bufs, flags)
393+
sys::recv_from_vectored(self.inner, bufs, flags)
383394
}
384395

385396
/// Receives data from the socket, without removing it from the queue.

src/sys/unix.rs

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -431,16 +431,48 @@ pub(crate) fn recv_vectored(
431431
bufs: &mut [IoSliceMut<'_>],
432432
flags: c_int,
433433
) -> io::Result<(usize, RecvFlags)> {
434+
recvmsg(fd, ptr::null_mut(), bufs, flags).map(|(n, _, recv_flags)| (n, recv_flags))
435+
}
436+
437+
#[cfg(not(target_os = "redox"))]
438+
pub fn recv_from_vectored(
439+
fd: SysSocket,
440+
bufs: &mut [IoSliceMut<'_>],
441+
flags: c_int,
442+
) -> io::Result<(usize, RecvFlags, SockAddr)> {
443+
let mut storage: MaybeUninit<libc::sockaddr_storage> = MaybeUninit::zeroed();
444+
recvmsg(fd, storage.as_mut_ptr(), bufs, flags).map(|(n, addrlen, recv_flags)| {
445+
// Safety: `recvmsg` wrote an address of `addrlen` bytes for us. The
446+
// remaining bytes are initialised to zero (which is valid for
447+
// `sockaddr_storage`).
448+
let addr = SockAddr::from_raw(unsafe { storage.assume_init() }, addrlen);
449+
(n as usize, recv_flags, addr)
450+
})
451+
}
452+
453+
/// Returns the (bytes received, sending address len, `RecvFlags`).
454+
fn recvmsg(
455+
fd: SysSocket,
456+
msg_name: *mut sockaddr_storage,
457+
bufs: &mut [IoSliceMut<'_>],
458+
flags: c_int,
459+
) -> io::Result<(usize, libc::socklen_t, RecvFlags)> {
460+
let msg_namelen = if msg_name.is_null() {
461+
0
462+
} else {
463+
size_of::<libc::sockaddr_storage>() as libc::socklen_t
464+
};
434465
let mut msg = libc::msghdr {
435-
msg_name: ptr::null_mut(),
436-
msg_namelen: 0,
466+
msg_name: msg_name.cast(),
467+
msg_namelen,
437468
msg_iov: bufs.as_mut_ptr().cast(),
438469
msg_iovlen: min(bufs.len(), IovLen::MAX as usize) as IovLen,
439470
msg_control: ptr::null_mut(),
440471
msg_controllen: 0,
441472
msg_flags: 0,
442473
};
443-
syscall!(recvmsg(fd, &mut msg as *mut _, flags)).map(|n| (n as usize, RecvFlags(msg.msg_flags)))
474+
syscall!(recvmsg(fd, &mut msg as *mut _, flags))
475+
.map(|n| (n as usize, msg.msg_namelen, RecvFlags(msg.msg_flags)))
444476
}
445477

446478
/// Unix only API.
@@ -578,29 +610,6 @@ pub struct Socket {
578610
}
579611

580612
impl Socket {
581-
#[cfg(not(target_os = "redox"))]
582-
pub fn recv_from_vectored(
583-
&self,
584-
bufs: &mut [IoSliceMut<'_>],
585-
flags: c_int,
586-
) -> io::Result<(usize, RecvFlags, SockAddr)> {
587-
let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() };
588-
let mut msg = libc::msghdr {
589-
msg_name: &mut storage as *mut libc::sockaddr_storage as *mut c_void,
590-
msg_namelen: mem::size_of_val(&storage) as socklen_t,
591-
msg_iov: bufs.as_mut_ptr().cast(),
592-
msg_iovlen: bufs.len().min(IovLen::MAX as usize) as IovLen,
593-
msg_control: std::ptr::null_mut(),
594-
msg_controllen: 0,
595-
msg_flags: 0,
596-
};
597-
598-
let n = syscall!(recvmsg(self.fd, &mut msg as *mut _, flags))?;
599-
let addr =
600-
unsafe { SockAddr::from_raw_parts(&storage as *const _ as *const _, msg.msg_namelen) };
601-
Ok((n as usize, RecvFlags(msg.msg_flags), addr))
602-
}
603-
604613
pub fn send(&self, buf: &[u8], flags: c_int) -> io::Result<usize> {
605614
let n = syscall!(send(
606615
self.fd,

src/sys/windows.rs

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,52 @@ pub(crate) fn recv_from(
361361
}
362362
}
363363

364+
pub(crate) fn recv_from_vectored(
365+
socket: SysSocket,
366+
bufs: &mut [IoSliceMut<'_>],
367+
flags: c_int,
368+
) -> io::Result<(usize, RecvFlags, SockAddr)> {
369+
let mut storage: MaybeUninit<SOCKADDR_STORAGE> = MaybeUninit::zeroed();
370+
let mut addrlen = size_of_val(&storage) as socklen_t;
371+
let mut nread = 0;
372+
let mut flags = flags as DWORD;
373+
let res = syscall!(
374+
WSARecvFrom(
375+
socket,
376+
bufs.as_mut_ptr().cast(),
377+
min(bufs.len(), DWORD::max_value() as usize) as DWORD,
378+
&mut nread,
379+
&mut flags,
380+
storage.as_mut_ptr().cast(),
381+
&mut addrlen,
382+
ptr::null_mut(),
383+
None,
384+
),
385+
PartialEq::eq,
386+
sock::SOCKET_ERROR
387+
);
388+
match res {
389+
Ok(_) => {
390+
// Safety: `WSARecvFrom` wrote an address of `addrlen` bytes for us.
391+
// The remaining bytes are initialised to zero (which is valid for
392+
// `sockaddr_storage`).
393+
let addr = SockAddr::from_raw(unsafe { storage.assume_init() }, addrlen);
394+
Ok((nread as usize, RecvFlags(0), addr))
395+
}
396+
Err(ref err) if err.raw_os_error() == Some(sock::WSAESHUTDOWN as i32) => {
397+
// Safety: see above.
398+
let addr = SockAddr::from_raw(unsafe { storage.assume_init() }, addrlen);
399+
Ok((nread as usize, RecvFlags(0), addr))
400+
}
401+
Err(ref err) if err.raw_os_error() == Some(sock::WSAEMSGSIZE as i32) => {
402+
// Safety: see above.
403+
let addr = SockAddr::from_raw(unsafe { storage.assume_init() }, addrlen);
404+
Ok((nread as usize, RecvFlags(MSG_TRUNC), addr))
405+
}
406+
Err(err) => Err(err),
407+
}
408+
}
409+
364410
/// Caller must ensure `T` is the correct type for `opt` and `val`.
365411
unsafe fn getsockopt<T>(socket: SysSocket, opt: c_int, val: c_int) -> io::Result<T> {
366412
let mut payload: MaybeUninit<T> = MaybeUninit::uninit();
@@ -414,45 +460,6 @@ pub struct Socket {
414460
}
415461

416462
impl Socket {
417-
pub fn recv_from_vectored(
418-
&self,
419-
bufs: &mut [IoSliceMut<'_>],
420-
flags: c_int,
421-
) -> io::Result<(usize, RecvFlags, SockAddr)> {
422-
let mut nread = 0;
423-
let mut flags = flags as DWORD;
424-
let mut storage: SOCKADDR_STORAGE = unsafe { mem::zeroed() };
425-
let mut addrlen = mem::size_of_val(&storage) as c_int;
426-
let ret = unsafe {
427-
sock::WSARecvFrom(
428-
self.socket,
429-
bufs.as_mut_ptr() as *mut WSABUF,
430-
bufs.len().min(DWORD::MAX as usize) as DWORD,
431-
&mut nread,
432-
&mut flags,
433-
&mut storage as *mut SOCKADDR_STORAGE as *mut SOCKADDR,
434-
&mut addrlen,
435-
ptr::null_mut(),
436-
None,
437-
)
438-
};
439-
440-
let flags;
441-
if ret == 0 {
442-
flags = RecvFlags(0);
443-
} else {
444-
let error = last_error();
445-
if error.raw_os_error() == Some(sock::WSAEMSGSIZE) {
446-
flags = RecvFlags(MSG_TRUNC)
447-
} else {
448-
return Err(error);
449-
}
450-
}
451-
452-
let addr = unsafe { SockAddr::from_raw_parts(&storage as *const _ as *const _, addrlen) };
453-
Ok((nread as usize, flags, addr))
454-
}
455-
456463
pub fn send(&self, buf: &[u8], flags: c_int) -> io::Result<usize> {
457464
unsafe {
458465
let n = {

src/tests.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -228,15 +228,12 @@ fn send_from_recv_to_vectored() {
228228
let mut men = [0u8; 3];
229229
let mut swear = [0u8; 5];
230230
let (received, flags, addr) = socket_b
231-
.recv_from_vectored(
232-
&mut [
233-
IoSliceMut::new(&mut surgeon),
234-
IoSliceMut::new(&mut has),
235-
IoSliceMut::new(&mut men),
236-
IoSliceMut::new(&mut swear),
237-
],
238-
0,
239-
)
231+
.recv_from_vectored(&mut [
232+
IoSliceMut::new(&mut surgeon),
233+
IoSliceMut::new(&mut has),
234+
IoSliceMut::new(&mut men),
235+
IoSliceMut::new(&mut swear),
236+
])
240237
.unwrap();
241238

242239
assert_eq!(received, 18);
@@ -291,7 +288,7 @@ fn recv_from_vectored_truncated() {
291288
let mut buffer = [0u8; 24];
292289

293290
let (received, flags, addr) = socket_b
294-
.recv_from_vectored(&mut [IoSliceMut::new(&mut buffer)], 0)
291+
.recv_from_vectored(&mut [IoSliceMut::new(&mut buffer)])
295292
.unwrap();
296293
assert_eq!(received, 24);
297294
assert_eq!(flags.is_truncated(), true);

0 commit comments

Comments
 (0)