Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 21b72eb

Browse files
Merge pull request #894 from jbr/dont-poll-after-eof-in-io-copy
io::copy: don't poll the reader again after eof while waiting for the writer to flush
2 parents 996ff48 + ae817ca commit 21b72eb

File tree

2 files changed

+100
-4
lines changed

2 files changed

+100
-4
lines changed

‎src/io/copy.rs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ use crate::io::{self, BufRead, BufReader, Read, Write};
77
use crate::task::{Context, Poll};
88
use crate::utils::Context as _;
99

10+
// Note: There are two otherwise-identical implementations of this
11+
// function because unstable has removed the `?Sized` bound for the
12+
// reader and writer and accepts `R` and `W` instead of `&mut R` and
13+
// `&mut W`. If making a change to either of the implementations,
14+
// ensure that you copy it into the other.
15+
1016
/// Copies the entire contents of a reader into a writer.
1117
///
1218
/// This function will continuously read data from `reader` and then
@@ -57,6 +63,7 @@ where
5763
#[pin]
5864
writer: W,
5965
amt: u64,
66+
reader_eof: bool
6067
}
6168
}
6269

@@ -69,13 +76,20 @@ where
6976

7077
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
7178
let mut this = self.project();
79+
7280
loop {
73-
let buffer = futures_core::ready!(this.reader.as_mut().poll_fill_buf(cx))?;
74-
if buffer.is_empty() {
81+
if *this.reader_eof {
7582
futures_core::ready!(this.writer.as_mut().poll_flush(cx))?;
7683
return Poll::Ready(Ok(*this.amt));
7784
}
7885

86+
let buffer = futures_core::ready!(this.reader.as_mut().poll_fill_buf(cx))?;
87+
88+
if buffer.is_empty() {
89+
*this.reader_eof = true;
90+
continue;
91+
}
92+
7993
let i = futures_core::ready!(this.writer.as_mut().poll_write(cx, buffer))?;
8094
if i == 0 {
8195
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
@@ -89,6 +103,7 @@ where
89103
let future = CopyFuture {
90104
reader: BufReader::new(reader),
91105
writer,
106+
reader_eof: false,
92107
amt: 0,
93108
};
94109
future.await.context(|| String::from("io::copy failed"))
@@ -144,6 +159,7 @@ where
144159
#[pin]
145160
writer: W,
146161
amt: u64,
162+
reader_eof: bool
147163
}
148164
}
149165

@@ -156,13 +172,20 @@ where
156172

157173
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158174
let mut this = self.project();
175+
159176
loop {
160-
let buffer = futures_core::ready!(this.reader.as_mut().poll_fill_buf(cx))?;
161-
if buffer.is_empty() {
177+
if *this.reader_eof {
162178
futures_core::ready!(this.writer.as_mut().poll_flush(cx))?;
163179
return Poll::Ready(Ok(*this.amt));
164180
}
165181

182+
let buffer = futures_core::ready!(this.reader.as_mut().poll_fill_buf(cx))?;
183+
184+
if buffer.is_empty() {
185+
*this.reader_eof = true;
186+
continue;
187+
}
188+
166189
let i = futures_core::ready!(this.writer.as_mut().poll_write(cx, buffer))?;
167190
if i == 0 {
168191
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
@@ -176,6 +199,7 @@ where
176199
let future = CopyFuture {
177200
reader: BufReader::new(reader),
178201
writer,
202+
reader_eof: false,
179203
amt: 0,
180204
};
181205
future.await.context(|| String::from("io::copy failed"))

‎tests/io_copy.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
use std::{
2+
io::Result,
3+
pin::Pin,
4+
task::{Context, Poll},
5+
};
6+
7+
struct ReaderThatPanicsAfterEof {
8+
read_count: usize,
9+
has_sent_eof: bool,
10+
max_read: usize,
11+
}
12+
13+
impl async_std::io::Read for ReaderThatPanicsAfterEof {
14+
fn poll_read(
15+
mut self: Pin<&mut Self>,
16+
_cx: &mut Context<'_>,
17+
buf: &mut [u8],
18+
) -> Poll<Result<usize>> {
19+
if self.has_sent_eof {
20+
panic!("this should be unreachable because we should not poll after eof (Ready(Ok(0)))")
21+
} else if self.read_count >= self.max_read {
22+
self.has_sent_eof = true;
23+
Poll::Ready(Ok(0))
24+
} else {
25+
self.read_count += 1;
26+
Poll::Ready(Ok(buf.len()))
27+
}
28+
}
29+
}
30+
31+
struct WriterThatTakesAWhileToFlush {
32+
max_flush: usize,
33+
flush_count: usize,
34+
}
35+
36+
impl async_std::io::Write for WriterThatTakesAWhileToFlush {
37+
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
38+
Poll::Ready(Ok(buf.len()))
39+
}
40+
41+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
42+
self.flush_count += 1;
43+
if self.flush_count >= self.max_flush {
44+
Poll::Ready(Ok(()))
45+
} else {
46+
cx.waker().wake_by_ref();
47+
Poll::Pending
48+
}
49+
}
50+
51+
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
52+
Poll::Ready(Ok(()))
53+
}
54+
}
55+
56+
#[test]
57+
fn io_copy_does_not_poll_after_eof() {
58+
async_std::task::block_on(async {
59+
let mut reader = ReaderThatPanicsAfterEof {
60+
has_sent_eof: false,
61+
max_read: 10,
62+
read_count: 0,
63+
};
64+
65+
let mut writer = WriterThatTakesAWhileToFlush {
66+
flush_count: 0,
67+
max_flush: 10,
68+
};
69+
70+
assert!(async_std::io::copy(&mut reader, &mut writer).await.is_ok());
71+
})
72+
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /