Skip to content

Commit a758c64

Browse files
committed
Refactor Socket::send_vectored
1 parent 2497129 commit a758c64

File tree

4 files changed

+75
-49
lines changed

4 files changed

+75
-49
lines changed

src/socket.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -422,12 +422,19 @@ impl Socket {
422422
sys::send(self.inner, buf, flags)
423423
}
424424

425-
/// Identical to [`send_with_flags`] but writes from a slice of buffers.
425+
/// Send data to the connected peer. Returns the amount of bytes written.
426+
#[cfg(not(target_os = "redox"))]
427+
pub fn send_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
428+
self.send_vectored_with_flags(bufs, 0)
429+
}
430+
431+
/// Identical to [`send_vectored`] but allows for specification of arbitrary
432+
/// flags to the underlying `sendmsg`/`WSASend` call.
426433
///
427-
/// [`send_with_flags`]: #method.send_with_flags
434+
/// [`send_vectored`]: Socket::send_vectored
428435
#[cfg(not(target_os = "redox"))]
429-
pub fn send_vectored(&self, bufs: &[IoSlice<'_>], flags: i32) -> io::Result<usize> {
430-
self.inner().send_vectored(bufs, flags)
436+
pub fn send_vectored_with_flags(&self, bufs: &[IoSlice<'_>], flags: i32) -> io::Result<usize> {
437+
sys::send_vectored(self.inner, bufs, flags)
431438
}
432439

433440
/// Sends out-of-band (OOB) data on the socket to connected peer
@@ -909,6 +916,7 @@ impl Read for Socket {
909916
self.recv(buf)
910917
}
911918

919+
#[cfg(not(target_os = "redox"))]
912920
fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
913921
self.recv_vectored(bufs).map(|(n, _)| n)
914922
}
@@ -919,6 +927,7 @@ impl<'a> Read for &'a Socket {
919927
self.recv(buf)
920928
}
921929

930+
#[cfg(not(target_os = "redox"))]
922931
fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
923932
self.recv_vectored(bufs).map(|(n, _)| n)
924933
}

src/sys/unix.rs

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ fn recvmsg(
471471
msg_controllen: 0,
472472
msg_flags: 0,
473473
};
474-
syscall!(recvmsg(fd, &mut msg as *mut _, flags))
474+
syscall!(recvmsg(fd, &mut msg, flags))
475475
.map(|n| (n as usize, msg.msg_namelen, RecvFlags(msg.msg_flags)))
476476
}
477477

@@ -485,6 +485,28 @@ pub(crate) fn send(fd: SysSocket, buf: &[u8], flags: c_int) -> io::Result<usize>
485485
.map(|n| n as usize)
486486
}
487487

488+
#[cfg(not(target_os = "redox"))]
489+
pub(crate) fn send_vectored(
490+
fd: SysSocket,
491+
bufs: &[IoSlice<'_>],
492+
flags: c_int,
493+
) -> io::Result<usize> {
494+
let mut msg = libc::msghdr {
495+
msg_name: ptr::null_mut(),
496+
msg_namelen: 0,
497+
// Safety: we're creating a `*mut` pointer from a reference, which is UB
498+
// once actually used. However the OS should not write to it in the
499+
// `sendmsg` system call.
500+
msg_iov: bufs.as_ptr() as *mut _,
501+
msg_iovlen: min(bufs.len(), IovLen::MAX as usize) as IovLen,
502+
msg_control: ptr::null_mut(),
503+
msg_controllen: 0,
504+
msg_flags: 0,
505+
};
506+
507+
syscall!(sendmsg(fd, &mut msg, flags)).map(|n| n as usize)
508+
}
509+
488510
/// Unix only API.
489511
impl crate::Socket {
490512
/// Accept a new incoming connection from this listener.
@@ -632,22 +654,6 @@ impl Socket {
632654
Ok(n as usize)
633655
}
634656

635-
#[cfg(not(target_os = "redox"))]
636-
pub fn send_vectored(&self, bufs: &[IoSlice<'_>], flags: c_int) -> io::Result<usize> {
637-
let mut msg = libc::msghdr {
638-
msg_name: std::ptr::null_mut(),
639-
msg_namelen: 0,
640-
msg_iov: bufs.as_ptr() as *mut libc::iovec,
641-
msg_iovlen: bufs.len().min(IovLen::MAX as usize) as IovLen,
642-
msg_control: std::ptr::null_mut(),
643-
msg_controllen: 0,
644-
msg_flags: 0,
645-
};
646-
647-
let n = syscall!(sendmsg(self.fd, &mut msg as *mut libc::msghdr, flags))?;
648-
Ok(n as usize)
649-
}
650-
651657
#[cfg(not(target_os = "redox"))]
652658
pub fn send_to_vectored(
653659
&self,

src/sys/windows.rs

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,39 @@ pub(crate) fn send(socket: SysSocket, buf: &[u8], flags: c_int) -> io::Result<us
421421
.map(|n| n as usize)
422422
}
423423

424+
pub(crate) fn send_vectored(
425+
socket: SysSocket,
426+
bufs: &[IoSlice<'_>],
427+
flags: c_int,
428+
) -> io::Result<usize> {
429+
let mut nsent = 0;
430+
syscall!(
431+
WSASend(
432+
socket,
433+
// FIXME: From the `WSASend` docs [1]:
434+
// > For a Winsock application, once the WSASend function is called,
435+
// > the system owns these buffers and the application may not
436+
// > access them.
437+
//
438+
// So what we're doing is actually UB as `bufs` needs to be `&mut
439+
// [IoSlice<'_>]`.
440+
//
441+
// Tracking issue: https://github.com/rust-lang/socket2-rs/issues/129.
442+
//
443+
// [1] https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasend
444+
bufs.as_ptr() as *mut _,
445+
min(bufs.len(), DWORD::max_value() as usize) as DWORD,
446+
&mut nsent,
447+
flags as DWORD,
448+
std::ptr::null_mut(),
449+
None,
450+
),
451+
PartialEq::eq,
452+
sock::SOCKET_ERROR
453+
)
454+
.map(|_| nsent as usize)
455+
}
456+
424457
/// Caller must ensure `T` is the correct type for `opt` and `val`.
425458
unsafe fn getsockopt<T>(socket: SysSocket, opt: c_int, val: c_int) -> io::Result<T> {
426459
let mut payload: MaybeUninit<T> = MaybeUninit::uninit();
@@ -494,25 +527,6 @@ impl Socket {
494527
}
495528
}
496529

497-
pub fn send_vectored(&self, bufs: &[IoSlice<'_>], flags: c_int) -> io::Result<usize> {
498-
let mut nsent = 0;
499-
let ret = unsafe {
500-
sock::WSASend(
501-
self.socket,
502-
bufs.as_ptr() as *mut WSABUF,
503-
bufs.len().min(DWORD::MAX as usize) as DWORD,
504-
&mut nsent,
505-
flags as DWORD,
506-
std::ptr::null_mut(),
507-
None,
508-
)
509-
};
510-
match ret {
511-
0 => Ok(nsent as usize),
512-
_ => Err(last_error()),
513-
}
514-
}
515-
516530
pub fn send_to_vectored(
517531
&self,
518532
bufs: &[IoSlice<'_>],

src/tests.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,12 @@ fn send_recv_vectored() {
157157
let (socket_a, socket_b) = udp_pair_connected();
158158

159159
let sent = socket_a
160-
.send_vectored(
161-
&[
162-
IoSlice::new(b"the"),
163-
IoSlice::new(b"weeknight"),
164-
IoSlice::new(b"would"),
165-
IoSlice::new(b"yellow"),
166-
],
167-
0,
168-
)
160+
.send_vectored(&[
161+
IoSlice::new(b"the"),
162+
IoSlice::new(b"weeknight"),
163+
IoSlice::new(b"would"),
164+
IoSlice::new(b"yellow"),
165+
])
169166
.unwrap();
170167
assert_eq!(sent, 23);
171168

0 commit comments

Comments
 (0)