Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add into_inner() on Rust writer #1314

Merged
merged 2 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions rust/src/io_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ impl<W> CountingCrcWriter<W> {
pub fn finalize(self) -> (W, Hasher) {
(self.inner, self.hasher)
}

pub fn current_checksum(&self) -> u32 {
self.hasher.clone().finalize()
}
}

impl<W: Write> Write for CountingCrcWriter<W> {
Expand Down
158 changes: 127 additions & 31 deletions rust/src/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use std::{
use bimap::BiHashMap;
use binrw::prelude::*;
use byteorder::{WriteBytesExt, LE};
#[cfg(feature = "zstd")]
use zstd::stream::{raw as zraw, zio};

use crate::{
chunk_sink::{ChunkMode, ChunkSink},
Expand Down Expand Up @@ -311,6 +313,7 @@ struct SchemaContent<'a> {
/// and check for errors when done; otherwise the result will be unwrapped on drop.
pub struct Writer<W: Write + Seek> {
writer: Option<WriteMode<W>>,
is_finished: bool,
chunk_mode: ChunkMode,
options: WriteOptions,
schemas: BiHashMap<SchemaContent<'static>, u16>,
Expand Down Expand Up @@ -363,6 +366,7 @@ impl<W: Write + Seek> Writer<W> {

Ok(Self {
writer: Some(WriteMode::Raw(writer)),
is_finished: false,
options: opts,
chunk_mode,
schemas: Default::default(),
Expand Down Expand Up @@ -667,10 +671,8 @@ impl<W: Write + Seek> Writer<W> {
) -> McapResult<()> {
self.finish_chunk()?;

let prev_writer = self.writer.take().expect(Self::WHERE_WRITER);

let WriteMode::Raw(w) = prev_writer else {
panic!(
let WriteMode::Raw(w) = self.writer.take().expect(Self::WRITER_IS_NONE) else {
unreachable!(
"since finish_chunk was called, write mode is guaranteed to be raw at this point"
);
};
Expand Down Expand Up @@ -786,31 +788,35 @@ impl<W: Write + Seek> Writer<W> {
Ok(())
}

/// `.expect()` message when we go to write and self.writer is `None`,
/// which should only happen when [`Writer::finish()`] was called.
const WHERE_WRITER: &'static str = "Trying to write a record on a finished MCAP";
const WRITER_IS_NONE: &'static str = "unreachable: self.writer should never be None";

fn assert_not_finished(&self) {
assert!(
!self.is_finished,
"{}",
"Trying to write a record on a finished MCAP"
);
}

/// Starts a new chunk if we haven't done so already.
fn start_chunk(&mut self) -> McapResult<&mut ChunkWriter<W>> {
self.assert_not_finished();

// It is not possible to start writing a chunk if we're still writing an attachment. Return
// an error instead.
if let Some(WriteMode::Attachment(..)) = self.writer {
return Err(McapError::AttachmentNotInProgress);
}

// Some Rust tricky: we can't move the writer out of self.writer,
// leave that empty for a bit, and then replace it with a ChunkWriter.
// (That would leave it in an unspecified state if we bailed here!)
// Instead briefly swap it out for a null writer while we set up the chunker
// The writer will only be None if finish() was called.
assert!(
self.options.use_chunks,
"Trying to write to a chunk when chunking is disabled"
);

let prev_writer = self.writer.take().expect(Self::WHERE_WRITER);

self.writer = Some(match prev_writer {
// Rust forbids moving values out of a &mut reference. We made self.writer an Option so we
// can work around this by using take() to temporarily replace it with None while we
// construct the ChunkWriter.
self.writer = Some(match self.writer.take().expect(Self::WRITER_IS_NONE) {
WriteMode::Raw(w) => {
// It's chunkin time.
WriteMode::Chunk(ChunkWriter::new(
Expand All @@ -832,16 +838,15 @@ impl<W: Write + Seek> Writer<W> {

/// Finish the current chunk, if we have one.
fn finish_chunk(&mut self) -> McapResult<&mut CountingCrcWriter<W>> {
self.assert_not_finished();
// If we're currently writing an attachment then we're not writing a chunk. Return an
// error instead.
if let Some(WriteMode::Attachment(..)) = self.writer {
return Err(McapError::AttachmentNotInProgress);
}

// See above
let prev_writer = self.writer.take().expect(Self::WHERE_WRITER);

self.writer = Some(match prev_writer {
// See start_chunk() for why we use take() here.
self.writer = Some(match self.writer.take().expect(Self::WRITER_IS_NONE) {
WriteMode::Chunk(c) => {
let (w, mode, index) = c.finish()?;
self.chunk_indexes.push(index);
Expand All @@ -862,28 +867,29 @@ impl<W: Write + Seek> Writer<W> {
///
/// Subsequent calls to other methods will panic.
pub fn finish(&mut self) -> McapResult<()> {
if self.writer.is_none() {
if self.is_finished {
// We already called finish().
// Maybe we're dropping after the user called it?
return Ok(());
}

// Finish any chunk we were working on and update stats, indexes, etc.
self.finish_chunk()?;
self.is_finished = true;

// Grab the writer - self.writer becoming None makes subsequent writes fail.
let writer = match self.writer.take() {
let writer = match &mut self.writer {
// We called finish_chunk() above, so we're back to raw writes for
// the summary section.
Some(WriteMode::Raw(w)) => w,
_ => unreachable!(),
};
let (mut writer, data_section_crc) = writer.finalize();
let data_section_crc = data_section_crc.finalize();
let data_section_crc = writer.current_checksum();
let writer = writer.get_mut();

// We're done with the data secton!
write_record(
&mut writer,
writer,
&Record::DataEnd(records::DataEnd { data_section_crc }),
)?;

Expand Down Expand Up @@ -952,8 +958,8 @@ impl<W: Write + Seek> Writer<W> {
}

// Write all schemas.
let schemas_start = summary_start;
if self.options.repeat_schemas && !all_schemas.is_empty() {
let schemas_start: u64 = summary_start;
for schema in all_schemas.iter() {
write_record(&mut ccw, schema)?;
}
Expand Down Expand Up @@ -1053,14 +1059,30 @@ impl<W: Write + Seek> Writer<W> {
ccw.write_u64::<LE>(summary_start)?;
ccw.write_u64::<LE>(summary_offset_start)?;

let (mut writer, summary_crc) = ccw.finalize();
let (writer, summary_crc) = ccw.finalize();

writer.write_u32::<LE>(summary_crc.finalize())?;

writer.write_all(MAGIC)?;
writer.flush()?;
Ok(())
}

/// Consumes this writer, returning the underlying stream. Unless [`Self::finish()`] was called
/// first, the underlying stream __will not contain a complete MCAP.__
///
/// Use this if you wish to handle any errors returned when the underlying stream is closed. In
/// particular, if using [`std::fs::File`], you may wish to call [`std::fs::File::sync_all()`]
/// to ensure all data was sent to the filesystem.
pub fn into_inner(mut self) -> W {
self.is_finished = true;
// Peel away all the layers of the writer to get the underlying stream.
match self.writer.take().expect(Self::WRITER_IS_NONE) {
WriteMode::Raw(w) => w.finalize().0,
WriteMode::Attachment(w) => w.writer.finalize().0.finalize().0,
WriteMode::Chunk(w) => w.compressor.finalize().0.into_inner().finalize().0.inner,
}
}
}

impl<W: Write + Seek> Drop for Writer<W> {
Expand All @@ -1071,8 +1093,10 @@ impl<W: Write + Seek> Drop for Writer<W> {

enum Compressor<W: Write> {
Null(W),
// zstd's Encoder wrapper doesn't let us get the inner writer without calling finish(), so use
// zio::Writer directly instead.
#[cfg(feature = "zstd")]
Zstd(zstd::Encoder<'static, W>),
Zstd(zio::Writer<W, zraw::Encoder<'static>>),
#[cfg(feature = "lz4")]
Lz4(lz4::Encoder<W>),
}
Expand All @@ -1082,7 +1106,10 @@ impl<W: Write> Compressor<W> {
Ok(match self {
Compressor::Null(w) => w,
#[cfg(feature = "zstd")]
Compressor::Zstd(w) => w.finish()?,
Compressor::Zstd(mut w) => {
w.finish()?;
w.into_inner().0
}
#[cfg(feature = "lz4")]
Compressor::Lz4(w) => {
let (output, result) = w.finish();
Expand All @@ -1091,6 +1118,16 @@ impl<W: Write> Compressor<W> {
}
})
}

fn into_inner(self) -> W {
match self {
Compressor::Null(w) => w,
#[cfg(feature = "zstd")]
Compressor::Zstd(w) => w.into_inner().0,
#[cfg(feature = "lz4")]
Compressor::Lz4(w) => w.finish().0,
}
}
}

impl<W: Write> Write for Compressor<W> {
Expand Down Expand Up @@ -1178,10 +1215,11 @@ impl<W: Write + Seek> ChunkWriter<W> {
#[cfg(feature = "zstd")]
Some(Compression::Zstd) => {
#[allow(unused_mut)]
let mut enc = zstd::Encoder::new(sink, 0)?;
let mut enc = zraw::Encoder::with_dictionary(0, &[])?;
// Enable multithreaded encoding on non-WASM targets.
#[cfg(not(target_arch = "wasm32"))]
enc.multithread(num_cpus::get_physical() as u32)?;
Compressor::Zstd(enc)
enc.set_parameter(zraw::CParameter::NbWorkers(num_cpus::get_physical() as u32))?;
Compressor::Zstd(zio::Writer::new(sink, enc))
}
#[cfg(feature = "lz4")]
Some(Compression::Lz4) => Compressor::Lz4(
Expand Down Expand Up @@ -1510,4 +1548,62 @@ mod tests {
};
assert!(matches!(too_many, McapError::TooManySchemas));
}

#[test]
#[should_panic(expected = "Trying to write a record on a finished MCAP")]
fn panics_if_write_called_after_finish() {
let file = std::io::Cursor::new(Vec::new());
let mut writer = Writer::new(file).expect("failed to construct writer");
writer.finish().expect("failed to finish writer");

let custom_channel = std::sync::Arc::new(crate::Channel {
id: 1,
topic: "chat".into(),
message_encoding: "json".into(),
metadata: BTreeMap::new(),
schema: None,
});

writer
.write(&crate::Message {
channel: custom_channel.clone(),
sequence: 0,
log_time: 0,
publish_time: 0,
data: Cow::Owned(Vec::new()),
})
.expect("could not write message");
}

#[test]
fn writes_message_and_checks_stream_length() {
let file = std::io::Cursor::new(Vec::new());
let mut writer = Writer::new(file).expect("failed to construct writer");

let custom_channel = std::sync::Arc::new(crate::Channel {
id: 1,
topic: "chat".into(),
message_encoding: "json".into(),
metadata: BTreeMap::new(),
schema: None,
});

writer
.write(&crate::Message {
channel: custom_channel.clone(),
sequence: 0,
log_time: 0,
publish_time: 0,
data: Cow::Owned(Vec::new()),
})
.expect("could not write message");

writer.finish().expect("failed to finish writer");

let output_len = writer
.into_inner()
.stream_position()
.expect("failed to get stream position");
assert_eq!(output_len, 487);
}
}
Loading