Skip to content

Commit 878e6b4

Browse files
committed
Change Socket::set_no_inherit signature
To match Socket::set_cloexec. Also add a test for it and Type::no_inherit.
1 parent 05a3a82 commit 878e6b4

File tree

2 files changed

+63
-10
lines changed

2 files changed

+63
-10
lines changed

src/sys/windows.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,11 @@ use winapi::ctypes::{c_char, c_long, c_ulong};
2222
use winapi::shared::in6addr::*;
2323
use winapi::shared::inaddr::*;
2424
use winapi::shared::minwindef::DWORD;
25-
#[cfg(feature = "all")]
2625
use winapi::shared::ntdef::HANDLE;
2726
use winapi::shared::ws2def::{self, *};
2827
use winapi::shared::ws2ipdef::*;
29-
#[cfg(feature = "all")]
3028
use winapi::um::handleapi::SetHandleInformation;
3129
use winapi::um::processthreadsapi::GetCurrentProcessId;
32-
#[cfg(feature = "all")]
3330
use winapi::um::winbase;
3431
use winapi::um::winbase::INFINITE;
3532
use winapi::um::winsock2::{self as sock, u_long, SD_BOTH, SD_RECEIVE, SD_SEND};
@@ -91,7 +88,7 @@ impl_debug!(
9188
impl Type {
9289
/// Our custom flag to set `WSA_FLAG_NO_HANDLE_INHERIT` on socket creation.
9390
/// Trying to mimic `Type::cloexec` on windows.
94-
const NO_INHERIT: c_int = 1 << (size_of::<c_int>());
91+
const NO_INHERIT: c_int = 1 << ((size_of::<c_int>() * 8) - 1); // Last bit.
9592

9693
/// Set `WSA_FLAG_NO_HANDLE_INHERIT` on the socket.
9794
#[cfg(feature = "all")]
@@ -300,12 +297,16 @@ fn ioctlsocket(socket: SysSocket, cmd: c_long, payload: &mut u_long) -> io::Resu
300297
/// Windows only API.
301298
impl crate::Socket {
302299
/// Sets `HANDLE_FLAG_INHERIT` to zero using `SetHandleInformation`.
303-
#[cfg(feature = "all")]
304-
pub fn set_no_inherit(&self) -> io::Result<()> {
300+
pub fn set_no_inherit(&self, no_inherit: bool) -> io::Result<()> {
305301
// NOTE: can't use `syscall!` because it expects the function in the
306302
// `sock::` path.
307-
let res =
308-
unsafe { SetHandleInformation(self.inner as HANDLE, winbase::HANDLE_FLAG_INHERIT, 0) };
303+
let res = unsafe {
304+
SetHandleInformation(
305+
self.inner as HANDLE,
306+
winbase::HANDLE_FLAG_INHERIT,
307+
!no_inherit as _,
308+
)
309+
};
309310
if res == 0 {
310311
// Zero means error.
311312
Err(io::Error::last_os_error())

tests/socket.rs

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
1+
#[cfg(windows)]
2+
use std::io;
13
#[cfg(unix)]
24
use std::os::unix::io::AsRawFd;
5+
#[cfg(windows)]
6+
use std::os::windows::io::AsRawSocket;
7+
8+
#[cfg(windows)]
9+
use winapi::shared::minwindef::DWORD;
10+
#[cfg(windows)]
11+
use winapi::um::handleapi::GetHandleInformation;
12+
#[cfg(windows)]
13+
use winapi::um::winbase::HANDLE_FLAG_INHERIT;
314

415
use socket2::{Domain, Socket, Type};
516

@@ -28,7 +39,7 @@ fn set_nonblocking() {
2839
))]
2940
#[test]
3041
fn type_nonblocking() {
31-
let ty = Type::Stream.nonblocking();
42+
let ty = Type::STREAM.nonblocking();
3243
let socket = Socket::new(Domain::IPV4, ty, None).unwrap();
3344
assert_nonblocking(&socket, true);
3445
}
@@ -76,17 +87,58 @@ fn set_cloexec() {
7687
))]
7788
#[test]
7889
fn type_cloexec() {
79-
let ty = Type::Stream.cloexec();
90+
let ty = Type::STREAM.cloexec();
8091
let socket = Socket::new(Domain::IPV4, ty, None).unwrap();
8192
assert_close_on_exec(&socket, true);
8293
}
8394

8495
/// Assert that `CLOEXEC` is set on `socket`.
8596
#[cfg(unix)]
97+
#[track_caller]
8698
pub fn assert_close_on_exec<S>(socket: &S, want: bool)
8799
where
88100
S: AsRawFd,
89101
{
90102
let flags = unsafe { libc::fcntl(socket.as_raw_fd(), libc::F_GETFD) };
91103
assert_eq!(flags & libc::FD_CLOEXEC != 0, want, "CLOEXEC option");
92104
}
105+
106+
#[cfg(windows)]
107+
#[test]
108+
fn set_no_inherit() {
109+
let socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap();
110+
assert_flag_inherit(&socket, false);
111+
112+
socket.set_no_inherit(true).unwrap();
113+
assert_flag_inherit(&socket, true);
114+
115+
socket.set_no_inherit(false).unwrap();
116+
assert_flag_inherit(&socket, false);
117+
}
118+
119+
#[cfg(all(feature = "all", windows))]
120+
#[test]
121+
fn type_no_inherit() {
122+
let ty = Type::STREAM.no_inherit();
123+
let socket = Socket::new(Domain::IPV4, ty, None).unwrap();
124+
assert_flag_inherit(&socket, true);
125+
}
126+
127+
/// Assert that `FLAG_INHERIT` is set on `socket`.
128+
#[cfg(windows)]
129+
#[track_caller]
130+
pub fn assert_flag_inherit<S>(socket: &S, want: bool)
131+
where
132+
S: AsRawSocket,
133+
{
134+
let mut flags: DWORD = 0;
135+
if unsafe { GetHandleInformation(socket.as_raw_socket() as _, &mut flags) } == 0 {
136+
let err = io::Error::last_os_error();
137+
panic!("unexpected error: {}", err);
138+
}
139+
assert_eq!(
140+
flags & HANDLE_FLAG_INHERIT != 0,
141+
want,
142+
"FLAG_INHERIT option"
143+
);
144+
}

0 commit comments

Comments
 (0)