Skip to content

Commit

Permalink
Merge pull request #1 from mcroomp/simplifywrite
Browse files Browse the repository at this point in the history
Simplify the write path for VP8
  • Loading branch information
mcroomp authored Oct 19, 2024
2 parents e8dcb33 + b2913a8 commit 1f56eb5
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 69 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "cabac"
version = "0.8.0"
version = "0.9.0"
edition = "2021"

description = "Context-adaptive binary arithmetic coding library"
Expand All @@ -19,6 +19,7 @@ byteorder = "1.4"

[dev-dependencies]
criterion = "0.5"
rand = "0.8"

[lib]

Expand Down
79 changes: 37 additions & 42 deletions src/vp8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS”

use std::io::{Read, Result, Write};

use byteorder::WriteBytesExt;

use crate::traits::{CabacReader, CabacWriter};

const BITS_IN_BYTE: i32 = 8;
Expand Down Expand Up @@ -247,7 +249,8 @@ pub struct VP8Writer<W> {
range: u32,
bits_left: i32,
writer: W,
buffer: Vec<u8>,
num_buffered_bytes: u32,
buffered_byte: u8,
}

impl<W: Write> VP8Writer<W> {
Expand All @@ -256,8 +259,9 @@ impl<W: Write> VP8Writer<W> {
low_value: 0,
range: 255,
bits_left: -24,
buffer: Vec::new(),
writer: writer,
num_buffered_bytes: 0,
buffered_byte: 0,
};

let mut dummy_branch = VP8Context::default();
Expand All @@ -266,23 +270,7 @@ impl<W: Write> VP8Writer<W> {
Ok(retval)
}

/// When buffer is full and is going to be sent to output, preserve buffer data that
/// is not final and should carried over to the next buffer.
fn flush_non_final_data(&mut self) -> Result<()> {
// carry over buffer data that might be not final
let mut i = self.buffer.len() - 1;
while self.buffer[i] == 0xFF {
assert!(i > 0);
i -= 1;
}

self.writer.write_all(&self.buffer[..i])?;
self.buffer.drain(..i);

Ok(())
}

#[inline(always)]
#[inline]
fn send_to_output(
&mut self,
shift: &mut i32,
Expand All @@ -291,31 +279,42 @@ impl<W: Write> VP8Writer<W> {
) -> Result<()> {
let offset = *shift - *tmp_count;

if ((*tmp_low_value << (offset - 1)) & 0x80000000) != 0 {
let mut x = self.buffer.len() - 1;
let last_byte = *tmp_low_value >> (24 - offset);

while self.buffer[x] == 0xFF {
self.buffer[x] = 0;
if (last_byte & 0x100) != 0 {
self.flush_buffered_bytes(1)?;
}

assert!(x > 0);
x -= 1;
}
let last_byte = last_byte as u8;

self.buffer[x] += 1;
}
if last_byte == 0xff {
self.num_buffered_bytes += 1;
} else {
self.flush_buffered_bytes(0)?;

self.buffer.push((*tmp_low_value >> (24 - offset)) as u8);
self.buffered_byte = last_byte;
self.num_buffered_bytes = 1;
}

*tmp_low_value <<= offset;
*shift = *tmp_count;
*tmp_low_value &= 0xffffff;
*tmp_count -= 8;

// check if we're out of buffer space, if yes - send the buffer to output,
if self.buffer.len() > 65536 - 128 {
self.flush_non_final_data()?;
}
Ok(())
}

fn flush_buffered_bytes(&mut self, carry: u8) -> Result<()> {
if self.num_buffered_bytes > 0 {
self.writer
.write_u8(self.buffered_byte.wrapping_add(carry))?;
self.num_buffered_bytes -= 1;

while self.num_buffered_bytes > 0 {
self.writer.write_u8(0xffu8.wrapping_add(carry))?;
self.num_buffered_bytes -= 1;
}
}
Ok(())
}
}
Expand Down Expand Up @@ -401,17 +400,13 @@ impl<W: Write> CabacWriter<VP8Context> for VP8Writer<W> {
}

fn finish(&mut self) -> Result<()> {
for _i in 0..32 {
let mut dummy_branch = VP8Context::default();
self.put(false, &mut dummy_branch)?;
}

// Ensure there's no ambigous collision with any index marker bytes
if (self.buffer.last().unwrap() & 0xe0) == 0xc0 {
self.buffer.push(0);
// pad the rest of the stream so we don't have to
// worry about carrying the last byte
while self.low_value > 0 {
self.put_bypass(false)?;
}

self.writer.write_all(&self.buffer[..])?;
self.flush_buffered_bytes(0)?;

Ok(())
}
Expand Down
99 changes: 73 additions & 26 deletions tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,25 @@ fn test_permutation_vp8(pattern: u64, num_bits: u8, bypass_index: u8) {
}
}

#[derive(Clone, Copy)]
enum Seq {
Normal(bool),
Normal(bool, usize),
Bypass(bool),
}

fn test_seq_vp8(seq: &[Seq]) {
let mut output = Vec::new();
{
let mut context = VP8Context::default();
let mut context = Vec::new();
for _i in 0..16 {
context.push(VP8Context::default());
}
let mut writer = VP8Writer::new(&mut output).unwrap();

for s in seq {
for &s in seq {
match s {
Seq::Normal(b) => writer.put(*b, &mut context).unwrap(),
Seq::Bypass(b) => writer.put_bypass(*b).unwrap(),
Seq::Normal(b, c) => writer.put(b, &mut context[c]).unwrap(),
Seq::Bypass(b) => writer.put_bypass(b).unwrap(),
}
}

Expand All @@ -106,16 +110,20 @@ fn test_seq_vp8(seq: &[Seq]) {

// now try reading it
{
let mut context = VP8Context::default();
let mut context = Vec::new();
for _ in 0..16 {
context.push(VP8Context::default());
}

let mut reader = VP8Reader::new(Cursor::new(&output)).unwrap();

for s in seq {
for &s in seq {
match s {
Seq::Normal(b) => {
assert_eq!(*b, reader.get(&mut context).unwrap())
Seq::Normal(b, c) => {
assert_eq!(b, reader.get(&mut context[c]).unwrap())
}
Seq::Bypass(b) => {
assert_eq!(*b, reader.get_bypass().unwrap())
assert_eq!(b, reader.get_bypass().unwrap())
}
}
}
Expand All @@ -125,13 +133,17 @@ fn test_seq_vp8(seq: &[Seq]) {
fn test_seq_h265(seq: &[Seq]) {
let mut output = Vec::new();
{
let mut context = H265Context::default();
let mut context = Vec::new();
for _ in 0..16 {
context.push(H265Context::default());
}

let mut writer = H265Writer::new(&mut output);

for s in seq {
for &s in seq {
match s {
Seq::Normal(b) => writer.put(*b, &mut context).unwrap(),
Seq::Bypass(b) => writer.put_bypass(*b).unwrap(),
Seq::Normal(b, c) => writer.put(b, &mut context[c]).unwrap(),
Seq::Bypass(b) => writer.put_bypass(b).unwrap(),
}
}

Expand All @@ -140,16 +152,20 @@ fn test_seq_h265(seq: &[Seq]) {

// now try reading it
{
let mut context = H265Context::default();
let mut context = Vec::new();
for _ in 0..16 {
context.push(H265Context::default());
}

let mut reader = H265Reader::new(Cursor::new(&output)).unwrap();

for s in seq {
for &s in seq {
match s {
Seq::Normal(b) => {
assert_eq!(*b, reader.get(&mut context).unwrap())
Seq::Normal(b, c) => {
assert_eq!(b, reader.get(&mut context[c]).unwrap())
}
Seq::Bypass(b) => {
assert_eq!(*b, reader.get_bypass().unwrap())
assert_eq!(b, reader.get_bypass().unwrap())
}
}
}
Expand Down Expand Up @@ -216,17 +232,22 @@ fn test_basic_permutations_h264() {

#[test]
fn test_random_sequences() {
for i in 1..27 {
let mut seed: u32 = 27 + i;
use rand::Rng;

let mut rng = rand::thread_rng();

let probs: [f64; 16] = [
0.001, 0.01, 0.1, 0.11, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 0.91, 0.99, 0.999, 0.9999, 1.0,
];

for _ in 1..1000 {
let mut seq = Vec::new();

for _i in 0..10000 {
seed = seed.wrapping_mul(10000019) + 7;
for _ in 0..1000 {
let ctx = rng.gen_range(0..16);

seq.push(match seed % 4 {
0 => Seq::Normal(true),
1 => Seq::Normal(false),
seq.push(match rng.gen_range(0..4) {
0 | 1 => Seq::Normal(rng.gen_bool(probs[ctx]), ctx),
2 => Seq::Bypass(false),
_ => Seq::Bypass(true),
});
Expand All @@ -236,3 +257,29 @@ fn test_random_sequences() {
test_seq_vp8(&seq);
}
}

#[test]
fn test_all_0() {
let all_0 = vec![Seq::Normal(false, 0); 10000];

test_seq_h265(&all_0);
test_seq_vp8(&all_0);
}

#[test]
fn test_all_1() {
let all_1 = vec![Seq::Normal(true, 0); 10000];

test_seq_h265(&all_1);
test_seq_vp8(&all_1);
}

#[test]
fn test_alt() {
let mut seq = Vec::new();
for i in 0..10000 {
seq.push(Seq::Normal(i % 2 == 0, 0));
}
test_seq_h265(&seq);
test_seq_vp8(&seq);
}

0 comments on commit 1f56eb5

Please sign in to comment.