Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/tracel-ai/burn into refacto…
Browse files Browse the repository at this point in the history
…r/wgpu-v23
  • Loading branch information
AsherJingkongChen committed Nov 30, 2024
2 parents 23d8679 + 3dc4b43 commit 6d57c23
Show file tree
Hide file tree
Showing 101 changed files with 2,592 additions and 842 deletions.
247 changes: 125 additions & 122 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ candle-cuda = ["burn/candle-cuda"]
candle-metal = ["burn/candle", "burn/metal"]
cuda-jit = ["burn/cuda-jit"]
cuda-jit-fusion = ["cuda-jit", "burn/fusion"]
hip-jit = ["burn/hip-jit"]
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
hip-jit = ["burn/hip-jit"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
Expand Down Expand Up @@ -54,6 +54,8 @@ strum_macros = { workspace = true }
sysinfo = { workspace = true, features = ["serde"] }
wgpu = { workspace = true }
wsl = { workspace = true }
tracing-subscriber = { workspace = true }
log = { workspace = true }

[dev-dependencies]
rstest = { workspace = true }
Expand Down
147 changes: 138 additions & 9 deletions backend-comparison/benches/conv2d.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use std::hint::black_box;

use backend_comparison::persistence::save;
use burn::tensor::{
backend::Backend, module::conv2d, ops::ConvOptions, Distribution, Shape, Tensor,
};
use burn_common::benchmark::{run_benchmark, Benchmark};

pub struct Conv2dBenchmark<B: Backend> {
suffix: &'static str,
input_shape: Shape,
weight_shape: Shape,
bias_shape: Shape,
Expand All @@ -16,7 +19,7 @@ impl<B: Backend> Benchmark for Conv2dBenchmark<B> {
type Args = (Tensor<B, 4>, Tensor<B, 4>, Tensor<B, 1>);

fn name(&self) -> String {
"conv2d".into()
format!("conv2d-{}", self.suffix)
}

fn shapes(&self) -> Vec<Vec<usize>> {
Expand Down Expand Up @@ -50,6 +53,10 @@ impl<B: Backend> Benchmark for Conv2dBenchmark<B> {
fn sync(&self) {
B::sync(&self.device)
}

fn num_samples(&self) -> usize {
40
}
}

#[allow(dead_code)]
Expand All @@ -75,6 +82,7 @@ fn bench<B: Backend>(
let groups = 1;
let options = ConvOptions::new(strides, padding, dilations, groups);
let benchmark = Conv2dBenchmark::<B> {
suffix: "input_16x512x512_weight_16x3x3_stride_1",
input_shape: [batch_size, channels_in, height_in, width_in].into(),
weight_shape: [
channels_out,
Expand All @@ -88,14 +96,135 @@ fn bench<B: Backend>(
device: device.clone(),
};

save::<B>(
vec![run_benchmark(benchmark)],
device,
feature_name,
url,
token,
)
.unwrap();
let conv1 = Conv2dBenchmark::<B> {
suffix: "input_3x227x227_weight_96x11x11_stride_4",
input_shape: [batch_size, 3, 227, 227].into(),
weight_shape: [96, 3, 11, 11].into(),
bias_shape: [96].into(),
options: ConvOptions::new([4, 4], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv2 = Conv2dBenchmark::<B> {
suffix: "input_3x231x231_weight_96x11x11_stride_4",
input_shape: [batch_size, 3, 231, 231].into(),
weight_shape: [96, 3, 11, 11].into(),
bias_shape: [96].into(),
options: ConvOptions::new([4, 4], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv3 = Conv2dBenchmark::<B> {
suffix: "input_3x227x227_weight_64x7x7_stride_2",
input_shape: [batch_size, 3, 227, 227].into(),
weight_shape: [64, 3, 7, 7].into(),
bias_shape: [64].into(),
options: ConvOptions::new([2, 2], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv4 = Conv2dBenchmark::<B> {
suffix: "input_64x224x224_weight_64x7x7_stride_2",
input_shape: [batch_size, 64, 224, 224].into(),
weight_shape: [64, 64, 7, 7].into(),
bias_shape: [64].into(),
options: ConvOptions::new([2, 2], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv5 = Conv2dBenchmark::<B> {
suffix: "input_96x24x24_weight_256x5x5_stride_1",
input_shape: [batch_size, 96, 24, 24].into(),
weight_shape: [256, 96, 5, 5].into(),
bias_shape: [256].into(),
options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv6 = Conv2dBenchmark::<B> {
suffix: "input_256x12x12_weight_512x3x3_stride_1",
input_shape: [batch_size, 256, 12, 12].into(),
weight_shape: [512, 256, 3, 3].into(),
bias_shape: [512].into(),
options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv7 = Conv2dBenchmark::<B> {
suffix: "input_3x224x224_weight_64x3x3_stride_1",
input_shape: [batch_size, 3, 224, 224].into(),
weight_shape: [64, 3, 3, 3].into(),
bias_shape: [64].into(),
options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv8 = Conv2dBenchmark::<B> {
suffix: "input_64x112x112_weight_128x3x3_stride_1",
input_shape: [batch_size, 64, 112, 112].into(),
weight_shape: [128, 64, 3, 3].into(),
bias_shape: [128].into(),
options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv9 = Conv2dBenchmark::<B> {
suffix: "input_64x56x56_weight_64x3x3_stride_1",
input_shape: [batch_size, 64, 56, 56].into(),
weight_shape: [64, 64, 3, 3].into(),
bias_shape: [64].into(),
options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv10 = Conv2dBenchmark::<B> {
suffix: "input_128x28x28_weight_128x3x3_stride_1",
input_shape: [batch_size, 128, 28, 28].into(),
weight_shape: [128, 128, 3, 3].into(),
bias_shape: [128].into(),
options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv11 = Conv2dBenchmark::<B> {
suffix: "input_256x14x14_weight_256x3x3_stride_1",
input_shape: [batch_size, 256, 14, 14].into(),
weight_shape: [256, 256, 3, 3].into(),
bias_shape: [256].into(),
options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv12 = Conv2dBenchmark::<B> {
suffix: "input_512x7x7_weight_512x3x3_stride_1",
input_shape: [batch_size, 512, 7, 7].into(),
weight_shape: [512, 512, 3, 3].into(),
bias_shape: [512].into(),
options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1),
device: device.clone(),
};

let conv13 = Conv2dBenchmark::<B> {
suffix: "input_96x224x224_weight_64x1x1_stride_1",
input_shape: [batch_size, 96, 224, 224].into(),
weight_shape: [64, 96, 1, 1].into(),
bias_shape: [64].into(),
options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1),
device: device.clone(),
};

let benches = vec![
benchmark, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8, conv9, conv10, conv11,
conv12, conv13,
];
let mut results = Vec::new();

for bench in benches {
let result = black_box(run_benchmark(bench));
results.push(result);
}

save::<B>(results, device, feature_name, url, token).unwrap();
}

fn main() {
Expand Down
42 changes: 18 additions & 24 deletions backend-comparison/benches/matmul.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn::tensor::{backend::Backend, Shape, Tensor};
use burn_common::benchmark::{run_benchmark, Benchmark};
use derive_new::new;

Expand All @@ -21,17 +21,13 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
vec![self.shape_lhs.dims.clone(), self.shape_rhs.dims.clone()]
}

fn num_samples(&self) -> usize {
10
}

fn execute(&self, (lhs, rhs): Self::Args) {
lhs.clone().matmul(rhs.clone());
lhs.matmul(rhs);
}

fn prepare(&self) -> Self::Args {
let lhs = Tensor::random(self.shape_lhs.clone(), Distribution::Default, &self.device);
let rhs = Tensor::random(self.shape_rhs.clone(), Distribution::Default, &self.device);
let lhs = Tensor::zeros(self.shape_lhs.clone(), &self.device);
let rhs = Tensor::zeros(self.shape_rhs.clone(), &self.device);

(lhs, rhs)
}
Expand All @@ -48,24 +44,22 @@ fn bench<B: Backend>(
url: Option<&str>,
token: Option<&str>,
) {
const D: usize = 3;
let batch_size = 8;
let m = 2048;
let k = 2048;
let n = 2048;
let shape_lhs = [batch_size, m, k].into();
let shape_rhs = [batch_size, k, n].into();
let benchmarks = [
(3, 4096, 4096, 4096),
(8, 2048, 2048, 2048),
(2, 4096, 4096, 512),
]
.into_iter()
.map(|(b, m, n, k)| {
let shape_lhs = [b, m, k].into();
let shape_rhs = [b, k, n].into();

let benchmark = MatmulBenchmark::<B, D>::new(shape_lhs, shape_rhs, device.clone());
MatmulBenchmark::<B, 3>::new(shape_lhs, shape_rhs, device.clone())
})
.map(run_benchmark)
.collect();

save::<B>(
vec![run_benchmark(benchmark)],
device,
feature_name,
url,
token,
)
.unwrap();
save::<B>(benchmarks, device, feature_name, url, token).unwrap();
}

fn main() {
Expand Down
27 changes: 27 additions & 0 deletions backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
use std::error::Error;

use tracing_subscriber::filter::LevelFilter;

pub mod burnbenchapp;
pub mod persistence;

Expand Down Expand Up @@ -26,10 +30,33 @@ pub fn get_sharing_url(args: &[String]) -> Option<&str> {
get_argument(args, "--sharing-url")
}

pub fn init_log() -> Result<(), Box<dyn Error + Send + Sync>> {
let result = tracing_subscriber::fmt()
.with_max_level(LevelFilter::DEBUG)
.without_time()
.try_init();

if result.is_ok() {
update_panic_hook();
}
result
}

fn update_panic_hook() {
let hook = std::panic::take_hook();

std::panic::set_hook(Box::new(move |info| {
log::error!("PANIC => {}", info.to_string());
hook(info);
}));
}

#[macro_export]
macro_rules! bench_on_backend {
() => {
use std::env;
backend_comparison::init_log().unwrap();

let args: Vec<String> = env::args().collect();
let url = backend_comparison::get_sharing_url(&args);
let token = backend_comparison::get_sharing_token(&args);
Expand Down
1 change: 1 addition & 0 deletions crates/burn-autodiff/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
type IntElem = B::IntElem;

type BoolTensorPrimitive = B::BoolTensorPrimitive;
type BoolElem = B::BoolElem;

type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive;
type QuantizedEncoding = B::QuantizedEncoding;
Expand Down
10 changes: 5 additions & 5 deletions crates/burn-autodiff/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,18 @@ macro_rules! testgen_all {

pub type FloatType = <TestBackend as burn_tensor::backend::Backend>::FloatElem;
pub type IntType = <TestBackend as burn_tensor::backend::Backend>::IntElem;
pub type BoolType = <TestBackend as burn_tensor::backend::Backend>::BoolTensorPrimitive;
pub type BoolType = <TestBackend as burn_tensor::backend::Backend>::BoolElem;

::paste::paste! {
$(mod [<$float _ty>] {
pub use super::*;

pub type TestBackend = TestBackend2<$float, IntType>;
pub type TestBackend = TestBackend2<$float, IntType, BoolType>;
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
pub type TestAutodiffTensor<const D: usize> = burn_tensor::Tensor<TestAutodiffBackend, D>;
pub type TestTensor<const D: usize> = TestTensor2<$float, IntType, D>;
pub type TestTensorInt<const D: usize> = TestTensorInt2<$float, IntType, D>;
pub type TestTensorBool<const D: usize> = TestTensorBool2<$float, IntType, D>;
pub type TestTensor<const D: usize> = TestTensor2<$float, IntType, BoolType, D>;
pub type TestTensorInt<const D: usize> = TestTensorInt2<$float, IntType, BoolType, D>;
pub type TestTensorBool<const D: usize> = TestTensorBool2<$float, IntType, BoolType, D>;

type FloatType = $float;

Expand Down
1 change: 1 addition & 0 deletions crates/burn-candle/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
type IntElem = I;

type BoolTensorPrimitive = CandleTensor;
type BoolElem = u32;

type QuantizedTensorPrimitive = CandleQTensor;
type QuantizedEncoding = u8;
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ blas-netlib = ["burn-ndarray?/blas-netlib"]
metal = ["burn-candle?/metal"]
openblas = ["burn-ndarray?/blas-openblas"]
openblas-system = ["burn-ndarray?/blas-openblas-system"]
template = ["burn-wgpu?/template"]
remote = ["burn-remote/client"]
router = ["burn-router"]
server = ["burn-remote/server"]
template = ["burn-wgpu?/template"]

candle = ["burn-candle"]
candle-cuda = ["candle", "burn-candle/cuda"]
Expand Down Expand Up @@ -138,10 +138,10 @@ burn-candle = { path = "../burn-candle", version = "0.16.0", optional = true }
burn-cuda = { path = "../burn-cuda", version = "0.16.0", optional = true, default-features = false }
burn-hip = { path = "../burn-hip", version = "0.16.0", optional = true, default-features = false }
burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true, default-features = false }
burn-tch = { path = "../burn-tch", version = "0.16.0", optional = true }
burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", optional = true, default-features = false }
burn-remote = { path = "../burn-remote", version = "0.16.0", default-features = false, optional = true }
burn-router = { path = "../burn-router", version = "0.16.0", default-features = false, optional = true }
burn-tch = { path = "../burn-tch", version = "0.16.0", optional = true }
burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", optional = true, default-features = false }

data-encoding = { workspace = true }
uuid = { workspace = true }
Expand Down
1 change: 1 addition & 0 deletions crates/burn-core/src/data/dataloader/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ where
}
}

/// Test batcher
#[cfg(test)]
#[derive(new, Clone)]
pub struct TestBatcher;
Expand Down
Loading

0 comments on commit 6d57c23

Please sign in to comment.