Skip to content

Commit

Permalink
Avoid unwrapping in into_inner()
Browse files Browse the repository at this point in the history
  • Loading branch information
Muon committed Jan 20, 2025
1 parent 76ef68a commit 6a19bdb
Showing 1 changed file with 39 additions and 24 deletions.
63 changes: 39 additions & 24 deletions rust/src/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::{
use bimap::BiHashMap;
use binrw::prelude::*;
use byteorder::{WriteBytesExt, LE};
use zstd::stream::{raw as zraw, zio};

use crate::{
chunk_sink::{ChunkMode, ChunkSink},
Expand Down Expand Up @@ -669,7 +670,7 @@ impl<W: Write + Seek> Writer<W> {
) -> McapResult<()> {
self.finish_chunk()?;

let WriteMode::Raw(w) = self.writer.take().expect(Self::WHERE_WRITER) else {
let WriteMode::Raw(w) = self.writer.take().expect(Self::WRITER_IS_NONE) else {
unreachable!(
"since finish_chunk was called, write mode is guaranteed to be raw at this point"
);
Expand Down Expand Up @@ -786,13 +787,19 @@ impl<W: Write + Seek> Writer<W> {
Ok(())
}

/// `.expect()` message when we go to write and self.writer is `None`,
/// which should only happen when [`Writer::finish()`] was called.
const WHERE_WRITER: &'static str = "Trying to write a record on a finished MCAP";
const WRITER_IS_NONE: &'static str = "unreachable: self.writer should never be None";

fn assert_not_finished(&self) {
assert!(
!self.is_finished,
"{}",
"Trying to write a record on a finished MCAP"
);
}

/// Starts a new chunk if we haven't done so already.
fn start_chunk(&mut self) -> McapResult<&mut ChunkWriter<W>> {
assert!(!self.is_finished, "{}", Self::WHERE_WRITER);
self.assert_not_finished();

// It is not possible to start writing a chunk if we're still writing an attachment. Return
// an error instead.
Expand All @@ -808,7 +815,7 @@ impl<W: Write + Seek> Writer<W> {
// Rust forbids moving values out of a &mut reference. We made self.writer an Option so we
// can work around this by using take() to temporarily replace it with None while we
// construct the ChunkWriter.
self.writer = Some(match self.writer.take().expect(Self::WHERE_WRITER) {
self.writer = Some(match self.writer.take().expect(Self::WRITER_IS_NONE) {
WriteMode::Raw(w) => {
// It's chunkin time.
WriteMode::Chunk(ChunkWriter::new(
Expand All @@ -830,15 +837,15 @@ impl<W: Write + Seek> Writer<W> {

/// Finish the current chunk, if we have one.
fn finish_chunk(&mut self) -> McapResult<&mut CountingCrcWriter<W>> {
assert!(!self.is_finished, "{}", Self::WHERE_WRITER);
self.assert_not_finished();
// If we're currently writing an attachment then we're not writing a chunk. Return an
// error instead.
if let Some(WriteMode::Attachment(..)) = self.writer {
return Err(McapError::AttachmentNotInProgress);
}

// See start_chunk() for why we use take() here.
self.writer = Some(match self.writer.take().expect(Self::WHERE_WRITER) {
self.writer = Some(match self.writer.take().expect(Self::WRITER_IS_NONE) {
WriteMode::Chunk(c) => {
let (w, mode, index) = c.finish()?;
self.chunk_indexes.push(index);
Expand Down Expand Up @@ -1068,19 +1075,11 @@ impl<W: Write + Seek> Writer<W> {
/// to ensure all data was sent to the filesystem.
pub fn into_inner(mut self) -> W {
self.is_finished = true;
match self.writer.take().expect(Self::WHERE_WRITER) {
// Peel away all the layers of the writer to get the underlying stream.
match self.writer.take().expect(Self::WRITER_IS_NONE) {
WriteMode::Raw(w) => w.finalize().0,
WriteMode::Attachment(w) => w.writer.finalize().0.finalize().0,
WriteMode::Chunk(w) => {
w.compressor
.finalize()
.0
.finish()
.expect("compression error")
.finalize()
.0
.inner
}
WriteMode::Chunk(w) => w.compressor.finalize().0.into_inner().finalize().0.inner,
}
}
}
Expand All @@ -1093,8 +1092,10 @@ impl<W: Write + Seek> Drop for Writer<W> {

enum Compressor<W: Write> {
Null(W),
// zstd's Encoder wrapper doesn't let us get the inner writer without calling finish(), so use
// zio::Writer directly instead.
#[cfg(feature = "zstd")]
Zstd(zstd::Encoder<'static, W>),
Zstd(zio::Writer<W, zraw::Encoder<'static>>),
#[cfg(feature = "lz4")]
Lz4(lz4::Encoder<W>),
}
Expand All @@ -1104,7 +1105,10 @@ impl<W: Write> Compressor<W> {
Ok(match self {
Compressor::Null(w) => w,
#[cfg(feature = "zstd")]
Compressor::Zstd(w) => w.finish()?,
Compressor::Zstd(mut w) => {
w.finish()?;
w.into_inner().0
}
#[cfg(feature = "lz4")]
Compressor::Lz4(w) => {
let (output, result) = w.finish();
Expand All @@ -1113,6 +1117,16 @@ impl<W: Write> Compressor<W> {
}
})
}

fn into_inner(self) -> W {
match self {
Compressor::Null(w) => w,
#[cfg(feature = "zstd")]
Compressor::Zstd(w) => w.into_inner().0,
#[cfg(feature = "lz4")]
Compressor::Lz4(w) => w.finish().0,
}
}
}

impl<W: Write> Write for Compressor<W> {
Expand Down Expand Up @@ -1200,10 +1214,11 @@ impl<W: Write + Seek> ChunkWriter<W> {
#[cfg(feature = "zstd")]
Some(Compression::Zstd) => {
#[allow(unused_mut)]
let mut enc = zstd::Encoder::new(sink, 0)?;
let mut enc = zraw::Encoder::with_dictionary(0, &[])?;
// Enable multithreaded encoding on non-WASM targets.
#[cfg(not(target_arch = "wasm32"))]
enc.multithread(num_cpus::get_physical() as u32)?;
Compressor::Zstd(enc)
enc.set_parameter(zraw::CParameter::NbWorkers(num_cpus::get_physical() as u32))?;
Compressor::Zstd(zio::Writer::new(sink, enc))
}
#[cfg(feature = "lz4")]
Some(Compression::Lz4) => Compressor::Lz4(
Expand Down

0 comments on commit 6a19bdb

Please sign in to comment.