Skip to content

Commit cbbd77d

Browse files
committed
Refactor Socket::take_error
1 parent 05bae6d commit cbbd77d

File tree

3 files changed

+55
-26
lines changed

3 files changed

+55
-26
lines changed

src/socket.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ impl Socket {
245245
/// the field in the process. This can be useful for checking errors between
246246
/// calls.
247247
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
248-
self.inner().take_error()
248+
sys::take_error(self.inner)
249249
}
250250

251251
/// Moves this TCP stream into or out of nonblocking mode.

src/sys/unix.rs

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#[cfg(not(target_os = "redox"))]
1010
use std::io::{IoSlice, IoSliceMut};
1111
use std::io::{Read, Write};
12-
use std::mem::{self, size_of_val};
12+
use std::mem::{self, size_of_val, MaybeUninit};
1313
use std::net::Shutdown;
1414
use std::net::{self, Ipv4Addr, Ipv6Addr};
1515
#[cfg(feature = "all")]
@@ -349,6 +349,14 @@ pub(crate) fn try_clone(fd: SysSocket) -> io::Result<SysSocket> {
349349
syscall!(fcntl(fd, libc::F_DUPFD_CLOEXEC, 0))
350350
}
351351

352+
pub(crate) fn take_error(fd: SysSocket) -> io::Result<Option<io::Error>> {
353+
match unsafe { getsockopt::<c_int>(fd, libc::SOL_SOCKET, libc::SO_ERROR) } {
354+
Ok(0) => Ok(None),
355+
Ok(errno) => Ok(Some(io::Error::from_raw_os_error(errno))),
356+
Err(err) => Err(err),
357+
}
358+
}
359+
352360
/// Unix only API.
353361
impl crate::Socket {
354362
/// Accept a new incoming connection from this listener.
@@ -423,6 +431,25 @@ fn fcntl_add(fd: SysSocket, flag: c_int) -> io::Result<()> {
423431
}
424432
}
425433

434+
/// Caller must ensure `T` is the correct type for `opt` and `val`.
435+
unsafe fn getsockopt<T>(fd: SysSocket, opt: c_int, val: c_int) -> io::Result<T> {
436+
let mut payload: MaybeUninit<T> = MaybeUninit::uninit();
437+
let mut len = mem::size_of::<T>() as libc::socklen_t;
438+
syscall!(getsockopt(
439+
fd,
440+
opt,
441+
val,
442+
payload.as_mut_ptr().cast(),
443+
&mut len,
444+
))
445+
.map(|_| {
446+
debug_assert_eq!(len as usize, mem::size_of::<T>());
447+
// Safety: `getsockopt` initialised `payload` for us.
448+
payload.assume_init()
449+
})
450+
}
451+
452+
/// Caller must ensure `T` is the correct type for `opt` and `val`.
426453
#[cfg(all(feature = "all", target_vendor = "apple"))]
427454
unsafe fn setsockopt<T>(fd: SysSocket, opt: c_int, val: c_int, payload: T) -> io::Result<()>
428455
where
@@ -435,8 +462,8 @@ where
435462
val,
436463
payload,
437464
mem::size_of::<T>() as libc::socklen_t,
438-
))?;
439-
Ok(())
465+
))
466+
.map(|_| ())
440467
}
441468

442469
#[repr(transparent)] // Required during rewriting.
@@ -445,17 +472,6 @@ pub struct Socket {
445472
}
446473

447474
impl Socket {
448-
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
449-
unsafe {
450-
let raw: c_int = self.getsockopt(libc::SOL_SOCKET, libc::SO_ERROR)?;
451-
if raw == 0 {
452-
Ok(None)
453-
} else {
454-
Ok(Some(io::Error::from_raw_os_error(raw as i32)))
455-
}
456-
}
457-
}
458-
459475
pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
460476
let previous = syscall!(fcntl(self.fd, libc::F_GETFL))?;
461477
let new = if nonblocking {

src/sys/windows.rs

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ pub(crate) use winapi::um::ws2tcpip::socklen_t;
7272
/// Helper macro to execute a system call that returns an `io::Result`.
7373
macro_rules! syscall {
7474
($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{
75+
#[allow(unused_unsafe)]
7576
let res = unsafe { sock::$fn($($arg, )*) };
7677
if $err_test(&res, &$err_value) {
7778
Err(io::Error::last_os_error())
@@ -252,6 +253,29 @@ pub(crate) fn try_clone(socket: SysSocket) -> io::Result<SysSocket> {
252253
)
253254
}
254255

256+
pub(crate) fn take_error(socket: SysSocket) -> io::Result<Option<io::Error>> {
257+
match unsafe { getsockopt::<c_int>(socket, SOL_SOCKET, SO_ERROR) } {
258+
Ok(0) => Ok(None),
259+
Ok(errno) => Ok(Some(io::Error::from_raw_os_error(errno))),
260+
Err(err) => Err(err),
261+
}
262+
}
263+
264+
unsafe fn getsockopt<T>(socket: SysSocket, opt: c_int, val: c_int) -> io::Result<T> {
265+
let mut payload: MaybeUninit<T> = MaybeUninit::uninit();
266+
let mut len = mem::size_of::<T>() as c_int;
267+
syscall!(
268+
getsockopt(socket, opt, val, payload.as_mut_ptr().cast(), &mut len,),
269+
PartialEq::eq,
270+
sock::SOCKET_ERROR
271+
)
272+
.map(|_| {
273+
debug_assert_eq!(len as usize, mem::size_of::<T>());
274+
// Safety: `getsockopt` initialised `payload` for us.
275+
payload.assume_init()
276+
})
277+
}
278+
255279
/// Windows only API.
256280
impl crate::Socket {
257281
/// Sets `HANDLE_FLAG_INHERIT` to zero using `SetHandleInformation`.
@@ -276,17 +300,6 @@ pub struct Socket {
276300
}
277301

278302
impl Socket {
279-
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
280-
unsafe {
281-
let raw: c_int = self.getsockopt(SOL_SOCKET, SO_ERROR)?;
282-
if raw == 0 {
283-
Ok(None)
284-
} else {
285-
Ok(Some(io::Error::from_raw_os_error(raw as i32)))
286-
}
287-
}
288-
}
289-
290303
pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
291304
unsafe {
292305
let mut nonblocking = nonblocking as c_ulong;

0 commit comments

Comments
 (0)