Skip to content

Commit 38d7035

Browse files
author
Stjepan Glavina
committed
Shut down the socket on close()
1 parent 07b6cca commit 38d7035

File tree

3 files changed

+38
-8
lines changed

3 files changed

+38
-8
lines changed

src/lib.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -623,8 +623,8 @@ impl<T: Write> AsyncWrite for Async<T> {
623623
poll_once(cx, self.write_with_mut(|io| io.flush()))
624624
}
625625

626-
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
627-
self.poll_flush(cx)
626+
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
627+
Poll::Ready(self.source.shutdown_write())
628628
}
629629
}
630630

@@ -652,8 +652,8 @@ where
652652
poll_once(cx, self.write_with(|io| (&*io).flush()))
653653
}
654654

655-
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
656-
self.poll_flush(cx)
655+
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
656+
Poll::Ready(self.source.shutdown_write())
657657
}
658658
}
659659

src/parking.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
use std::collections::BTreeMap;
88
use std::fmt;
99
use std::io;
10-
use std::mem;
10+
use std::mem::{self, ManuallyDrop};
11+
use std::net::{Shutdown, TcpStream};
1112
#[cfg(unix)]
12-
use std::os::unix::io::RawFd;
13+
use std::os::unix::io::{FromRawFd, RawFd};
1314
#[cfg(windows)]
14-
use std::os::windows::io::RawSocket;
15+
use std::os::windows::io::{FromRawSocket, RawSocket};
1516
use std::panic;
1617
use std::sync::atomic::{AtomicUsize, Ordering};
1718
use std::sync::{Arc, Condvar, Mutex, MutexGuard};
@@ -739,4 +740,21 @@ impl Source {
739740
})
740741
.await
741742
}
743+
744+
/// Shuts down the write side of the socket.
745+
///
746+
/// If this source is not a socket, the `shutdown` syscall error is ignored.
747+
pub(crate) fn shutdown_write(&self) -> io::Result<()> {
748+
// This may not be a TCP stream, but that's okay - all we do is call `shutdown()` on it.
749+
#[cfg(unix)]
750+
let stream = unsafe { ManuallyDrop::new(TcpStream::from_raw_fd(self.raw)) };
751+
#[cfg(windows)]
752+
let stream = unsafe { ManuallyDrop::new(TcpStream::from_raw_socket(self.raw)) };
753+
754+
// The only actual error may be ENOTCONN.
755+
match stream.shutdown(Shutdown::Write) {
756+
Err(err) if err.kind() == io::ErrorKind::NotConnected => Err(err),
757+
_ => Ok(()),
758+
}
759+
}
742760
}

tests/async.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::time::Duration;
99

1010
use async_io::{Async, Timer};
1111
use blocking::block_on;
12-
use futures::{AsyncReadExt, AsyncWriteExt, StreamExt};
12+
use futures::{future, AsyncReadExt, AsyncWriteExt, StreamExt};
1313
#[cfg(unix)]
1414
use tempfile::tempdir;
1515

@@ -337,3 +337,15 @@ fn tcp_duplex() -> io::Result<()> {
337337
Ok(())
338338
})
339339
}
340+
341+
#[test]
342+
fn close() -> io::Result<()> {
343+
block_on(async {
344+
let (mut reader, mut writer) = Async::<UnixStream>::pair()?;
345+
let mut buf = Vec::new();
346+
347+
// The writer must be closed in order for `read_to_end()` to finish.
348+
future::try_join(reader.read_to_end(&mut buf), writer.close()).await?;
349+
Ok(())
350+
})
351+
}

0 commit comments

Comments
 (0)