Skip to content

Fix incorrect io::Take's limit resulting from io::copy specialization #79650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion library/std/src/sys/unix/fs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1211,7 +1211,8 @@ pub fn copy(from: &Path, to: &Path) -> io::Result<u64> {
use super::kernel_copy::{copy_regular_files, CopyResult};

match copy_regular_files(reader.as_raw_fd(), writer.as_raw_fd(), max_len) {
CopyResult::Ended(result) => result,
CopyResult::Ended(bytes) => Ok(bytes),
CopyResult::Error(e, _) => Err(e),
CopyResult::Fallback(written) => match io::copy::generic_copy(&mut reader, &mut writer) {
Ok(bytes) => Ok(bytes + written),
Err(e) => Err(e),
Expand Down
54 changes: 42 additions & 12 deletions library/std/src/sys/unix/kernel_copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,11 @@ impl<R: CopyRead, W: CopyWrite> SpecCopy for Copier<'_, '_, R, W> {

if input_meta.copy_file_range_candidate() && output_meta.copy_file_range_candidate() {
let result = copy_regular_files(readfd, writefd, max_write);
result.update_take(reader);

match result {
CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written),
CopyResult::Ended(err) => return err,
CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
CopyResult::Error(e, _) => return Err(e),
CopyResult::Fallback(bytes) => written += bytes,
}
}
Expand All @@ -182,20 +183,22 @@ impl<R: CopyRead, W: CopyWrite> SpecCopy for Copier<'_, '_, R, W> {
// fall back to the generic copy loop.
if input_meta.potential_sendfile_source() {
let result = sendfile_splice(SpliceMode::Sendfile, readfd, writefd, max_write);
result.update_take(reader);

match result {
CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written),
CopyResult::Ended(err) => return err,
CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
CopyResult::Error(e, _) => return Err(e),
CopyResult::Fallback(bytes) => written += bytes,
}
}

if input_meta.maybe_fifo() || output_meta.maybe_fifo() {
let result = sendfile_splice(SpliceMode::Splice, readfd, writefd, max_write);
result.update_take(reader);

match result {
CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written),
CopyResult::Ended(err) => return err,
CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
CopyResult::Error(e, _) => return Err(e),
CopyResult::Fallback(0) => { /* use the fallback below */ }
CopyResult::Fallback(_) => {
unreachable!("splice should not return > 0 bytes on the fallback path")
Expand Down Expand Up @@ -225,6 +228,9 @@ trait CopyRead: Read {
Ok(0)
}

/// Updates `Take` wrappers to remove the number of bytes copied.
fn taken(&mut self, _bytes: u64) {}

/// The minimum of the limit of all `Take<_>` wrappers, `u64::MAX` otherwise.
/// This method does not account for data `BufReader` buffers and would underreport
/// the limit of a `Take<BufReader<Take<_>>>` type. Thus its result is only valid
Expand All @@ -251,6 +257,10 @@ where
(**self).drain_to(writer, limit)
}

fn taken(&mut self, bytes: u64) {
(**self).taken(bytes);
}

fn min_limit(&self) -> u64 {
(**self).min_limit()
}
Expand Down Expand Up @@ -407,6 +417,11 @@ impl<T: CopyRead> CopyRead for Take<T> {
Ok(bytes_drained)
}

fn taken(&mut self, bytes: u64) {
self.set_limit(self.limit() - bytes);
self.get_mut().taken(bytes);
}

fn min_limit(&self) -> u64 {
min(Take::limit(self), self.get_ref().min_limit())
}
Expand All @@ -432,6 +447,10 @@ impl<T: CopyRead> CopyRead for BufReader<T> {
Ok(bytes as u64 + inner_bytes)
}

fn taken(&mut self, bytes: u64) {
self.get_mut().taken(bytes);
}

fn min_limit(&self) -> u64 {
self.get_ref().min_limit()
}
Expand All @@ -457,10 +476,21 @@ fn fd_to_meta<T: AsRawFd>(fd: &T) -> FdMeta {
}

pub(super) enum CopyResult {
Ended(Result<u64>),
Ended(u64),
Error(Error, u64),
Fallback(u64),
}

impl CopyResult {
fn update_take(&self, reader: &mut impl CopyRead) {
match *self {
CopyResult::Fallback(bytes)
| CopyResult::Ended(bytes)
| CopyResult::Error(_, bytes) => reader.taken(bytes),
}
}
}

/// linux-specific implementation that will attempt to use copy_file_range for copy offloading
/// as the name says, it only works on regular files
///
Expand Down Expand Up @@ -527,7 +557,7 @@ pub(super) fn copy_regular_files(reader: RawFd, writer: RawFd, max_len: u64) ->
// - copying from an overlay filesystem in docker. reported to occur on fedora 32.
return CopyResult::Fallback(0);
}
Ok(0) => return CopyResult::Ended(Ok(written)), // reached EOF
Ok(0) => return CopyResult::Ended(written), // reached EOF
Ok(ret) => written += ret as u64,
Err(err) => {
return match err.raw_os_error() {
Expand All @@ -545,12 +575,12 @@ pub(super) fn copy_regular_files(reader: RawFd, writer: RawFd, max_len: u64) ->
assert_eq!(written, 0);
CopyResult::Fallback(0)
}
_ => CopyResult::Ended(Err(err)),
_ => CopyResult::Error(err, written),
};
}
}
}
CopyResult::Ended(Ok(written))
CopyResult::Ended(written)
}

#[derive(PartialEq)]
Expand Down Expand Up @@ -623,10 +653,10 @@ fn sendfile_splice(mode: SpliceMode, reader: RawFd, writer: RawFd, len: u64) ->
Some(os_err) if mode == SpliceMode::Sendfile && os_err == libc::EOVERFLOW => {
CopyResult::Fallback(written)
}
_ => CopyResult::Ended(Err(err)),
_ => CopyResult::Error(err, written),
};
}
}
}
CopyResult::Ended(Ok(written))
CopyResult::Ended(written)
}
13 changes: 10 additions & 3 deletions library/std/src/sys/unix/kernel_copy/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,15 @@ fn copy_specialization() -> Result<()> {
assert_eq!(sink.buffer(), b"wxyz");

let copied = crate::io::copy(&mut source, &mut sink)?;
assert_eq!(copied, 10);
assert_eq!(sink.buffer().len(), 0);
assert_eq!(copied, 10, "copy obeyed limit imposed by Take");
assert_eq!(sink.buffer().len(), 0, "sink buffer was flushed");
assert_eq!(source.limit(), 0, "outer Take was exhausted");
assert_eq!(source.get_ref().buffer().len(), 0, "source buffer should be drained");
assert_eq!(
source.get_ref().get_ref().limit(),
1,
"inner Take allowed reading beyond end of file, some bytes should be left"
);

let mut sink = sink.into_inner()?;
sink.seek(SeekFrom::Start(0))?;
Expand Down Expand Up @@ -210,7 +217,7 @@ fn bench_socket_pipe_socket_copy(b: &mut test::Bencher) {
);

match probe {
CopyResult::Ended(Ok(1)) => {
CopyResult::Ended(1) => {
// splice works
}
_ => {
Expand Down