Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add loom test for the counter #6888

Merged
merged 14 commits into from
Dec 15, 2022
80 changes: 79 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions ci/scripts/run-unit-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ set -euo pipefail
echo "+++ Run unit tests with coverage"
# use tee to disable progress bar
NEXTEST_PROFILE=ci cargo llvm-cov nextest --lcov --output-path lcov.info --features failpoints,sync_point 2> >(tee);
NEXTEST_PROFILE=ci RUSTFLAGS="--cfg loom" cargo nextest run --test loom
BowenXiao1999 marked this conversation as resolved.
Show resolved Hide resolved
if [[ "$RUN_SQLSMITH" -eq "1" ]]; then
NEXTEST_PROFILE=ci cargo nextest run run_sqlsmith_on_frontend --features "failpoints sync_point enable_sqlsmith_unit_test" 2> >(tee);
fi
Expand Down
6 changes: 5 additions & 1 deletion src/utils/task_stats_alloc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ tokio = { version = "0.2", package = "madsim-tokio", features = [
"time",
"signal",
] }
workspace-hack = { path = "../../workspace-hack" }
#workspace-hack = { path = "../../workspace-hack" }

[dev-dependencies]


[target.'cfg(loom)'.dependencies]
loom = {version = "0.5", features = ["futures", "checkpoint"]}
101 changes: 101 additions & 0 deletions src/utils/task_stats_alloc/tests/loom.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#![feature(allocator_api)]
BowenXiao1999 marked this conversation as resolved.
Show resolved Hide resolved
#![cfg(loom)]

// use task_stats_alloc::*;
use std::alloc::System;
use std::borrow::BorrowMut;
use std::hint::black_box;
use std::ptr::NonNull;
use std::time::Duration;

use loom::sync::atomic::{fence, AtomicUsize, Ordering};
use loom::sync::Arc;
use loom::thread;
use task_stats_alloc::{allocation_stat, BYTES_ALLOCATED};
use tokio::runtime::Handle;

#[repr(transparent)]
#[derive(Clone, Debug)]
pub struct TaskLocalBytesAllocated(Option<NonNull<AtomicUsize>>);

impl Default for TaskLocalBytesAllocated {
fn default() -> Self {
Self(Some(
NonNull::new(Box::leak(Box::new_in(0.into(), System))).unwrap(),
))
}
}

impl TaskLocalBytesAllocated {
pub fn new() -> Self {
Self::default()
}

/// Create an invalid counter.
pub const fn invalid() -> Self {
Self(None)
}

/// Adds to the current counter.
#[inline(always)]
pub(crate) fn add(&self, val: usize) {
if let Some(bytes) = self.0 {
let bytes_ref = unsafe { bytes.as_ref() };
bytes_ref.fetch_add(val, Ordering::Relaxed);
}
}

/// Subtracts from the counter value, and `drop` the counter while the count reaches zero.
#[inline(always)]
pub(crate) fn sub(&self, val: usize, atomic: Arc<AtomicUsize>) {
if let Some(bytes) = self.0 {
let bytes_ref = unsafe { bytes.as_ref() };
// Use Release to synchronize with the below deletion.
let old_bytes = bytes_ref.fetch_sub(val, Ordering::Relaxed);
// If the counter reaches zero, delete the counter. Note that we've ensured there's no
// zero deltas in `wrap_layout`, so there'll be no more uses of the counter.
if old_bytes == val {
// No fence here. Atomic add to avoid
atomic.fetch_add(1, Ordering::Relaxed);
unsafe { Box::from_raw_in(bytes.as_ptr(), System) };
}
}
}

#[inline(always)]
pub fn val(&self) -> usize {
let bytes_ref = self.0.as_ref().expect("bytes is invalid");
let bytes_ref = unsafe { bytes_ref.as_ref() };
bytes_ref.load(Ordering::Relaxed)
}
}

#[test]
fn test_to_avoid_double_drop() {
loom::model(|| {
let bytes_num = 3;
let mut num = Arc::new(TaskLocalBytesAllocated(Some(
NonNull::new(Box::leak(Box::new_in(bytes_num.into(), System))).unwrap(),
)));

// Add the flag value when counter drop so we can observe.
let flag_num = Arc::new(AtomicUsize::new(0));

let ths: Vec<_> = (0..bytes_num)
.map(|_| {
let num = num.clone();
let flag_num = flag_num.clone();
thread::spawn(move || {
num.sub(1, flag_num);
})
})
.collect();

for th in ths {
th.join().unwrap();
}

// Ensure the counter is dropped.
assert_eq!(flag_num.load(Ordering::Relaxed), 1);
});
}