Skip to content

Commit e887d43

Browse files
committed
Refactor Socket::try_clone
1 parent 0c002e8 commit e887d43

File tree

3 files changed

+43
-85
lines changed

3 files changed

+43
-85
lines changed

src/socket.rs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -223,14 +223,19 @@ impl Socket {
223223

224224
/// Creates a new independently owned handle to the underlying socket.
225225
///
226-
/// The returned `TcpStream` is a reference to the same stream that this
227-
/// object references. Both handles will read and write the same stream of
228-
/// data, and options set on one stream will be propagated to the other
229-
/// stream.
226+
/// # Notes
227+
///
228+
/// On Unix this uses `F_DUPFD_CLOEXEC` and thus sets the `FD_CLOEXEC` on
229+
/// the returned socket.
230+
///
231+
/// On Windows this uses `WSA_FLAG_NO_HANDLE_INHERIT` setting inheriting to
232+
/// false.
233+
///
234+
/// On Windows this can **not** be used function cannot be used on a
235+
/// QOS-enabled socket, see
236+
/// https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsaduplicatesocketw.
230237
pub fn try_clone(&self) -> io::Result<Socket> {
231-
self.inner()
232-
.try_clone()
233-
.map(|s| Socket { inner: s.inner() })
238+
sys::try_clone(self.inner).map(|inner| Socket { inner })
234239
}
235240

236241
/// Get the value of the `SO_ERROR` option on this socket.

src/sys/unix.rs

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream};
2020
use std::path::Path;
2121
#[cfg(feature = "all")]
2222
use std::ptr;
23-
use std::sync::atomic::{AtomicBool, Ordering};
2423
use std::time::Duration;
2524
use std::{cmp, fmt, io};
2625

@@ -346,6 +345,10 @@ pub(crate) fn getpeername(fd: SysSocket) -> io::Result<SockAddr> {
346345
.map(|_| unsafe { SockAddr::from_raw_parts(&storage as *const _ as *const _, len) })
347346
}
348347

348+
pub(crate) fn try_clone(fd: SysSocket) -> io::Result<SysSocket> {
349+
syscall!(fcntl(fd, libc::F_DUPFD_CLOEXEC, 0))
350+
}
351+
349352
/// Unix only API.
350353
impl crate::Socket {
351354
/// Accept a new incoming connection from this listener.
@@ -408,6 +411,7 @@ impl crate::Socket {
408411
}
409412
}
410413

414+
/// Add `flag` to the current set flags of `F_GETFD`.
411415
fn fcntl_add(fd: SysSocket, flag: c_int) -> io::Result<()> {
412416
let previous = syscall!(fcntl(fd, libc::F_GETFD))?;
413417
let new = previous | flag;
@@ -441,35 +445,6 @@ pub struct Socket {
441445
}
442446

443447
impl Socket {
444-
pub fn try_clone(&self) -> io::Result<Socket> {
445-
// implementation lifted from libstd
446-
#[cfg(any(target_os = "android", target_os = "haiku"))]
447-
use libc::F_DUPFD as F_DUPFD_CLOEXEC;
448-
#[cfg(not(any(target_os = "android", target_os = "haiku")))]
449-
use libc::F_DUPFD_CLOEXEC;
450-
451-
static CLOEXEC_FAILED: AtomicBool = AtomicBool::new(false);
452-
if !CLOEXEC_FAILED.load(Ordering::Relaxed) {
453-
match syscall!(fcntl(self.fd, F_DUPFD_CLOEXEC, 0)) {
454-
Ok(fd) => {
455-
let fd = unsafe { Socket::from_raw_fd(fd) };
456-
if cfg!(target_os = "linux") {
457-
set_cloexec(fd.as_raw_fd())?;
458-
}
459-
return Ok(fd);
460-
}
461-
Err(ref e) if e.raw_os_error() == Some(libc::EINVAL) => {
462-
CLOEXEC_FAILED.store(true, Ordering::Relaxed);
463-
}
464-
Err(e) => return Err(e),
465-
}
466-
}
467-
let fd = syscall!(fcntl(self.fd, libc::F_DUPFD, 0))?;
468-
let fd = unsafe { Socket::from_raw_fd(fd) };
469-
set_cloexec(fd.as_raw_fd())?;
470-
Ok(fd)
471-
}
472-
473448
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
474449
unsafe {
475450
let raw: c_int = self.getsockopt(libc::SOL_SOCKET, libc::SO_ERROR)?;
@@ -1191,15 +1166,6 @@ fn max_len() -> usize {
11911166
}
11921167
}
11931168

1194-
fn set_cloexec(fd: c_int) -> io::Result<()> {
1195-
let previous = syscall!(fcntl(fd, libc::F_GETFD))?;
1196-
let new = previous | libc::FD_CLOEXEC;
1197-
if new != previous {
1198-
syscall!(fcntl(fd, libc::F_SETFD, new))?;
1199-
}
1200-
Ok(())
1201-
}
1202-
12031169
fn dur2timeval(dur: Option<Duration>) -> io::Result<libc::timeval> {
12041170
match dur {
12051171
Some(dur) => {

src/sys/windows.rs

Lines changed: 26 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use std::cmp;
1010
use std::fmt;
1111
use std::io;
1212
use std::io::{IoSlice, IoSliceMut, Read, Write};
13-
use std::mem::{self, size_of_val};
13+
use std::mem::{self, size_of_val, MaybeUninit};
1414
use std::net::Shutdown;
1515
use std::net::{self, Ipv4Addr, Ipv6Addr};
1616
use std::os::windows::prelude::*;
@@ -36,13 +36,11 @@ use winapi::um::winsock2 as sock;
3636

3737
use crate::{RecvFlags, SockAddr};
3838

39-
const HANDLE_FLAG_INHERIT: DWORD = 0x00000001;
4039
const MSG_PEEK: c_int = 0x2;
4140
const SD_BOTH: c_int = 2;
4241
const SD_RECEIVE: c_int = 0;
4342
const SD_SEND: c_int = 1;
4443
const SIO_KEEPALIVE_VALS: DWORD = 0x98000004;
45-
const WSA_FLAG_OVERLAPPED: DWORD = 0x01;
4644

4745
pub use winapi::ctypes::c_int;
4846

@@ -151,7 +149,7 @@ pub(crate) fn socket(family: c_int, ty: c_int, protocol: c_int) -> io::Result<Sy
151149
protocol,
152150
ptr::null_mut(),
153151
0,
154-
WSA_FLAG_OVERLAPPED,
152+
sock::WSA_FLAG_OVERLAPPED,
155153
),
156154
PartialEq::eq,
157155
sock::INVALID_SOCKET
@@ -209,6 +207,30 @@ pub(crate) fn getpeername(socket: SysSocket) -> io::Result<SockAddr> {
209207
.map(|_| unsafe { SockAddr::from_raw_parts(&storage as *const _ as *const _, len) })
210208
}
211209

210+
pub(crate) fn try_clone(socket: SysSocket) -> io::Result<SysSocket> {
211+
let mut info: MaybeUninit<sock::WSAPROTOCOL_INFOW> = MaybeUninit::uninit();
212+
syscall!(
213+
WSADuplicateSocketW(socket, GetCurrentProcessId(), info.as_mut_ptr()),
214+
PartialEq::eq,
215+
sock::SOCKET_ERROR
216+
)?;
217+
// Safety: `WSADuplicateSocketW` intialised `info` for us.
218+
let mut info = unsafe { info.assume_init() };
219+
220+
syscall!(
221+
WSASocketW(
222+
info.iAddressFamily,
223+
info.iSocketType,
224+
info.iProtocol,
225+
&mut info,
226+
0,
227+
sock::WSA_FLAG_OVERLAPPED | sock::WSA_FLAG_NO_HANDLE_INHERIT,
228+
),
229+
PartialEq::eq,
230+
sock::INVALID_SOCKET
231+
)
232+
}
233+
212234
/// Windows only API.
213235
impl crate::Socket {
214236
/// Sets `HANDLE_FLAG_INHERIT` to zero using `SetHandleInformation`.
@@ -233,30 +255,6 @@ pub struct Socket {
233255
}
234256

235257
impl Socket {
236-
pub fn try_clone(&self) -> io::Result<Socket> {
237-
unsafe {
238-
let mut info: sock::WSAPROTOCOL_INFOW = mem::zeroed();
239-
let r = sock::WSADuplicateSocketW(self.socket, GetCurrentProcessId(), &mut info);
240-
if r != 0 {
241-
return Err(io::Error::last_os_error());
242-
}
243-
let socket = sock::WSASocketW(
244-
info.iAddressFamily,
245-
info.iSocketType,
246-
info.iProtocol,
247-
&mut info,
248-
0,
249-
WSA_FLAG_OVERLAPPED,
250-
);
251-
let socket = match socket {
252-
sock::INVALID_SOCKET => return Err(last_error()),
253-
n => Socket::from_raw_socket(n as RawSocket),
254-
};
255-
socket.set_no_inherit()?;
256-
Ok(socket)
257-
}
258-
}
259-
260258
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
261259
unsafe {
262260
let raw: c_int = self.getsockopt(SOL_SOCKET, SO_ERROR)?;
@@ -844,17 +842,6 @@ impl Socket {
844842
}
845843
}
846844

847-
fn set_no_inherit(&self) -> io::Result<()> {
848-
unsafe {
849-
let r = SetHandleInformation(self.socket as HANDLE, HANDLE_FLAG_INHERIT, 0);
850-
if r == 0 {
851-
Err(io::Error::last_os_error())
852-
} else {
853-
Ok(())
854-
}
855-
}
856-
}
857-
858845
pub fn inner(self) -> SysSocket {
859846
self.socket
860847
}

0 commit comments

Comments
 (0)