From 3ab11aa3249dc414847d22b636309da15d605db2 Mon Sep 17 00:00:00 2001 From: "Stephen M. Coakley" Date: Sun, 16 Feb 2020 22:05:50 -0600 Subject: [PATCH] Implement AsyncBufRead Fixes #8. --- Cargo.toml | 2 + src/lib.rs | 2 - src/pipe/chunked.rs | 182 +++++++++++++++++++------------------------- src/pipe/mod.rs | 34 ++++++--- tests/pipe.rs | 113 +++++++++++++++++++++++++++ 5 files changed, 215 insertions(+), 118 deletions(-) create mode 100644 tests/pipe.rs diff --git a/Cargo.toml b/Cargo.toml index 2df2a92..318af27 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,8 @@ default-features = false [dev-dependencies] criterion = "0.3" futures = "0.3" +quickcheck = "0.9" +quickcheck_macros = "0.9" [[bench]] name = "pipe" diff --git a/src/lib.rs b/src/lib.rs index aa61741..8cc6f7f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,6 +19,4 @@ clippy::all, )] -#![cfg_attr(feature = "nightly", feature(async_await))] - pub mod pipe; diff --git a/src/pipe/chunked.rs b/src/pipe/chunked.rs index e8e2f99..de2c0bd 100644 --- a/src/pipe/chunked.rs +++ b/src/pipe/chunked.rs @@ -22,10 +22,10 @@ use futures_channel::mpsc; use futures_core::{FusedStream, Stream}; -use futures_io::{AsyncRead, AsyncWrite}; -use futures_util::io::Cursor; +use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite}; use std::{ io, + io::{BufRead, Cursor, Write}, pin::Pin, task::{Context, Poll}, }; @@ -47,7 +47,9 @@ pub(crate) fn new(count: usize) -> (Reader, Writer) { // Fill up the buffer pool. for _ in 0..count { - buf_pool_tx.try_send(Cursor::new(Vec::new())).expect("buffer pool overflow"); + buf_pool_tx + .try_send(Cursor::new(Vec::new())) + .expect("buffer pool overflow"); } let reader = Reader { @@ -77,65 +79,53 @@ pub(crate) struct Reader { } impl AsyncRead for Reader { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { - // Fetch the chunk to read from. If we already have one from a previous - // read, use that, otherwise receive the next chunk from the writer. - let mut chunk = match self.chunk.take() { - Some(chunk) => chunk, - - None => { - // If the stream has terminated, then do not poll it again. - if self.buf_stream_rx.is_terminated() { - return Poll::Ready(Ok(0)); - } + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: &mut [u8], + ) -> Poll> { + // Read into the internal buffer. + match self.as_mut().poll_fill_buf(cx)? { + // Not quite ready yet. + Poll::Pending => Poll::Pending, - match Pin::new(&mut self.buf_stream_rx).poll_next(cx) { - // Wait for a new chunk to be delivered. - Poll::Pending => return Poll::Pending, + // A chunk is available. + Poll::Ready(chunk) => { + // Copy as much of the chunk as we can to the destination + // buffer. + let amt = buf.write(chunk)?; - // Pipe has closed, so return EOF. - Poll::Ready(None) => return Poll::Ready(Ok(0)), + // Mark however much was successfully copied as being consumed. + self.consume(amt); - // Accept the new chunk. - Poll::Ready(Some(buf)) => buf, - } + Poll::Ready(Ok(amt)) } - }; - - // Do the read. - let len = match Pin::new(&mut chunk).poll_read(cx, buf) { - Poll::Pending => unreachable!(), - Poll::Ready(Ok(len)) => len, - Poll::Ready(Err(e)) => panic!("cursor returned an error: {}", e), - }; - - // If the chunk is not empty yet, keep it for a future read. - if chunk.position() < chunk.get_ref().len() as u64 { - self.chunk = Some(chunk); } + } +} - // Otherwise, return it to the writer to be reused. - else { - chunk.set_position(0); - chunk.get_mut().clear(); - - match self.buf_pool_tx.try_send(chunk) { - Ok(()) => {} - - Err(e) => { +impl AsyncBufRead for Reader { + fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // If the current chunk is consumed, first return it to the writer for + // reuse. + if let Some(chunk) = self.chunk.as_ref() { + if chunk.position() >= chunk.get_ref().len() as u64 { + let mut chunk = self.chunk.take().unwrap(); + chunk.set_position(0); + chunk.get_mut().clear(); + + if let Err(e) = self.buf_pool_tx.try_send(chunk) { // We pre-fill the buffer pool channel with an exact number // of buffers, so this can never happen. if e.is_full() { panic!("buffer pool overflow") } - // If the writer disconnects, then we'll just discard this // buffer and any subsequent buffers until we've read // everything still in the pipe. else if e.is_disconnected() { // Nothing! } - // Some other error occurred. else { return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())); @@ -144,7 +134,38 @@ impl AsyncRead for Reader { } } - Poll::Ready(Ok(len)) + // If we have no current chunk, then attempt to read one. + if self.chunk.is_none() { + // If the stream has terminated, then do not poll it again. + if self.buf_stream_rx.is_terminated() { + return Poll::Ready(Ok(&[])); + } + + match Pin::new(&mut self.buf_stream_rx).poll_next(cx) { + // Wait for a new chunk to be delivered. + Poll::Pending => return Poll::Pending, + + // Pipe has closed, so return EOF. + Poll::Ready(None) => return Poll::Ready(Ok(&[])), + + // Accept the new chunk. + Poll::Ready(buf) => self.chunk = buf, + } + } + + // Return the current chunk, if any, as the buffer. + #[allow(unsafe_code)] + Poll::Ready(match unsafe { self.get_unchecked_mut().chunk.as_mut() } { + Some(chunk) => chunk.fill_buf(), + None => Ok(&[]), + }) + } + + fn consume(mut self: Pin<&mut Self>, amt: usize) { + if let Some(chunk) = self.chunk.as_mut() { + // Consume the requested amount from the current chunk. + chunk.consume(amt); + } } } @@ -158,7 +179,16 @@ pub(crate) struct Writer { } impl AsyncWrite for Writer { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + // Do not send empty buffers through the rotation. + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + // If the pipe is closed then return prematurely, otherwise we'd be // spending time writing the entire buffer only to discover that it is // closed afterward. @@ -204,61 +234,3 @@ impl AsyncWrite for Writer { Poll::Ready(Ok(())) } } - -#[cfg(all(test, feature = "nightly"))] -mod tests { - use futures::executor::block_on; - use futures::prelude::*; - use futures::task::noop_waker; - use super::*; - - #[test] - fn read_then_write() { - block_on(async { - let (mut reader, mut writer) = new(1); - - writer.write_all(b"hello").await.unwrap(); - - let mut dest = [0; 5]; - assert_eq!(reader.read(&mut dest).await.unwrap(), 5); - assert_eq!(&dest, b"hello"); - }) - } - - #[test] - fn reader_still_drainable_after_writer_disconnects() { - block_on(async { - let (mut reader, mut writer) = new(1); - - writer.write_all(b"hello").await.unwrap(); - - drop(writer); - - let mut dest = [0; 5]; - assert_eq!(reader.read(&mut dest).await.unwrap(), 5); - assert_eq!(&dest, b"hello"); - - // Continue returning Ok(0) forever. - for _ in 0..3 { - assert_eq!(reader.read(&mut dest).await.unwrap(), 0); - } - }) - } - - #[test] - fn writer_errors_if_reader_is_dropped() { - let waker = noop_waker(); - let mut context = Context::from_waker(&waker); - - let (reader, mut writer) = new(2); - - drop(reader); - - for _ in 0..3 { - match writer.write(b"hello").poll_unpin(&mut context) { - Poll::Ready(Err(e)) => assert_eq!(e.kind(), io::ErrorKind::BrokenPipe), - _ => panic!("expected poll to be ready"), - } - } - } -} diff --git a/src/pipe/mod.rs b/src/pipe/mod.rs index 10c4916..63b35ac 100644 --- a/src/pipe/mod.rs +++ b/src/pipe/mod.rs @@ -3,7 +3,7 @@ //! Pipes are like byte-oriented channels that implement I/O traits for reading //! and writing. -use futures_io::{AsyncRead, AsyncWrite}; +use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite}; use std::fmt; use std::io; use std::pin::Pin; @@ -23,14 +23,7 @@ const DEFAULT_CHUNK_COUNT: usize = 4; pub fn pipe() -> (PipeReader, PipeWriter) { let (reader, writer) = chunked::new(DEFAULT_CHUNK_COUNT); - ( - PipeReader { - inner: reader, - }, - PipeWriter { - inner: writer, - }, - ) + (PipeReader { inner: reader }, PipeWriter { inner: writer }) } /// The reading end of an asynchronous pipe. @@ -39,11 +32,26 @@ pub struct PipeReader { } impl AsyncRead for PipeReader { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { Pin::new(&mut self.inner).poll_read(cx, buf) } } +impl AsyncBufRead for PipeReader { + #[allow(unsafe_code)] + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + unsafe { self.map_unchecked_mut(|s| &mut s.inner) }.poll_fill_buf(cx) + } + + fn consume(mut self: Pin<&mut Self>, amt: usize) { + Pin::new(&mut self.inner).consume(amt) + } +} + impl fmt::Debug for PipeReader { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.pad("PipeReader") @@ -56,7 +64,11 @@ pub struct PipeWriter { } impl AsyncWrite for PipeWriter { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { Pin::new(&mut self.inner).poll_write(cx, buf) } diff --git a/tests/pipe.rs b/tests/pipe.rs new file mode 100644 index 0000000..781b12c --- /dev/null +++ b/tests/pipe.rs @@ -0,0 +1,113 @@ +use futures::{ + executor::block_on, + join, + prelude::*, +}; +use quickcheck_macros::quickcheck; +use sluice::pipe::pipe; +use std::io; + +#[test] +fn read_empty() { + block_on(async { + let (mut reader, writer) = pipe(); + drop(writer); + + let mut out = String::new(); + reader.read_to_string(&mut out).await.unwrap(); + assert_eq!(out, ""); + }) +} + +#[test] +fn read_then_write() { + block_on(async { + let (mut reader, mut writer) = pipe(); + + writer.write_all(b"hello world").await.unwrap(); + + let mut dest = [0; 6]; + + assert_eq!(reader.read(&mut dest).await.unwrap(), 6); + assert_eq!(&dest, b"hello "); + + assert_eq!(reader.read(&mut dest).await.unwrap(), 5); + assert_eq!(&dest[..5], b"world"); + }) +} + +#[test] +fn reader_still_drainable_after_writer_disconnects() { + block_on(async { + let (mut reader, mut writer) = pipe(); + + writer.write_all(b"hello").await.unwrap(); + + drop(writer); + + let mut dest = [0; 5]; + assert_eq!(reader.read(&mut dest).await.unwrap(), 5); + assert_eq!(&dest, b"hello"); + + // Continue returning Ok(0) forever. + for _ in 0..3 { + assert_eq!(reader.read(&mut dest).await.unwrap(), 0); + } + }) +} + +#[test] +fn writer_errors_if_reader_is_dropped() { + block_on(async { + let (reader, mut writer) = pipe(); + + drop(reader); + + for _ in 0..3 { + assert_eq!(writer.write(b"hello").await.unwrap_err().kind(), io::ErrorKind::BrokenPipe); + } + }) +} + +#[test] +fn pipe_lots_of_data() { + block_on(async { + let data = vec![0xff; 1_000_000]; + let (mut reader, mut writer) = pipe(); + + join!( + async { + writer.write_all(&data).await.unwrap(); + writer.close().await.unwrap(); + }, + async { + let mut out = Vec::new(); + reader.read_to_end(&mut out).await.unwrap(); + assert_eq!(&out[..], &data[..]); + }, + ); + }) +} + +#[quickcheck] +fn read_write_chunks_random(chunks: u16) { + block_on(async { + let data = [0; 8192]; + let (mut reader, mut writer) = pipe(); + + join!( + async { + for chunk in 0..chunks { + writer.write_all(&data).await.unwrap(); + } + }, + async { + for chunk in 0..chunks { + let mut buf = data.clone(); + reader.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..], &data[..]); + } + }, + ); + }) +}