Skip to content

Commit

Permalink
Implement AsyncBufRead
Browse files Browse the repository at this point in the history
Fixes #8.
  • Loading branch information
sagebind committed Feb 17, 2020
1 parent 0bada49 commit 3ab11aa
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 118 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ default-features = false
[dev-dependencies]
criterion = "0.3"
futures = "0.3"
quickcheck = "0.9"
quickcheck_macros = "0.9"

[[bench]]
name = "pipe"
Expand Down
2 changes: 0 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,4 @@
clippy::all,
)]

#![cfg_attr(feature = "nightly", feature(async_await))]

pub mod pipe;
182 changes: 77 additions & 105 deletions src/pipe/chunked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
use futures_channel::mpsc;
use futures_core::{FusedStream, Stream};
use futures_io::{AsyncRead, AsyncWrite};
use futures_util::io::Cursor;
use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite};
use std::{
io,
io::{BufRead, Cursor, Write},
pin::Pin,
task::{Context, Poll},
};
Expand All @@ -47,7 +47,9 @@ pub(crate) fn new(count: usize) -> (Reader, Writer) {

// Fill up the buffer pool.
for _ in 0..count {
buf_pool_tx.try_send(Cursor::new(Vec::new())).expect("buffer pool overflow");
buf_pool_tx
.try_send(Cursor::new(Vec::new()))
.expect("buffer pool overflow");
}

let reader = Reader {
Expand Down Expand Up @@ -77,65 +79,53 @@ pub(crate) struct Reader {
}

impl AsyncRead for Reader {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
// Fetch the chunk to read from. If we already have one from a previous
// read, use that, otherwise receive the next chunk from the writer.
let mut chunk = match self.chunk.take() {
Some(chunk) => chunk,

None => {
// If the stream has terminated, then do not poll it again.
if self.buf_stream_rx.is_terminated() {
return Poll::Ready(Ok(0));
}
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut buf: &mut [u8],
) -> Poll<io::Result<usize>> {
// Read into the internal buffer.
match self.as_mut().poll_fill_buf(cx)? {
// Not quite ready yet.
Poll::Pending => Poll::Pending,

match Pin::new(&mut self.buf_stream_rx).poll_next(cx) {
// Wait for a new chunk to be delivered.
Poll::Pending => return Poll::Pending,
// A chunk is available.
Poll::Ready(chunk) => {
// Copy as much of the chunk as we can to the destination
// buffer.
let amt = buf.write(chunk)?;

// Pipe has closed, so return EOF.
Poll::Ready(None) => return Poll::Ready(Ok(0)),
// Mark however much was successfully copied as being consumed.
self.consume(amt);

// Accept the new chunk.
Poll::Ready(Some(buf)) => buf,
}
Poll::Ready(Ok(amt))
}
};

// Do the read.
let len = match Pin::new(&mut chunk).poll_read(cx, buf) {
Poll::Pending => unreachable!(),
Poll::Ready(Ok(len)) => len,
Poll::Ready(Err(e)) => panic!("cursor returned an error: {}", e),
};

// If the chunk is not empty yet, keep it for a future read.
if chunk.position() < chunk.get_ref().len() as u64 {
self.chunk = Some(chunk);
}
}
}

// Otherwise, return it to the writer to be reused.
else {
chunk.set_position(0);
chunk.get_mut().clear();

match self.buf_pool_tx.try_send(chunk) {
Ok(()) => {}

Err(e) => {
impl AsyncBufRead for Reader {
fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
// If the current chunk is consumed, first return it to the writer for
// reuse.
if let Some(chunk) = self.chunk.as_ref() {
if chunk.position() >= chunk.get_ref().len() as u64 {
let mut chunk = self.chunk.take().unwrap();
chunk.set_position(0);
chunk.get_mut().clear();

if let Err(e) = self.buf_pool_tx.try_send(chunk) {
// We pre-fill the buffer pool channel with an exact number
// of buffers, so this can never happen.
if e.is_full() {
panic!("buffer pool overflow")
}

// If the writer disconnects, then we'll just discard this
// buffer and any subsequent buffers until we've read
// everything still in the pipe.
else if e.is_disconnected() {
// Nothing!
}

// Some other error occurred.
else {
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
Expand All @@ -144,7 +134,38 @@ impl AsyncRead for Reader {
}
}

Poll::Ready(Ok(len))
// If we have no current chunk, then attempt to read one.
if self.chunk.is_none() {
// If the stream has terminated, then do not poll it again.
if self.buf_stream_rx.is_terminated() {
return Poll::Ready(Ok(&[]));
}

match Pin::new(&mut self.buf_stream_rx).poll_next(cx) {
// Wait for a new chunk to be delivered.
Poll::Pending => return Poll::Pending,

// Pipe has closed, so return EOF.
Poll::Ready(None) => return Poll::Ready(Ok(&[])),

// Accept the new chunk.
Poll::Ready(buf) => self.chunk = buf,
}
}

// Return the current chunk, if any, as the buffer.
#[allow(unsafe_code)]
Poll::Ready(match unsafe { self.get_unchecked_mut().chunk.as_mut() } {
Some(chunk) => chunk.fill_buf(),
None => Ok(&[]),
})
}

fn consume(mut self: Pin<&mut Self>, amt: usize) {
if let Some(chunk) = self.chunk.as_mut() {
// Consume the requested amount from the current chunk.
chunk.consume(amt);
}
}
}

Expand All @@ -158,7 +179,16 @@ pub(crate) struct Writer {
}

impl AsyncWrite for Writer {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
// Do not send empty buffers through the rotation.
if buf.is_empty() {
return Poll::Ready(Ok(0));
}

// If the pipe is closed then return prematurely, otherwise we'd be
// spending time writing the entire buffer only to discover that it is
// closed afterward.
Expand Down Expand Up @@ -204,61 +234,3 @@ impl AsyncWrite for Writer {
Poll::Ready(Ok(()))
}
}

#[cfg(all(test, feature = "nightly"))]
mod tests {
use futures::executor::block_on;
use futures::prelude::*;
use futures::task::noop_waker;
use super::*;

#[test]
fn read_then_write() {
block_on(async {
let (mut reader, mut writer) = new(1);

writer.write_all(b"hello").await.unwrap();

let mut dest = [0; 5];
assert_eq!(reader.read(&mut dest).await.unwrap(), 5);
assert_eq!(&dest, b"hello");
})
}

#[test]
fn reader_still_drainable_after_writer_disconnects() {
block_on(async {
let (mut reader, mut writer) = new(1);

writer.write_all(b"hello").await.unwrap();

drop(writer);

let mut dest = [0; 5];
assert_eq!(reader.read(&mut dest).await.unwrap(), 5);
assert_eq!(&dest, b"hello");

// Continue returning Ok(0) forever.
for _ in 0..3 {
assert_eq!(reader.read(&mut dest).await.unwrap(), 0);
}
})
}

#[test]
fn writer_errors_if_reader_is_dropped() {
let waker = noop_waker();
let mut context = Context::from_waker(&waker);

let (reader, mut writer) = new(2);

drop(reader);

for _ in 0..3 {
match writer.write(b"hello").poll_unpin(&mut context) {
Poll::Ready(Err(e)) => assert_eq!(e.kind(), io::ErrorKind::BrokenPipe),
_ => panic!("expected poll to be ready"),
}
}
}
}
34 changes: 23 additions & 11 deletions src/pipe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! Pipes are like byte-oriented channels that implement I/O traits for reading
//! and writing.
use futures_io::{AsyncRead, AsyncWrite};
use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite};
use std::fmt;
use std::io;
use std::pin::Pin;
Expand All @@ -23,14 +23,7 @@ const DEFAULT_CHUNK_COUNT: usize = 4;
pub fn pipe() -> (PipeReader, PipeWriter) {
let (reader, writer) = chunked::new(DEFAULT_CHUNK_COUNT);

(
PipeReader {
inner: reader,
},
PipeWriter {
inner: writer,
},
)
(PipeReader { inner: reader }, PipeWriter { inner: writer })
}

/// The reading end of an asynchronous pipe.
Expand All @@ -39,11 +32,26 @@ pub struct PipeReader {
}

impl AsyncRead for PipeReader {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}

impl AsyncBufRead for PipeReader {
#[allow(unsafe_code)]
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
unsafe { self.map_unchecked_mut(|s| &mut s.inner) }.poll_fill_buf(cx)
}

fn consume(mut self: Pin<&mut Self>, amt: usize) {
Pin::new(&mut self.inner).consume(amt)
}
}

impl fmt::Debug for PipeReader {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("PipeReader")
Expand All @@ -56,7 +64,11 @@ pub struct PipeWriter {
}

impl AsyncWrite for PipeWriter {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}

Expand Down
Loading

0 comments on commit 3ab11aa

Please sign in to comment.