Skip to content

Commit f43253f

Browse files
committed
Refactor Socket::set_nonblocking
1 parent cbbd77d commit f43253f

File tree

3 files changed

+44
-30
lines changed

3 files changed

+44
-30
lines changed

src/socket.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,14 @@ impl Socket {
250250

251251
/// Moves this TCP stream into or out of nonblocking mode.
252252
///
253-
/// On Unix this corresponds to calling fcntl, and on Windows this
254-
/// corresponds to calling ioctlsocket.
253+
/// # Notes
254+
///
255+
/// On Unix this corresponds to calling `fcntl` (un)setting `O_NONBLOCK`.
256+
///
257+
/// On Windows this corresponds to calling `ioctlsocket` (un)setting
258+
/// `FIONBIO`.
255259
pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
256-
self.inner().set_nonblocking(nonblocking)
260+
sys::set_nonblocking(self.inner, nonblocking)
257261
}
258262

259263
/// Shuts down the read, write, or both halves of this connection.

src/sys/unix.rs

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,14 @@ pub(crate) fn take_error(fd: SysSocket) -> io::Result<Option<io::Error>> {
357357
}
358358
}
359359

360+
pub(crate) fn set_nonblocking(fd: SysSocket, nonblocking: bool) -> io::Result<()> {
361+
if nonblocking {
362+
fcntl_add(fd, libc::O_NONBLOCK)
363+
} else {
364+
fcntl_remove(fd, libc::O_NONBLOCK)
365+
}
366+
}
367+
360368
/// Unix only API.
361369
impl crate::Socket {
362370
/// Accept a new incoming connection from this listener.
@@ -431,6 +439,18 @@ fn fcntl_add(fd: SysSocket, flag: c_int) -> io::Result<()> {
431439
}
432440
}
433441

442+
/// Remove `flag` to the current set flags of `F_GETFD`.
443+
fn fcntl_remove(fd: SysSocket, flag: c_int) -> io::Result<()> {
444+
let previous = syscall!(fcntl(fd, libc::F_GETFD))?;
445+
let new = previous & !flag;
446+
if new != previous {
447+
syscall!(fcntl(fd, libc::F_SETFD, new)).map(|_| ())
448+
} else {
449+
// Flag was already set.
450+
Ok(())
451+
}
452+
}
453+
434454
/// Caller must ensure `T` is the correct type for `opt` and `val`.
435455
unsafe fn getsockopt<T>(fd: SysSocket, opt: c_int, val: c_int) -> io::Result<T> {
436456
let mut payload: MaybeUninit<T> = MaybeUninit::uninit();
@@ -472,19 +492,6 @@ pub struct Socket {
472492
}
473493

474494
impl Socket {
475-
pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
476-
let previous = syscall!(fcntl(self.fd, libc::F_GETFL))?;
477-
let new = if nonblocking {
478-
previous | libc::O_NONBLOCK
479-
} else {
480-
previous & !libc::O_NONBLOCK
481-
};
482-
if new != previous {
483-
syscall!(fcntl(self.fd, libc::F_SETFL, new))?;
484-
}
485-
Ok(())
486-
}
487-
488495
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
489496
let how = match how {
490497
Shutdown::Write => libc::SHUT_WR,

src/sys/windows.rs

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use std::ptr;
1818
use std::sync::Once;
1919
use std::time::Duration;
2020

21-
use winapi::ctypes::{c_char, c_ulong};
21+
use winapi::ctypes::{c_char, c_long, c_ulong};
2222
use winapi::shared::in6addr::*;
2323
use winapi::shared::inaddr::*;
2424
use winapi::shared::minwindef::DWORD;
@@ -32,7 +32,7 @@ use winapi::um::processthreadsapi::GetCurrentProcessId;
3232
#[cfg(feature = "all")]
3333
use winapi::um::winbase;
3434
use winapi::um::winbase::INFINITE;
35-
use winapi::um::winsock2 as sock;
35+
use winapi::um::winsock2::{self as sock, u_long};
3636

3737
use crate::{RecvFlags, SockAddr, Type};
3838

@@ -261,6 +261,12 @@ pub(crate) fn take_error(socket: SysSocket) -> io::Result<Option<io::Error>> {
261261
}
262262
}
263263

264+
pub(crate) fn set_nonblocking(socket: SysSocket, nonblocking: bool) -> io::Result<()> {
265+
let mut nonblocking = nonblocking as u_long;
266+
ioctlsocket(socket, sock::FIONBIO, &mut nonblocking)
267+
}
268+
269+
/// Caller must ensure `T` is the correct type for `opt` and `val`.
264270
unsafe fn getsockopt<T>(socket: SysSocket, opt: c_int, val: c_int) -> io::Result<T> {
265271
let mut payload: MaybeUninit<T> = MaybeUninit::uninit();
266272
let mut len = mem::size_of::<T>() as c_int;
@@ -276,6 +282,15 @@ unsafe fn getsockopt<T>(socket: SysSocket, opt: c_int, val: c_int) -> io::Result
276282
})
277283
}
278284

285+
fn ioctlsocket(socket: SysSocket, cmd: c_long, payload: &mut u_long) -> io::Result<()> {
286+
syscall!(
287+
ioctlsocket(socket, cmd, payload),
288+
PartialEq::eq,
289+
sock::SOCKET_ERROR
290+
)
291+
.map(|_| ())
292+
}
293+
279294
/// Windows only API.
280295
impl crate::Socket {
281296
/// Sets `HANDLE_FLAG_INHERIT` to zero using `SetHandleInformation`.
@@ -300,18 +315,6 @@ pub struct Socket {
300315
}
301316

302317
impl Socket {
303-
pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
304-
unsafe {
305-
let mut nonblocking = nonblocking as c_ulong;
306-
let r = sock::ioctlsocket(self.socket, sock::FIONBIO as c_int, &mut nonblocking);
307-
if r == 0 {
308-
Ok(())
309-
} else {
310-
Err(io::Error::last_os_error())
311-
}
312-
}
313-
}
314-
315318
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
316319
let how = match how {
317320
Shutdown::Write => SD_SEND,

0 commit comments

Comments
 (0)