From 7cdd6b0e70dcb78237d18a92bf365bf75f7e63f9 Mon Sep 17 00:00:00 2001 From: Sylvain Benner Date: Thu, 28 Nov 2024 23:49:13 -0500 Subject: [PATCH 1/4] Fix xtask command with last version (#2566) --- Cargo.lock | 132 ++++++++++++++++----------------- xtask/src/commands/test.rs | 1 + xtask/src/commands/validate.rs | 2 + 3 files changed, 69 insertions(+), 66 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 381b15a512..ace556a7e8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -364,7 +364,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.1", "http-body-util", - "hyper 1.5.0", + "hyper 1.5.1", "hyper-util", "itoa", "matchit", @@ -378,7 +378,7 @@ dependencies = [ "serde_path_to_error", "serde_urlencoded", "sha1", - "sync_wrapper 1.0.1", + "sync_wrapper 1.0.2", "tokio", "tokio-tungstenite", "tower", @@ -402,7 +402,7 @@ dependencies = [ "mime", "pin-project-lite", "rustversion", - "sync_wrapper 1.0.1", + "sync_wrapper 1.0.2", "tower-layer", "tower-service", "tracing", @@ -542,9 +542,9 @@ checksum = "6099cdc01846bc367c4e7dd630dc5966dccf36b652fae7a74e17b640411a91b2" [[package]] name = "blake3" -version = "1.5.4" +version = "1.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d82033247fd8e890df8f740e407ad4d038debb9eb1f40533fffb32e7d17dc6f7" +checksum = "b8ee0c1824c4dea5b5f81736aff91bae041d2c07ee1192bec91054e10e3e601e" dependencies = [ "arrayref", "arrayvec", @@ -1025,9 +1025,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.8.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" +checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" [[package]] name = "bytesize" @@ -1334,9 +1334,9 @@ dependencies = [ [[package]] name = "cmake" -version = "0.1.51" +version = "0.1.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a" +checksum = "c682c223677e0e5b6b7f63a64b9351844c3f1b1678a68b7ee617e30fb082620e" dependencies = [ "cc", ] @@ -1527,9 +1527,9 @@ dependencies = [ [[package]] name = "cpufeatures" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ca741a962e1b0bff6d724a1a0958b686406e853bb14061f218562e1896f95e6" +checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" dependencies = [ "libc", ] @@ -2359,12 +2359,12 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.9" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -3087,9 +3087,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" +checksum = "ccae279728d634d083c00f6099cb58f01cc99c145b84b8be2f6c74618d79922e" dependencies = [ "atomic-waker", "bytes", @@ -3373,14 +3373,14 @@ dependencies = [ [[package]] name = "hyper" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbbff0a806a4728c99295b254c8838933b5b082d75e3cb70c8dab21fdfbcfa9a" +checksum = "97818827ef4f364230e16705d4706e2897df2bb60617d6ca15d598025a3c481f" dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.6", + "h2 0.4.7", "http 1.1.0", "http-body 1.0.1", "httparse", @@ -3400,7 +3400,7 @@ checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.5.0", + "hyper 1.5.1", "hyper-util", "rustls", "rustls-native-certs 0.8.1", @@ -3431,7 +3431,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ "bytes", "http-body-util", - "hyper 1.5.0", + "hyper 1.5.1", "hyper-util", "native-tls", "tokio", @@ -3450,7 +3450,7 @@ dependencies = [ "futures-util", "http 1.1.0", "http-body 1.0.1", - "hyper 1.5.0", + "hyper 1.5.1", "pin-project-lite", "socket2", "tokio", @@ -3825,9 +3825,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.11" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" [[package]] name = "itoap" @@ -3896,9 +3896,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" [[package]] name = "libc" -version = "0.2.165" +version = "0.2.166" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcb4d3d38eab6c5239a362fa8bae48c03baf980a6e7079f063942d563ef3533e" +checksum = "c2ccc108bbc0b1331bd061864e7cd823c0cab660bbe6970e66e2c0614decde36" [[package]] name = "libfuzzer-sys" @@ -3956,9 +3956,9 @@ checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" [[package]] name = "litemap" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704" +checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" [[package]] name = "litrs" @@ -4737,7 +4737,7 @@ dependencies = [ "chrono", "futures", "humantime", - "hyper 1.5.0", + "hyper 1.5.1", "itertools 0.13.0", "md-5", "parking_lot 0.12.3", @@ -5648,9 +5648,9 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" +checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" [[package]] name = "portable-atomic-util" @@ -6390,11 +6390,11 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2 0.4.6", + "h2 0.4.7", "http 1.1.0", "http-body 1.0.1", "http-body-util", - "hyper 1.5.0", + "hyper 1.5.1", "hyper-rustls", "hyper-tls 0.6.0", "hyper-util", @@ -6414,7 +6414,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "sync_wrapper 1.0.1", + "sync_wrapper 1.0.2", "system-configuration 0.6.1", "tokio", "tokio-native-tls", @@ -6568,9 +6568,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.40" +version = "0.38.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99e4ea3e1cdc4b559b8e5650f9c8e5998e3e5c1343b4eaf034565f32318d63c0" +checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6" dependencies = [ "bitflags 2.6.0", "errno", @@ -6581,9 +6581,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.18" +version = "0.23.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c9cc1d47e243d655ace55ed38201c19ae02c148ae56412ab8750e8f0166ab7f" +checksum = "934b404430bb06b3fae2cba809eb45a1ab1aecd64491213d7c3301b88393f8d1" dependencies = [ "log", "once_cell", @@ -6728,9 +6728,9 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" dependencies = [ "windows-sys 0.59.0", ] @@ -7106,9 +7106,9 @@ checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" [[package]] name = "socket2" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" dependencies = [ "libc", "windows-sys 0.52.0", @@ -7279,9 +7279,9 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] name = "sync_wrapper" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" dependencies = [ "futures-core", ] @@ -7648,9 +7648,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokenizers" -version = "0.20.3" +version = "0.20.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67b67c92f6d705e2a1d106fb0b28c696f9074901a9c656ee5d9f5de204c39bf7" +checksum = "3b08cc37428a476fc9e20ac850132a513a2e1ce32b6a31addf2b74fa7033b905" dependencies = [ "aho-corasick", "derive_builder", @@ -7830,9 +7830,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracel-xtask" -version = "1.1.6" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4126466aafe1c518cb5c23979c286903cb1d1ff1bc3b76891254a243a0ed1e15" +checksum = "8aab3144d5269d5c34522f9064c096681e446913bec143beeb2d755f6f4e834f" dependencies = [ "anyhow", "clap 4.5.21", @@ -7849,9 +7849,9 @@ dependencies = [ [[package]] name = "tracel-xtask-macros" -version = "1.1.6" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35c68844637b0e748a78eaa0b37981ec7cf16016a5886d441233618c1b2e588a" +checksum = "33e0cd750c845b57d39ad8ee75080dbf8c3a18fbd0a07ef3c59339498766c1fb" dependencies = [ "proc-macro2", "quote", @@ -7860,9 +7860,9 @@ dependencies = [ [[package]] name = "tracing" -version = "0.1.40" +version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ "log", "pin-project-lite", @@ -7884,9 +7884,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.27" +version = "0.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", @@ -8006,9 +8006,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" +checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" [[package]] name = "unicode-normalization" @@ -8355,9 +8355,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.6" +version = "0.26.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841c67bff177718f1d4dfefde8d8f0e78f9b6589319ba88312f567fc5841a958" +checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e" dependencies = [ "rustls-pki-types", ] @@ -8880,9 +8880,9 @@ checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" [[package]] name = "yoke" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5" +checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" dependencies = [ "serde", "stable_deref_trait", @@ -8892,9 +8892,9 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" +checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", @@ -8925,18 +8925,18 @@ dependencies = [ [[package]] name = "zerofrom" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55" +checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" dependencies = [ "zerofrom-derive", ] [[package]] name = "zerofrom-derive" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" +checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", diff --git a/xtask/src/commands/test.rs b/xtask/src/commands/test.rs index c88c352979..47e50f80ed 100644 --- a/xtask/src/commands/test.rs +++ b/xtask/src/commands/test.rs @@ -123,6 +123,7 @@ pub(crate) fn handle_command( jobs: args.jobs, ci: args.ci, features: args.features.clone(), + no_default_features: args.no_default_features, }, env, ) diff --git a/xtask/src/commands/validate.rs b/xtask/src/commands/validate.rs index dedb40828a..49943a538d 100644 --- a/xtask/src/commands/validate.rs +++ b/xtask/src/commands/validate.rs @@ -56,6 +56,7 @@ pub fn handle_command( command: Some(TestSubCommand::All), ci: true, features: None, + no_default_features: false, }, ExecutionEnvironment::Std, )?; @@ -103,6 +104,7 @@ pub fn handle_command( command: Some(TestSubCommand::All), ci: true, features: None, + no_default_features: false, }, ExecutionEnvironment::NoStd, )?; From 42e7c1f225a7c3e2be502e437fc417c5f61f2589 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Fri, 29 Nov 2024 15:21:13 +0100 Subject: [PATCH 2/4] [Feat] 8-bit bool for JitBackend (#2526) --- crates/burn-autodiff/src/backend.rs | 1 + crates/burn-autodiff/src/tests/mod.rs | 10 +- crates/burn-candle/src/backend.rs | 1 + crates/burn-cuda/src/lib.rs | 6 +- crates/burn-fusion/src/backend.rs | 6 +- crates/burn-hip/src/lib.rs | 4 +- crates/burn-jit/src/backend.rs | 24 +++- crates/burn-jit/src/element.rs | 24 ++++ crates/burn-jit/src/fusion/base.rs | 30 +++-- .../burn-jit/src/fusion/elemwise/builder.rs | 9 +- .../src/fusion/elemwise/optimization.rs | 6 +- .../burn-jit/src/fusion/on_write/builder.rs | 12 +- crates/burn-jit/src/fusion/on_write/trace.rs | 28 +++-- .../src/fusion/on_write/trace_builder.rs | 20 +-- crates/burn-jit/src/kernel/cast/bool_cast.rs | 16 +-- crates/burn-jit/src/kernel/comparison.rs | 119 +++++++++++------- .../burn-jit/src/kernel/conv/conv2d/base.rs | 18 ++- .../burn-jit/src/kernel/conv/conv2d/col2im.rs | 37 ++++-- .../burn-jit/src/kernel/conv/conv2d/direct.rs | 5 +- .../burn-jit/src/kernel/conv/conv2d/im2col.rs | 24 ++-- .../src/kernel/conv/conv2d/implicit_gemm.rs | 5 +- .../kernel/conv/conv2d/transpose_direct.rs | 5 +- .../src/kernel/conv/conv2d/tune/conv2d.rs | 14 +-- .../conv/conv2d/tune/conv_transpose2d.rs | 12 +- .../burn-jit/src/kernel/conv/deform_conv2d.rs | 25 ++-- .../kernel/conv/deform_conv_transpose2d.rs | 42 ++++--- crates/burn-jit/src/kernel/index/flip.rs | 16 +-- crates/burn-jit/src/kernel/mask/base.rs | 10 +- crates/burn-jit/src/kernel/mask/mask_fill.rs | 54 ++++---- crates/burn-jit/src/kernel/mask/mask_where.rs | 66 +++++----- crates/burn-jit/src/lib.rs | 2 +- crates/burn-jit/src/ops/activation_ops.rs | 5 +- crates/burn-jit/src/ops/base.rs | 13 +- crates/burn-jit/src/ops/bool_ops.rs | 29 ++--- crates/burn-jit/src/ops/float_ops.rs | 38 +++--- crates/burn-jit/src/ops/int_ops.rs | 34 ++--- crates/burn-jit/src/ops/module_ops.rs | 10 +- crates/burn-jit/src/ops/qtensor.rs | 4 +- crates/burn-jit/src/ops/transaction.rs | 9 +- crates/burn-jit/src/tensor/base.rs | 4 +- crates/burn-jit/src/tensor/qtensor.rs | 11 +- crates/burn-jit/src/tests/mask_fill.rs | 2 + crates/burn-jit/src/tests/mask_where.rs | 2 + crates/burn-jit/src/tests/mod.rs | 50 ++++---- crates/burn-ndarray/src/backend.rs | 1 + crates/burn-remote/src/client/channel.rs | 2 + crates/burn-router/src/backend.rs | 2 + crates/burn-router/src/channel/base.rs | 2 + crates/burn-router/src/types.rs | 1 + crates/burn-tch/src/backend.rs | 1 + crates/burn-tensor/src/tensor/backend/base.rs | 2 + crates/burn-tensor/src/tests/mod.rs | 57 +++++++-- crates/burn-tensor/src/tests/ops/remainder.rs | 2 +- crates/burn-wgpu/src/lib.rs | 17 ++- .../examples/custom-cubecl-kernel.rs | 2 +- examples/custom-cubecl-kernel/src/backward.rs | 6 +- examples/custom-cubecl-kernel/src/forward.rs | 7 +- .../examples/custom-wgpu-kernel.rs | 2 +- examples/custom-wgpu-kernel/src/backward.rs | 7 +- examples/custom-wgpu-kernel/src/forward.rs | 8 +- 60 files changed, 597 insertions(+), 384 deletions(-) diff --git a/crates/burn-autodiff/src/backend.rs b/crates/burn-autodiff/src/backend.rs index e1f41e77a2..78bfe301f5 100644 --- a/crates/burn-autodiff/src/backend.rs +++ b/crates/burn-autodiff/src/backend.rs @@ -30,6 +30,7 @@ impl Backend for Autodiff { type IntElem = B::IntElem; type BoolTensorPrimitive = B::BoolTensorPrimitive; + type BoolElem = B::BoolElem; type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive; type QuantizedEncoding = B::QuantizedEncoding; diff --git a/crates/burn-autodiff/src/tests/mod.rs b/crates/burn-autodiff/src/tests/mod.rs index 1980550157..438adfb98e 100644 --- a/crates/burn-autodiff/src/tests/mod.rs +++ b/crates/burn-autodiff/src/tests/mod.rs @@ -90,18 +90,18 @@ macro_rules! testgen_all { pub type FloatType = ::FloatElem; pub type IntType = ::IntElem; - pub type BoolType = ::BoolTensorPrimitive; + pub type BoolType = ::BoolElem; ::paste::paste! { $(mod [<$float _ty>] { pub use super::*; - pub type TestBackend = TestBackend2<$float, IntType>; + pub type TestBackend = TestBackend2<$float, IntType, BoolType>; pub type TestAutodiffBackend = burn_autodiff::Autodiff; pub type TestAutodiffTensor = burn_tensor::Tensor; - pub type TestTensor = TestTensor2<$float, IntType, D>; - pub type TestTensorInt = TestTensorInt2<$float, IntType, D>; - pub type TestTensorBool = TestTensorBool2<$float, IntType, D>; + pub type TestTensor = TestTensor2<$float, IntType, BoolType, D>; + pub type TestTensorInt = TestTensorInt2<$float, IntType, BoolType, D>; + pub type TestTensorBool = TestTensorBool2<$float, IntType, BoolType, D>; type FloatType = $float; diff --git a/crates/burn-candle/src/backend.rs b/crates/burn-candle/src/backend.rs index e03b26474c..1ad606b910 100644 --- a/crates/burn-candle/src/backend.rs +++ b/crates/burn-candle/src/backend.rs @@ -168,6 +168,7 @@ impl Backend for Candle { type IntElem = I; type BoolTensorPrimitive = CandleTensor; + type BoolElem = u32; type QuantizedTensorPrimitive = CandleQTensor; type QuantizedEncoding = u8; diff --git a/crates/burn-cuda/src/lib.rs b/crates/burn-cuda/src/lib.rs index 086d00bab7..030c2d9ff1 100644 --- a/crates/burn-cuda/src/lib.rs +++ b/crates/burn-cuda/src/lib.rs @@ -7,10 +7,10 @@ pub use cubecl::cuda::CudaDevice; use cubecl::cuda::CudaRuntime; #[cfg(not(feature = "fusion"))] -pub type Cuda = JitBackend; +pub type Cuda = JitBackend; #[cfg(feature = "fusion")] -pub type Cuda = burn_fusion::Fusion>; +pub type Cuda = burn_fusion::Fusion>; #[cfg(test)] mod tests { @@ -19,5 +19,5 @@ mod tests { pub type TestRuntime = cubecl::cuda::CudaRuntime; pub use half::{bf16, f16}; - burn_jit::testgen_all!([f16, bf16, f32], [i8, i16, i32, i64]); + burn_jit::testgen_all!([f16, bf16, f32], [i8, i16, i32, i64], [u8, u32]); } diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index aa72ba7dbc..aa308ad9a7 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -5,7 +5,7 @@ use burn_tensor::{ backend::{Backend, DeviceOps}, ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}, repr::{OperationDescription, QuantizedKind, ReprBackend, TensorHandle}, - Device, + Device, Element, }; use serde::{de::DeserializeOwned, Serialize}; use std::marker::PhantomData; @@ -35,6 +35,8 @@ impl Backend for Fusion { type BoolTensorPrimitive = FusionTensor; + type BoolElem = B::BoolElem; + type QuantizedTensorPrimitive = QFusionTensor; type QuantizedEncoding = B::QuantizedEncoding; @@ -142,6 +144,8 @@ pub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug { type FusionDevice: DeviceOps; /// The client to interact with the runtime. type FusionClient: FusionClient; + /// The type that represents booleans on the backend. + type BoolRepr: Element; /// The list of optimizations that will be used to optimize the computational graph. fn optimizations( diff --git a/crates/burn-hip/src/lib.rs b/crates/burn-hip/src/lib.rs index 89b91243c6..fc8f704e74 100644 --- a/crates/burn-hip/src/lib.rs +++ b/crates/burn-hip/src/lib.rs @@ -12,11 +12,11 @@ use cubecl::hip::HipRuntime; #[cfg(target_os = "linux")] #[cfg(not(feature = "fusion"))] -pub type Hip = JitBackend; +pub type Hip = JitBackend; #[cfg(target_os = "linux")] #[cfg(feature = "fusion")] -pub type Hip = burn_fusion::Fusion>; +pub type Hip = burn_fusion::Fusion>; // TODO: Hang the computer when AMD isn't available. // diff --git a/crates/burn-jit/src/backend.rs b/crates/burn-jit/src/backend.rs index 23629c4a9e..b455d859a0 100644 --- a/crates/burn-jit/src/backend.rs +++ b/crates/burn-jit/src/backend.rs @@ -1,4 +1,5 @@ use crate::{ + element::BoolElement, tensor::{JitTensor, QJitTensor}, FloatElement, IntElement, JitRuntime, }; @@ -18,24 +19,27 @@ pub(crate) static SEED: Mutex> = Mutex::new(None); /// Generic tensor backend that can be compiled just-in-time to any shader runtime #[derive(new)] -pub struct JitBackend { +pub struct JitBackend { _runtime: PhantomData, _float_elem: PhantomData, _int_elem: PhantomData, + _bool_elem: PhantomData, } -impl Backend for JitBackend +impl Backend for JitBackend where R: JitRuntime, R::Server: ComputeServer, R::Device: burn_tensor::backend::DeviceOps, F: FloatElement, I: IntElement, + BT: BoolElement, { type Device = R::Device; type FloatElem = F; type IntElem = I; + type BoolElem = BT; type FloatTensorPrimitive = JitTensor; type IntTensorPrimitive = JitTensor; @@ -63,19 +67,25 @@ where } } -impl core::fmt::Debug for JitBackend { +impl core::fmt::Debug + for JitBackend +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!("JitBackend {{ runtime: {}}}", R::name())) } } -impl Clone for JitBackend { +impl Clone + for JitBackend +{ fn clone(&self) -> Self { Self::new() } } -impl Default for JitBackend { +impl Default + for JitBackend +{ fn default() -> Self { Self::new() } @@ -90,7 +100,9 @@ where } #[cfg(not(feature = "fusion"))] -impl ReprBackend for JitBackend { +impl ReprBackend + for JitBackend +{ type Handle = HandleKind; fn float_tensor(handle: TensorHandle) -> FloatTensor { diff --git a/crates/burn-jit/src/element.rs b/crates/burn-jit/src/element.rs index 939b2fb24e..f0e15352cf 100644 --- a/crates/burn-jit/src/element.rs +++ b/crates/burn-jit/src/element.rs @@ -13,6 +13,27 @@ pub trait FloatElement: JitElement + Float {} /// The int element type for the jit backend. pub trait IntElement: JitElement + Int {} +/// The element type for booleans for the jit backend. +pub trait BoolElement: JitElement + Int { + /// The true value for the boolean element. + fn true_val() -> Self { + Self::from_int(1) + } + + /// The false value for the boolean element. + fn false_val() -> Self { + Self::from_int(0) + } + + /// New bool element from Rust bool. + fn new_bool(val: bool) -> Self { + match val { + true => Self::true_val(), + false => Self::false_val(), + } + } +} + impl JitElement for u64 {} impl JitElement for u32 {} impl JitElement for u16 {} @@ -36,3 +57,6 @@ impl IntElement for i64 {} impl IntElement for i32 {} impl IntElement for i16 {} impl IntElement for i8 {} + +impl BoolElement for u8 {} +impl BoolElement for u32 {} diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 7968626e89..4572f580b5 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -1,6 +1,6 @@ use super::elemwise::optimization::{ElemwiseOptimization, ElemwiseOptimizationState}; -use crate::fusion::elemwise::builder::ElementWiseBuilder; use crate::tensor::{JitQuantizationParameters, QJitTensor}; +use crate::{element::BoolElement, fusion::elemwise::builder::ElementWiseBuilder}; use crate::{kernel, tensor::JitTensor, FloatElement, IntElement, JitBackend, JitRuntime}; use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; use burn_tensor::quantization::QuantizationScheme; @@ -30,13 +30,14 @@ pub enum JitOptimizationState { ElementWise(ElemwiseOptimizationState), } -impl burn_fusion::Optimization> for JitOptimization +impl burn_fusion::Optimization> for JitOptimization where R: JitRuntime, + BT: BoolElement, { fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitFusionHandle>) { match self { - Self::ElementWise2(op) => op.execute(context), + Self::ElementWise2(op) => op.execute::(context), } } @@ -61,7 +62,9 @@ where } } -impl ReprBackend for JitBackend { +impl ReprBackend + for JitBackend +{ type Handle = JitFusionHandle; fn float_tensor(handle: TensorHandle) -> burn_tensor::ops::FloatTensor { @@ -122,30 +125,37 @@ impl ReprBackend for JitBackend FusionRuntime for FusionJitRuntime { +impl FusionRuntime for FusionJitRuntime { type OptimizationState = JitOptimizationState; type Optimization = JitOptimization; type FusionHandle = JitFusionHandle; type FusionDevice = R::JitDevice; type FusionClient = MutexFusionClient; + type BoolRepr = BT; fn optimizations( device: R::Device, ) -> Vec>> { - vec![Box::new(ElementWiseBuilder::::new(device.clone()))] + vec![Box::new(ElementWiseBuilder::::new( + device.clone(), + BT::as_elem().into(), + ))] } } /// Fusion runtime for JIT runtimes. #[derive(Debug)] -pub struct FusionJitRuntime { +pub struct FusionJitRuntime { _b: PhantomData, + _bool: PhantomData, } -impl FusionBackend for JitBackend { - type FusionRuntime = FusionJitRuntime; +impl FusionBackend + for JitBackend +{ + type FusionRuntime = FusionJitRuntime; - type FullPrecisionBackend = JitBackend; + type FullPrecisionBackend = JitBackend; fn cast_float( tensor: burn_tensor::ops::FloatTensor, diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index 6766e3000a..e37196bc2a 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -1,7 +1,10 @@ use burn_fusion::OptimizationBuilder; use crate::{ - fusion::{on_write::builder::FuseOnWriteBuilder, JitOptimization}, + fusion::{ + on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision}, + JitOptimization, + }, JitRuntime, }; @@ -14,13 +17,13 @@ pub(crate) struct ElementWiseBuilder { } impl ElementWiseBuilder { - pub fn new(device: R::Device) -> Self { + pub fn new(device: R::Device, bool_precision: ElemwisePrecision) -> Self { let client = R::client(&device); let props = client.properties(); let max_bindings = props.hardware_properties().max_bindings; Self { - builder: FuseOnWriteBuilder::new(max_bindings), + builder: FuseOnWriteBuilder::new(max_bindings, bool_precision), device, } } diff --git a/crates/burn-jit/src/fusion/elemwise/optimization.rs b/crates/burn-jit/src/fusion/elemwise/optimization.rs index f5f3000926..d3e8e35b50 100644 --- a/crates/burn-jit/src/fusion/elemwise/optimization.rs +++ b/crates/burn-jit/src/fusion/elemwise/optimization.rs @@ -1,4 +1,4 @@ -use crate::fusion::on_write::kernel::fuse_on_write; +use crate::{fusion::on_write::kernel::fuse_on_write, BoolElement}; use crate::{fusion::JitFusionHandle, JitRuntime}; use burn_fusion::stream::Context; use burn_tensor::repr::TensorDescription; @@ -28,9 +28,9 @@ pub struct ElemwiseOptimizationState { impl ElemwiseOptimization { /// Execute the optimization. - pub fn execute(&mut self, context: &mut Context<'_, JitFusionHandle>) { + pub fn execute(&mut self, context: &mut Context<'_, JitFusionHandle>) { self.trace - .run::(&self.client, &self.device, context) + .run::(&self.client, &self.device, context) } /// Number of element wise operations fused. diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index 1bd167af90..287b656274 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -1,5 +1,5 @@ use super::{ - ir::{Arg, BinaryElemwiseArgs, ElemwiseOp, UnaryElemwiseArgs}, + ir::{Arg, BinaryElemwiseArgs, ElemwiseOp, ElemwisePrecision, UnaryElemwiseArgs}, trace::FuseOnWriteTrace, trace_builder::FuseOnWriteTraceBuilder, }; @@ -30,9 +30,9 @@ struct TryFuseBuilder { } impl TryFuseBuilder { - fn new(max_bindings: u32) -> Self { + fn new(max_bindings: u32, bool_precision: ElemwisePrecision) -> Self { Self { - builder: FuseOnWriteTraceBuilder::new(), + builder: FuseOnWriteTraceBuilder::new(bool_precision), max_bindings, added_ops: false, } @@ -118,7 +118,7 @@ impl OptimizationBuilder for FuseOnWriteBuilder { fn reset(&mut self) { self.num_ops = 0; self.status = OptimizationStatus::Open; - self.builder = TryFuseBuilder::new(self.max_bindings); + self.builder = TryFuseBuilder::new(self.max_bindings, self.builder.builder.bool_precision); self.current_output_shape.clear(); } @@ -137,9 +137,9 @@ impl OptimizationBuilder for FuseOnWriteBuilder { } impl FuseOnWriteBuilder { - pub fn new(max_bindings: u32) -> Self { + pub fn new(max_bindings: u32, bool_precision: ElemwisePrecision) -> Self { Self { - builder: TryFuseBuilder::new(max_bindings), + builder: TryFuseBuilder::new(max_bindings, bool_precision), num_ops: 0, max_bindings, current_output_shape: Vec::new(), diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index 591cc9c347..d9ec09aea8 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -1,6 +1,6 @@ use crate::{ fusion::{on_write::ir::LayoutInfo, strides_dyn_rank, JitFusionHandle}, - JitRuntime, + BoolElement, JitRuntime, }; use super::ir::{Arg, ElemwiseConfig, ElemwiseOp, ElemwisePrecision, GlobalArgsLaunch}; @@ -90,16 +90,17 @@ struct PotentialInplace<'a> { impl FuseOnWriteTrace { /// Run a trace with the given [runner](TraceRunner). - pub fn run>( + pub fn run>( &self, client: &ComputeClient, device: &R::Device, context: &mut Context<'_, JitFusionHandle>, ) { - let analysis = self.analyse::(client, device, context); + let analysis = self.analyse::(client, device, context); let inputs = self.register_inputs(context, &analysis.handle_inputs, analysis.vectorization); - let outputs = self.register_outputs(&analysis.handle_outputs, analysis.vectorization); + let outputs = + self.register_outputs::<_, BT>(&analysis.handle_outputs, analysis.vectorization); let mut ops = Sequence::new(); for op in analysis.reads.into_values() { @@ -126,7 +127,7 @@ impl FuseOnWriteTrace { Runner::run(client, inputs, outputs, config) } - fn analyse<'a, 'c, R: JitRuntime, Runner: TraceRunner>( + fn analyse<'a, 'c, R: JitRuntime, BT: BoolElement, Runner: TraceRunner>( &'a self, client: &ComputeClient, device: &R::Device, @@ -146,7 +147,7 @@ impl FuseOnWriteTrace { }; self.analyse_inputs(context, &mut analysis); - self.analyse_outputs(client, device, context, &mut analysis); + self.analyse_outputs::<_, BT>(client, device, context, &mut analysis); analysis.vectorization = Runner::vectorization( analysis.handle_inputs.iter().map(|item| &item.handle), @@ -189,7 +190,7 @@ impl FuseOnWriteTrace { } } - fn analyse_outputs<'a, 'c, R: JitRuntime>( + fn analyse_outputs<'a, 'c, R: JitRuntime, BT: BoolElement>( &'a self, client: &ComputeClient, device: &R::Device, @@ -273,9 +274,9 @@ impl FuseOnWriteTrace { } } - // We encode bool tensors as u32. + // We encode bool tensors as `B`. let dtype = match tensor_global.dtype { - DType::Bool => DType::U32, + DType::Bool => BT::dtype(), _ => tensor_global.dtype, }; let size = tensor_global.shape.iter().product::() * Elem::from(dtype).size(); @@ -406,7 +407,7 @@ impl FuseOnWriteTrace { inputs } - fn register_outputs<'s, R: JitRuntime>( + fn register_outputs<'s, R: JitRuntime, BT: BoolElement>( &self, handle_outputs: &'s [HandleOutput<'_, R>], vectorization: u8, @@ -473,8 +474,11 @@ impl FuseOnWriteTrace { ElemwisePrecision::U32 => outputs.t_u32.push(arg), ElemwisePrecision::U16 => outputs.t_u16.push(arg), ElemwisePrecision::U8 => outputs.t_u8.push(arg), - // Bools are encoded as u32. - ElemwisePrecision::Bool => outputs.t_u32.push(arg), + ElemwisePrecision::Bool => match BT::dtype() { + DType::U32 => outputs.t_u32.push(arg), + DType::U8 => outputs.t_u8.push(arg), + _ => todo!(), + }, }; } } diff --git a/crates/burn-jit/src/fusion/on_write/trace_builder.rs b/crates/burn-jit/src/fusion/on_write/trace_builder.rs index 06e8d24e15..5cb427814d 100644 --- a/crates/burn-jit/src/fusion/on_write/trace_builder.rs +++ b/crates/burn-jit/src/fusion/on_write/trace_builder.rs @@ -16,10 +16,11 @@ pub struct FuseOnWriteTraceBuilder { scalars: BTreeMap, ops: Vec, reads: BTreeMap, + pub bool_precision: ElemwisePrecision, } impl FuseOnWriteTraceBuilder { - pub fn new() -> Self { + pub fn new(bool_precision: ElemwisePrecision) -> Self { Self { locals: Locals::default(), outputs: RegisteredTensors::default(), @@ -27,6 +28,7 @@ impl FuseOnWriteTraceBuilder { scalars: BTreeMap::default(), ops: Vec::new(), reads: BTreeMap::new(), + bool_precision, } } @@ -49,9 +51,9 @@ impl FuseOnWriteTraceBuilder { pub fn input(&mut self, tensor: &TensorDescription) -> Arg { let precision = tensor.dtype.into(); - // Bool tensors are encoded as u32. + // Bool tensors are encoded as bool_precision. let precision_input = match precision { - ElemwisePrecision::Bool => ElemwisePrecision::U32, + ElemwisePrecision::Bool => self.bool_precision, _ => precision, }; @@ -82,9 +84,9 @@ impl FuseOnWriteTraceBuilder { pub fn output(&mut self, tensor: &TensorDescription) -> Arg { let precision = tensor.dtype.into(); - // Bool tensors are encoded as u32. + // Bool tensors are encoded as bool_precision. let precision_output = match precision { - ElemwisePrecision::Bool => ElemwisePrecision::U32, + ElemwisePrecision::Bool => self.bool_precision, _ => precision, }; @@ -103,9 +105,9 @@ impl FuseOnWriteTraceBuilder { pub fn scalar(&mut self, _: &E, dtype: DType) -> Arg { let precision = dtype.into(); - // Bool scalars are encoded as u32. + // Bool scalars are encoded as bool_precision. let precision = match precision { - ElemwisePrecision::Bool => ElemwisePrecision::U32, + ElemwisePrecision::Bool => self.bool_precision, _ => precision, }; let new_index = self.scalars.get(&precision).copied().unwrap_or(0); @@ -154,9 +156,9 @@ impl FuseOnWriteTraceBuilder { let mark = |var: &Arg, list: &mut Vec<(TensorId, ElemwisePrecision)>| { if let Arg::Local(index, precision) = var { if let Some(tensor_id) = self.locals.find_tensor_id(*precision, *index) { - // Input and outputs tensors are using u32 for booleans. + // Input and outputs tensors are using bool_precision for booleans. let precision = match precision { - ElemwisePrecision::Bool => ElemwisePrecision::U32, + ElemwisePrecision::Bool => self.bool_precision, _ => *precision, }; diff --git a/crates/burn-jit/src/kernel/cast/bool_cast.rs b/crates/burn-jit/src/kernel/cast/bool_cast.rs index 07a915ee1f..74e55888e1 100644 --- a/crates/burn-jit/src/kernel/cast/bool_cast.rs +++ b/crates/burn-jit/src/kernel/cast/bool_cast.rs @@ -1,9 +1,9 @@ -use crate::{tensor::JitTensor, JitElement, JitRuntime}; +use crate::{tensor::JitTensor, BoolElement, JitElement, JitRuntime}; use cubecl::{calculate_cube_count_elemwise, prelude::*, CubeDim}; #[cube(launch)] -fn bool_cast_kernel(input: &Tensor, output: &mut Tensor) { - if input[ABSOLUTE_POS] >= 1 { +fn bool_cast_kernel(input: &Tensor, output: &mut Tensor) { + if input[ABSOLUTE_POS] >= B::from_int(1) { output[ABSOLUTE_POS] = T::from_int(1); } else { output[ABSOLUTE_POS] = T::from_int(0); @@ -12,11 +12,13 @@ fn bool_cast_kernel(input: &Tensor, output: &mut Tensor) { /// Cast a bool tensor to the given element type. /// -/// This alternative to cast is necessary because bool are represented as u32 +/// This alternative to cast is necessary because bool are represented as u32 or u8 /// where any non-zero value means true. Depending how it was created /// it may hold an uncanny bit combination. Naively casting it would not /// necessarily yield 0 or 1. -pub fn bool_cast(tensor: JitTensor) -> JitTensor { +pub fn bool_cast( + tensor: JitTensor, +) -> JitTensor { let num_elems = tensor.shape.num_elements(); let buffer = tensor.client.empty(num_elems * core::mem::size_of::()); let output = JitTensor::new_contiguous( @@ -30,11 +32,11 @@ pub fn bool_cast(tensor: JitTensor) -> JitTens let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim); - bool_cast_kernel::launch::( + bool_cast_kernel::launch::( &tensor.client, cube_count, cube_dim, - tensor.as_tensor_arg::(1), + tensor.as_tensor_arg::(1), output.as_tensor_arg::(1), ); diff --git a/crates/burn-jit/src/kernel/comparison.rs b/crates/burn-jit/src/kernel/comparison.rs index 420a74d81b..007d4200d9 100644 --- a/crates/burn-jit/src/kernel/comparison.rs +++ b/crates/burn-jit/src/kernel/comparison.rs @@ -1,5 +1,7 @@ -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; -use burn_tensor::{DType, Shape}; +use crate::{ + element::JitElement, ops::numeric::empty_device, tensor::JitTensor, BoolElement, JitRuntime, +}; +use burn_tensor::Shape; use cubecl::{ calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, tensor_vectorization_factor, @@ -55,10 +57,10 @@ impl ComparisonOp for LowerOp { } #[cube(launch)] -pub(crate) fn kernel_scalar_cmp>( +pub(crate) fn kernel_scalar_cmp>( input: &Tensor>, scalar: C, - output: &mut Tensor>, + output: &mut Tensor>, ) { let offset_output = ABSOLUTE_POS; @@ -70,10 +72,10 @@ pub(crate) fn kernel_scalar_cmp>( } #[cube(launch)] -pub(crate) fn kernel_cmp>( +pub(crate) fn kernel_cmp>( lhs: &Tensor>, rhs: &Tensor>, - out: &mut Tensor>, + out: &mut Tensor>, #[comptime] rank: Option, #[comptime] to_contiguous_lhs: bool, #[comptime] to_contiguous_rhs: bool, @@ -87,7 +89,7 @@ pub(crate) fn kernel_cmp>( } if to_contiguous_lhs { - offset_lhs = index_offset_with_layout::( + offset_lhs = index_offset_with_layout::( lhs, out, offset_out, @@ -98,7 +100,7 @@ pub(crate) fn kernel_cmp>( } if to_contiguous_rhs { - offset_rhs = index_offset_with_layout::( + offset_rhs = index_offset_with_layout::( rhs, out, offset_out, @@ -111,7 +113,7 @@ pub(crate) fn kernel_cmp>( out[offset_out] = Line::cast_from(O::execute(lhs[offset_lhs], rhs[offset_rhs])); } -pub(crate) fn launch_cmp>( +pub(crate) fn launch_cmp>( lhs: JitTensor, rhs: JitTensor, ) -> JitTensor { @@ -141,9 +143,9 @@ pub(crate) fn launch_cmp>( let cube_count = calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); - let same_tensor_type = core::any::TypeId::of::() == core::any::TypeId::of::(); + let same_tensor_type = core::any::TypeId::of::() == core::any::TypeId::of::(); if same_tensor_type && lhs.can_mut_broadcast(&rhs) { - kernel_cmp::launch::( + kernel_cmp::launch::( &client, cube_count, cube_dim, @@ -161,10 +163,10 @@ pub(crate) fn launch_cmp>( lhs.shape, lhs.device, lhs.strides, - DType::U32, + BT::dtype(), ) } else if same_tensor_type && rhs.can_mut_broadcast(&lhs) { - kernel_cmp::launch::( + kernel_cmp::launch::( &client, cube_count, CubeDim::default(), @@ -182,20 +184,20 @@ pub(crate) fn launch_cmp>( rhs.shape, rhs.device, rhs.strides, - DType::U32, + BT::dtype(), ) } else { - let output = empty_device::(lhs.client.clone(), lhs.device.clone(), shape_out); + let output = empty_device::(lhs.client.clone(), lhs.device.clone(), shape_out); let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape; let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape; - kernel_cmp::launch::( + kernel_cmp::launch::( &client, cube_count, CubeDim::default(), lhs.as_tensor_arg::(vectorization_factor), rhs.as_tensor_arg::(vectorization_factor), - output.as_tensor_arg::(vectorization_factor), + output.as_tensor_arg::(vectorization_factor), None, to_contiguous_lhs, to_contiguous_rhs, @@ -205,7 +207,12 @@ pub(crate) fn launch_cmp>( } } -pub(crate) fn launch_scalar_cmp>( +pub(crate) fn launch_scalar_cmp< + R: JitRuntime, + E: JitElement, + BT: BoolElement, + O: ComparisonOp, +>( mut tensor: JitTensor, scalar: E, ) -> JitTensor { @@ -224,9 +231,9 @@ pub(crate) fn launch_scalar_cmp let cube_count = calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); - let same_tensor_type = core::any::TypeId::of::() == core::any::TypeId::of::(); + let same_tensor_type = core::any::TypeId::of::() == core::any::TypeId::of::(); if same_tensor_type && tensor.can_mut() { - kernel_scalar_cmp::launch::( + kernel_scalar_cmp::launch::( &client, cube_count, cube_dim, @@ -241,70 +248,94 @@ pub(crate) fn launch_scalar_cmp tensor.shape, tensor.device, tensor.strides, - DType::U32, + BT::dtype(), ) } else { - let output = empty_device::( + let output = empty_device::( tensor.client.clone(), tensor.device.clone(), tensor.shape.clone(), ); - kernel_scalar_cmp::launch::( + kernel_scalar_cmp::launch::( &client, cube_count, CubeDim::default(), tensor.as_tensor_arg::(vectorization_factor), ScalarArg::new(scalar), - output.as_tensor_arg::(vectorization_factor), + output.as_tensor_arg::(vectorization_factor), ); output } } -pub fn equal(lhs: JitTensor, rhs: JitTensor) -> JitTensor { - launch_cmp::(lhs, rhs) +pub fn equal( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + launch_cmp::(lhs, rhs) } -pub fn greater(lhs: JitTensor, rhs: JitTensor) -> JitTensor { - launch_cmp::(lhs, rhs) +pub fn greater( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + launch_cmp::(lhs, rhs) } -pub fn greater_equal( +pub fn greater_equal( lhs: JitTensor, rhs: JitTensor, ) -> JitTensor { - launch_cmp::(lhs, rhs) + launch_cmp::(lhs, rhs) } -pub fn lower(lhs: JitTensor, rhs: JitTensor) -> JitTensor { - launch_cmp::(lhs, rhs) +pub fn lower( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + launch_cmp::(lhs, rhs) } -pub fn lower_equal( +pub fn lower_equal( lhs: JitTensor, rhs: JitTensor, ) -> JitTensor { - launch_cmp::(lhs, rhs) + launch_cmp::(lhs, rhs) } -pub fn equal_elem(lhs: JitTensor, rhs: E) -> JitTensor { - launch_scalar_cmp::(lhs, rhs) +pub fn equal_elem( + lhs: JitTensor, + rhs: E, +) -> JitTensor { + launch_scalar_cmp::(lhs, rhs) } -pub fn greater_elem(lhs: JitTensor, rhs: E) -> JitTensor { - launch_scalar_cmp::(lhs, rhs) +pub fn greater_elem( + lhs: JitTensor, + rhs: E, +) -> JitTensor { + launch_scalar_cmp::(lhs, rhs) } -pub fn lower_elem(lhs: JitTensor, rhs: E) -> JitTensor { - launch_scalar_cmp::(lhs, rhs) +pub fn lower_elem( + lhs: JitTensor, + rhs: E, +) -> JitTensor { + launch_scalar_cmp::(lhs, rhs) } -pub fn greater_equal_elem(lhs: JitTensor, rhs: E) -> JitTensor { - launch_scalar_cmp::(lhs, rhs) +pub fn greater_equal_elem( + lhs: JitTensor, + rhs: E, +) -> JitTensor { + launch_scalar_cmp::(lhs, rhs) } -pub fn lower_equal_elem(lhs: JitTensor, rhs: E) -> JitTensor { - launch_scalar_cmp::(lhs, rhs) +pub fn lower_equal_elem( + lhs: JitTensor, + rhs: E, +) -> JitTensor { + launch_scalar_cmp::(lhs, rhs) } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/base.rs index 1796389157..9f07d36c55 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/base.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/base.rs @@ -69,7 +69,7 @@ impl Default for ConvTranspose2dStrategy { /// * `options` - The options to use for the convolution /// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option. /// -pub fn conv2d( +pub fn conv2d( input: JitTensor, weight: JitTensor, bias: Option>, @@ -77,13 +77,11 @@ pub fn conv2d( strategy: Conv2dStrategy, ) -> JitTensor { match strategy { - Conv2dStrategy::Direct => conv2d_direct::(input, weight, bias, options), + Conv2dStrategy::Direct => conv2d_direct::(input, weight, bias, options), #[cfg(feature = "autotune")] - Conv2dStrategy::Autotune => conv2d_autotune::(input, weight, bias, options), - Conv2dStrategy::Gemm => conv2d_im2col::(input, weight, bias, options), - Conv2dStrategy::ImplicitGemm => { - conv2d_implicit_gemm::(input, weight, bias, options) - } + Conv2dStrategy::Autotune => conv2d_autotune::(input, weight, bias, options), + Conv2dStrategy::Gemm => conv2d_im2col::(input, weight, bias, options), + Conv2dStrategy::ImplicitGemm => conv2d_implicit_gemm::(input, weight, bias, options), } } @@ -104,14 +102,14 @@ pub fn conv_transpose2d( ) -> JitTensor { match strategy { ConvTranspose2dStrategy::Direct => { - conv_transpose2d_direct::(input, weight, bias, options) + conv_transpose2d_direct::(input, weight, bias, options) } #[cfg(feature = "autotune")] ConvTranspose2dStrategy::Autotune => { - conv_transpose2d_autotune::(input, weight, bias, options) + conv_transpose2d_autotune::(input, weight, bias, options) } ConvTranspose2dStrategy::Gemm => { - conv_transpose2d_col2im::(input, weight, bias, options) + conv_transpose2d_col2im::(input, weight, bias, options) } } } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs index 846aa3d8dd..0659561805 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs @@ -1,14 +1,18 @@ use burn_tensor::{ - ops::{conv::calculate_conv_transpose_output_size, ConvTransposeOptions, FloatTensorOps as _}, + ops::{conv::calculate_conv_transpose_output_size, ConvTransposeOptions}, Shape, }; use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ - kernel::into_contiguous, + kernel::{ + into_contiguous, + matmul::{matmul, MatmulStrategy}, + slice, + }, ops::{numeric::empty_device, reshape, swap_dims}, tensor::JitTensor, - FloatElement, IntElement, JitBackend, JitRuntime, + FloatElement, JitElement, JitRuntime, }; use super::batches_per_run; @@ -20,7 +24,7 @@ use super::batches_per_run; /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// -pub fn conv_transpose2d_col2im( +pub fn conv_transpose2d_col2im( input: JitTensor, weight: JitTensor, bias: Option>, @@ -77,12 +81,12 @@ pub fn conv_transpose2d_col2im( let input_shape_run = Shape::new([batches_per_run, input_channels, input_h, input_w]); for run in 0..runs { - let input = JitBackend::::float_narrow(input.clone(), 0, run, 1); + let input = index::(input.clone(), run); let input = reshape(input, input_shape_run.clone()); let im_shape = Shape::new([batches_per_run, im_channels, im_h, im_w]); - let image_slice = JitBackend::::float_narrow(image.clone(), 0, run, 1); + let image_slice = index::(image.clone(), run); let image_slice = reshape(image_slice, im_shape); - execute::( + execute::( input, weight.clone(), bias.clone(), @@ -96,7 +100,7 @@ pub fn conv_transpose2d_col2im( } else { let im_shape = Shape::new([batches_per_run, im_channels, im_h, im_w]); let image = empty_device::(input.client.clone(), input.device.clone(), im_shape); - execute::( + execute::( input, weight, bias, @@ -109,8 +113,21 @@ pub fn conv_transpose2d_col2im( } } +pub(crate) fn index(tensor: JitTensor, i: usize) -> JitTensor { + #[allow(clippy::single_range_in_vec_init)] + let mut indices = vec![i..i + 1]; + for dim in tensor.shape.dims[1..].iter() { + indices.push(0..*dim); + } + let new_shape = Shape { + dims: tensor.shape.dims[1..].to_vec(), + }; + let tensor = slice::(tensor, &indices); + reshape(tensor, new_shape) +} + #[allow(clippy::too_many_arguments)] -fn execute( +fn execute( input: JitTensor, weight: JitTensor, bias: Option>, @@ -128,7 +145,7 @@ fn execute( let input_shape = Shape::new([groups, input_ch_per_group, col_shape_1]); let input = reshape(input, input_shape); - let columns = JitBackend::::float_matmul(weight, input); + let columns = matmul::(weight, input, MatmulStrategy::default()); let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1])); col2im::( diff --git a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs index 9a65b6ae51..d5154ecc4b 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs @@ -11,7 +11,7 @@ use crate::{ reshape, }, tensor::JitTensor, - FloatElement, IntElement, JitRuntime, + FloatElement, JitRuntime, }; #[derive(CubeLaunch)] @@ -120,8 +120,7 @@ fn direct_conv2d_kernel( /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// -#[allow(clippy::extra_unused_type_parameters)] -pub fn conv2d_direct( +pub fn conv2d_direct( input: JitTensor, weight: JitTensor, bias: Option>, diff --git a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs index 88125f0463..abcb8488fb 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs @@ -1,14 +1,16 @@ use burn_tensor::{ - ops::{conv::calculate_conv_output_size, ConvOptions, FloatTensorOps as _}, + ops::{conv::calculate_conv_output_size, ConvOptions}, Shape, }; use cubecl::{calculate_cube_count_elemwise, linalg::matmul, prelude::*}; use crate::{ - kernel::into_contiguous, + kernel::{ + conv::index, into_contiguous, launch_binop, matmul::matmul, matmul::MatmulStrategy, AddOp, + }, ops::{numeric::empty_device, reshape, swap_dims}, tensor::JitTensor, - FloatElement, IntElement, JitBackend, JitRuntime, + FloatElement, JitRuntime, }; #[derive(CubeLaunch)] @@ -178,7 +180,7 @@ fn im2col( /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// -pub fn conv2d_im2col( +pub fn conv2d_im2col( input: JitTensor, weight: JitTensor, bias: Option>, @@ -206,7 +208,7 @@ pub fn conv2d_im2col( if kernel_h == 1 && kernel_w == 1 && in_height == out_h && in_width == out_w { // Special case for 1x1 kernels (sometimes used to scale the image by a set of weights) - return execute_1x1_kernel::(input, weight, bias, options); + return execute_1x1_kernel::(input, weight, bias, options); } let batches_per_run = batches_per_run(batch_size, out_h, out_w) @@ -221,9 +223,9 @@ pub fn conv2d_im2col( let input = reshape(input, in_shape); let in_shape_run = Shape::new([batches_per_run, in_channels, in_height, in_width]); for run in 0..runs { - let input = JitBackend::::float_narrow(input.clone(), 0, run, 1); + let input = index::(input.clone(), run); let input = reshape(input, in_shape_run.clone()); - let out_slice = JitBackend::::float_narrow(out.clone(), 0, run, 1); + let out_slice = index::(out.clone(), run); let out_slice = reshape(out_slice, matmul_shape.clone()); execute::( input, @@ -245,12 +247,12 @@ pub fn conv2d_im2col( if let Some(bias) = bias { let bias = reshape(bias, Shape::new([1, out_channels, 1, 1])); - out = JitBackend::::float_add(out, bias) + out = launch_binop::(out, bias) } out } -fn execute_1x1_kernel( +fn execute_1x1_kernel( input: JitTensor, weight: JitTensor, bias: Option>, @@ -266,12 +268,12 @@ fn execute_1x1_kernel( let weight = reshape(weight, Shape::new([groups, out_c_per_grp, in_c_per_grp])); let in_shape = Shape::new([groups, in_c_per_grp, batch_size * height * width]); let input = reshape(input, in_shape); - let out = JitBackend::::float_matmul(weight, input); + let out = matmul::(weight, input, MatmulStrategy::default()); let mut out = reshape(out, Shape::new([out_channels, batch_size, height, width])); if let Some(bias) = bias { let bias = reshape(bias, Shape::new([out_channels, 1, 1, 1])); - out = JitBackend::::float_add(out, bias) + out = launch_binop::(out, bias) } swap_dims(out, 0, 1) diff --git a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs index 49a639ef43..6771f2c5e2 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs @@ -18,7 +18,7 @@ use crate::{ permute, }, tensor::JitTensor, - FloatElement, IntElement, JitRuntime, + FloatElement, JitRuntime, }; use super::nchw_to_nhwc; @@ -30,8 +30,7 @@ use super::nchw_to_nhwc; /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// -#[allow(clippy::extra_unused_type_parameters)] -pub fn conv2d_implicit_gemm( +pub fn conv2d_implicit_gemm( input: JitTensor, weight: JitTensor, bias: Option>, diff --git a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs index 1062241d75..6a97ab8759 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs @@ -8,7 +8,7 @@ use crate::{ reshape, }, tensor::JitTensor, - IntElement, JitRuntime, + JitRuntime, }; use burn_tensor::{ops::ConvTransposeOptions, Shape}; @@ -121,8 +121,7 @@ fn conv_transpose2d_direct_kernel( /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// -#[allow(clippy::extra_unused_type_parameters)] -pub fn conv_transpose2d_direct( +pub fn conv_transpose2d_direct( input: JitTensor, weight: JitTensor, bias: Option>, diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs index 05ec7fd960..4a8122a478 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs @@ -16,13 +16,13 @@ use crate::{ prng::random_uniform, }, tensor::JitTensor, - FloatElement, IntElement, JitAutotuneKey, JitRuntime, JitTuneId, + FloatElement, JitAutotuneKey, JitRuntime, JitTuneId, }; use super::Conv2dAutotuneKey; /// Executes autotune on conv2d operations -pub fn conv2d_autotune( +pub fn conv2d_autotune( input: JitTensor, weights: JitTensor, bias: Option>, @@ -35,9 +35,7 @@ pub fn conv2d_autotune( TUNER.execute( &JitTuneId::new::(&input.device), &client, - Box::new(Conv2dOperations::::new( - input, weights, bias, options, - )), + Box::new(Conv2dOperations::::new(input, weights, bias, options)), ) } @@ -46,7 +44,7 @@ pub fn conv2d_autotune( create_key = create_key::, should_run = should_run )] -pub fn conv2d_operations( +pub fn conv2d_operations( key: JitAutotuneKey, input: JitTensor, weights: JitTensor, @@ -74,8 +72,8 @@ pub fn conv2d_operations( tune_with!(input, weights, bias, options) } -fn should_run( - op: &Conv2dOperations, +fn should_run( + op: &Conv2dOperations, key: &JitAutotuneKey, index: usize, ) -> bool { diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs index 3a8c1d04f2..c2d546151a 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs @@ -10,13 +10,13 @@ use crate::{ prng::random_uniform, }, tensor::JitTensor, - FloatElement, IntElement, JitAutotuneKey, JitRuntime, JitTuneId, + FloatElement, JitAutotuneKey, JitRuntime, JitTuneId, }; use super::ConvTranspose2dAutotuneKey; /// Executes autotune on conv2d operations -pub fn conv_transpose2d_autotune( +pub fn conv_transpose2d_autotune( input: JitTensor, weights: JitTensor, bias: Option>, @@ -29,14 +29,14 @@ pub fn conv_transpose2d_autotune( TUNER.execute( &JitTuneId::new::(&input.device), &client, - Box::new(ConvTranspose2dOperations::::new( + Box::new(ConvTranspose2dOperations::::new( input, weights, bias, options, )), ) } #[tune(operations(conv_transpose2d_direct, conv_transpose2d_col2im), create_key = create_key::, should_run = should_run)] -pub fn conv_transpose2d_operations( +pub fn conv_transpose2d_operations( key: JitAutotuneKey, input: JitTensor, weights: JitTensor, @@ -95,8 +95,8 @@ fn create_key( )) } -fn should_run( - _op: &ConvTranspose2dOperations, +fn should_run( + _op: &ConvTranspose2dOperations, key: &JitAutotuneKey, index: usize, ) -> bool { diff --git a/crates/burn-jit/src/kernel/conv/deform_conv2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv2d.rs index b005a2384c..438850fe72 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv2d.rs @@ -1,18 +1,22 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use burn_tensor::{ - ops::{conv::calculate_conv_output_size, DeformConvOptions, FloatTensorOps as _}, + ops::{conv::calculate_conv_output_size, DeformConvOptions}, Shape, }; use crate::{ - kernel::into_contiguous, + kernel::{ + into_contiguous, launch_binop, + matmul::{matmul, MatmulStrategy}, + AddOp, + }, ops::{ numeric::{ones_device, zeros_device}, reshape, swap_dims, }, tensor::JitTensor, - FloatElement, IntElement, JitBackend, JitRuntime, + FloatElement, JitRuntime, }; #[derive(CubeLaunch)] @@ -251,7 +255,7 @@ pub(crate) fn deform_im2col( output } -pub(crate) fn deform_conv2d( +pub(crate) fn deform_conv2d( input: JitTensor, offset: JitTensor, weight: JitTensor, @@ -294,24 +298,15 @@ pub(crate) fn deform_conv2d( let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_size_0])); let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1])); - let out = JitBackend::::float_matmul(weight, columns); + let out = matmul::(weight, columns, MatmulStrategy::default()); let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w])); let out = swap_dims(out, 0, 1); if let Some(bias) = bias { let bias = reshape(bias, Shape::new([1, out_channels, 1, 1])); - JitBackend::::float_add(out, bias) + launch_binop::(out, bias) } else { out } } - -pub(crate) fn index( - tensor: JitTensor, - index: usize, -) -> JitTensor { - let [_, shape_0, shape_1] = tensor.shape.dims(); - let tensor = JitBackend::::float_narrow(tensor, 0, index, 1); - reshape(tensor, Shape::new([shape_0, shape_1])) -} diff --git a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs index 4022a0bbe2..907b5ef344 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs @@ -5,7 +5,12 @@ use burn_tensor::{ use cubecl::{calculate_cube_count_elemwise, cube, prelude::*, CubeDim, CubeLaunch}; use crate::{ - kernel::{cast, into_contiguous}, + element::BoolElement, + kernel::{ + cast, into_contiguous, + matmul::{matmul, MatmulStrategy}, + slice_assign, + }, ops::{ numeric::{empty_device, ones_device, zeros_device}, reshape, swap_dims, @@ -18,7 +23,12 @@ use super::{bilinear_interpolate, deform_im2col, index}; /// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions. #[allow(clippy::single_range_in_vec_init)] -pub(crate) fn deform_conv2d_backward( +pub(crate) fn deform_conv2d_backward< + R: JitRuntime, + E: FloatElement, + I: IntElement, + BT: BoolElement, +>( input: JitTensor, offset: JitTensor, weight: JitTensor, @@ -26,14 +36,14 @@ pub(crate) fn deform_conv2d_backward>, out_grad: JitTensor, options: DeformConvOptions<2>, -) -> DeformConv2dBackward> { +) -> DeformConv2dBackward> { let [_, _, out_h, out_w] = out_grad.shape.dims(); let [_, _, kernel_h, kernel_w] = weight.shape.dims(); let gradient_bias = bias.map(|bias| { - let grad = JitBackend::::float_sum_dim(out_grad.clone(), 0); - let grad = JitBackend::::float_sum_dim(grad, 2); - let grad = JitBackend::::float_sum_dim(grad, 3); + let grad = JitBackend::::float_sum_dim(out_grad.clone(), 0); + let grad = JitBackend::::float_sum_dim(grad, 2); + let grad = JitBackend::::float_sum_dim(grad, 3); reshape(grad, bias.shape) }); @@ -42,7 +52,7 @@ pub(crate) fn deform_conv2d_backward( + let (input_gradient, offset_gradient, mask_gradient) = backward_gradient_inputs::( input.clone(), weight.clone(), offset.clone(), @@ -52,7 +62,7 @@ pub(crate) fn deform_conv2d_backward( + let weight_grad = compute_weight_grad::( input, offset, mask, @@ -71,7 +81,7 @@ pub(crate) fn deform_conv2d_backward( +fn compute_weight_grad( input: JitTensor, offset: JitTensor, mask: Option>, @@ -98,9 +108,9 @@ fn compute_weight_grad( let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1])); let columns = swap_dims(columns, 1, 2); - let grad_weight = JitBackend::::float_matmul(out_grad, columns); + let grad_weight = matmul::(out_grad, columns, MatmulStrategy::default()); - JitBackend::::float_reshape( + reshape( grad_weight, Shape::new([out_channels, in_c_per_group, kernel_h, kernel_w]), ) @@ -108,7 +118,7 @@ fn compute_weight_grad( type InputGradients = (JitTensor, JitTensor, Option>); -fn backward_gradient_inputs( +fn backward_gradient_inputs( image: JitTensor, weight: JitTensor, offset: JitTensor, @@ -138,11 +148,11 @@ fn backward_gradient_inputs( let out_grad = reshape(out_grad, out_grad_shape); for group in 0..groups { - let weight = swap_dims(index::(weight.clone(), group), 0, 1); - let out_grad = index::(out_grad.clone(), group); - let values = JitBackend::::float_matmul(weight, out_grad); + let weight = swap_dims(index::(weight.clone(), group), 0, 1); + let out_grad = index::(out_grad.clone(), group); + let values = matmul::(weight, out_grad, MatmulStrategy::default()); let values = reshape(values, Shape::new([1, col_shape_0, col_shape_1])); - columns = JitBackend::::float_slice_assign( + columns = slice_assign::( columns, &[group..group + 1, 0..col_shape_0, 0..col_shape_1], values, diff --git a/crates/burn-jit/src/kernel/index/flip.rs b/crates/burn-jit/src/kernel/index/flip.rs index e35cac8b2c..583e0346d3 100644 --- a/crates/burn-jit/src/kernel/index/flip.rs +++ b/crates/burn-jit/src/kernel/index/flip.rs @@ -1,4 +1,6 @@ -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use crate::{ + element::JitElement, ops::numeric::empty_device, tensor::JitTensor, BoolElement, JitRuntime, +}; use cubecl::{calculate_cube_count_elemwise, prelude::*}; #[cube(launch_unchecked)] @@ -31,7 +33,7 @@ fn flip_kernel( output[ABSOLUTE_POS] = input[offset_input]; } -pub(crate) fn flip( +pub(crate) fn flip( tensor: JitTensor, indices: &[usize], ) -> JitTensor { @@ -40,26 +42,26 @@ pub(crate) fn flip( tensor.device.clone(), tensor.shape.clone(), ); - flip_on_output::(tensor, output, indices) + flip_on_output::(tensor, output, indices) } -pub(crate) fn flip_on_output( +pub(crate) fn flip_on_output( tensor: JitTensor, output: JitTensor, indices: &[usize], ) -> JitTensor { let ndims = tensor.shape.num_dims(); - let mut indices_sequence = SequenceArg::<'_, R, u32>::new(); + let mut indices_sequence = SequenceArg::<'_, R, BT>::new(); for i in 0..ndims { - indices_sequence.push(ScalarArg::new(indices.contains(&i) as u32)); + indices_sequence.push(ScalarArg::new(BT::new_bool(indices.contains(&i)))); } let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); unsafe { - flip_kernel::launch_unchecked::( + flip_kernel::launch_unchecked::( &tensor.client, cube_count, cube_dim, diff --git a/crates/burn-jit/src/kernel/mask/base.rs b/crates/burn-jit/src/kernel/mask/base.rs index 2140972326..d37c6e05bb 100644 --- a/crates/burn-jit/src/kernel/mask/base.rs +++ b/crates/burn-jit/src/kernel/mask/base.rs @@ -1,8 +1,8 @@ use super::{mask_where::MaskWhereStrategy, MaskFillStrategy}; -use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; +use crate::{element::JitElement, tensor::JitTensor, BoolElement, JitRuntime}; /// Execute the mask fill kernel. -pub(crate) fn mask_fill_auto( +pub(crate) fn mask_fill_auto( tensor: JitTensor, mask: JitTensor, value: E, @@ -13,11 +13,11 @@ pub(crate) fn mask_fill_auto( MaskFillStrategy::Readonly }; - super::mask_fill(tensor, mask, value, strategy) + super::mask_fill::(tensor, mask, value, strategy) } /// Execute the mask where kernel. -pub(crate) fn mask_where_auto( +pub(crate) fn mask_where_auto( tensor: JitTensor, mask: JitTensor, value: JitTensor, @@ -30,5 +30,5 @@ pub(crate) fn mask_where_auto( MaskWhereStrategy::Readonly }; - super::mask_where::(tensor, mask, value, strategy) + super::mask_where::(tensor, mask, value, strategy) } diff --git a/crates/burn-jit/src/kernel/mask/mask_fill.rs b/crates/burn-jit/src/kernel/mask/mask_fill.rs index e8b3f814d9..386e7a5039 100644 --- a/crates/burn-jit/src/kernel/mask/mask_fill.rs +++ b/crates/burn-jit/src/kernel/mask/mask_fill.rs @@ -1,11 +1,16 @@ use cubecl::{calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*}; -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use crate::{ + element::JitElement, + ops::{max_vectorization, numeric::empty_device}, + tensor::JitTensor, + BoolElement, JitRuntime, +}; #[cube(launch)] -fn mask_fill_readonly_kernel( +fn mask_fill_readonly_kernel( input: &Tensor>, - mask: &Tensor>, + mask: &Tensor>, output: &mut Tensor>, value: T, #[comptime] rank: u32, @@ -17,17 +22,15 @@ fn mask_fill_readonly_kernel( let index_input = index_offset_with_layout(input, output, ABSOLUTE_POS, 0, rank, true); let index_mask = index_offset_with_layout(mask, output, ABSOLUTE_POS, 0, rank, true); - if mask[index_mask] >= Line::new(1) { - output[ABSOLUTE_POS] = Line::new(value); - } else { - output[ABSOLUTE_POS] = input[index_input]; - } + let mask = Line::cast_from(mask[index_mask]); + + output[ABSOLUTE_POS] = select_many(mask, Line::new(value), input[index_input]); } #[cube(launch)] -fn mask_fill_inplace_kernel( +fn mask_fill_inplace_kernel( input: &mut Tensor>, - mask: &Tensor>, + mask: &Tensor>, value: T, #[comptime] rank: u32, ) { @@ -36,10 +39,9 @@ fn mask_fill_inplace_kernel( } let index_mask = index_offset_with_layout(mask, input, ABSOLUTE_POS, 0, rank, true); + let mask = Line::cast_from(mask[index_mask]); - if mask[index_mask] >= Line::new(1) { - input[ABSOLUTE_POS] = Line::new(value); - } + input[ABSOLUTE_POS] = select_many(mask, Line::new(value), input[ABSOLUTE_POS]); } #[derive(Clone, Copy, Debug)] @@ -56,19 +58,19 @@ pub enum MaskFillStrategy { } /// Execute the mask fill kernel with the given strategy. -pub fn mask_fill( +pub fn mask_fill( input: JitTensor, mask: JitTensor, value: E, strategy: MaskFillStrategy, ) -> JitTensor { match strategy { - MaskFillStrategy::Readonly => mask_fill_readonly::(input, mask, value), - MaskFillStrategy::Inplace => mask_fill_inplace::(input, mask, value), + MaskFillStrategy::Readonly => mask_fill_readonly::(input, mask, value), + MaskFillStrategy::Inplace => mask_fill_inplace::(input, mask, value), } } -fn mask_fill_readonly( +fn mask_fill_readonly( input: JitTensor, mask: JitTensor, value: EI, @@ -82,14 +84,15 @@ fn mask_fill_readonly( let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim); + let vectorization = max_vectorization(&input); - mask_fill_readonly_kernel::launch::( + mask_fill_readonly_kernel::launch::( &input.client, cube_count, cube_dim, - input.as_tensor_arg::(1), - mask.as_tensor_arg::(1), - output.as_tensor_arg::(1), + input.as_tensor_arg::(vectorization), + mask.as_tensor_arg::(vectorization), + output.as_tensor_arg::(vectorization), ScalarArg::new(value), ndims as u32, ); @@ -97,7 +100,7 @@ fn mask_fill_readonly( output } -fn mask_fill_inplace( +fn mask_fill_inplace( input: JitTensor, mask: JitTensor, value: EI, @@ -105,13 +108,14 @@ fn mask_fill_inplace( let ndims = input.shape.num_dims(); let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim); + let vectorization = max_vectorization(&input); - mask_fill_inplace_kernel::launch::( + mask_fill_inplace_kernel::launch::( &input.client, cube_count, cube_dim, - input.as_tensor_arg::(1), - mask.as_tensor_arg::(1), + input.as_tensor_arg::(vectorization), + mask.as_tensor_arg::(vectorization), ScalarArg::new(value), ndims as u32, ); diff --git a/crates/burn-jit/src/kernel/mask/mask_where.rs b/crates/burn-jit/src/kernel/mask/mask_where.rs index 73c7c8fcf1..5518e9648b 100644 --- a/crates/burn-jit/src/kernel/mask/mask_where.rs +++ b/crates/burn-jit/src/kernel/mask/mask_where.rs @@ -1,11 +1,16 @@ use cubecl::{calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*}; -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use crate::{ + element::JitElement, + ops::{max_vectorization, numeric::empty_device}, + tensor::JitTensor, + BoolElement, JitRuntime, +}; #[cube(launch)] -fn mask_where_readonly_kernel( +fn mask_where_readonly_kernel( input: &Tensor>, - mask: &Tensor>, + mask: &Tensor>, value: &Tensor>, output: &mut Tensor>, #[comptime] rank: u32, @@ -17,20 +22,17 @@ fn mask_where_readonly_kernel( let index_input = index_offset_with_layout(input, output, ABSOLUTE_POS, 0, rank, true); let index_mask = index_offset_with_layout(mask, output, ABSOLUTE_POS, 0, rank, true); let index_value = index_offset_with_layout(value, output, ABSOLUTE_POS, 0, rank, true); + let mask = Line::cast_from(mask[index_mask]); - if mask[index_mask] >= Line::new(1) { - output[ABSOLUTE_POS] = value[index_value]; - } else { - output[ABSOLUTE_POS] = input[index_input]; - } + output[ABSOLUTE_POS] = select_many(mask, value[index_value], input[index_input]); } #[cube(launch)] -fn mask_where_inplace_kernel( +fn mask_where_inplace_kernel( input: &mut Tensor>, - mask: &Tensor>, + mask: &Tensor>, value: &Tensor>, - reverse: u32, + reverse: B, #[comptime] rank: u32, ) { if ABSOLUTE_POS >= input.len() { @@ -40,9 +42,11 @@ fn mask_where_inplace_kernel( let index_mask = index_offset_with_layout(mask, input, ABSOLUTE_POS, 0, rank, true); let index_value = index_offset_with_layout(value, input, ABSOLUTE_POS, 0, rank, true); - if mask[index_mask] != Line::new(reverse) { - input[ABSOLUTE_POS] = value[index_value]; - } + input[ABSOLUTE_POS] = select( + mask[index_mask] != Line::new(reverse), + value[index_value], + input[ABSOLUTE_POS], + ); } #[derive(Clone, Copy, Debug)] @@ -61,20 +65,20 @@ pub enum MaskWhereStrategy { } /// Execute the mask where kernel with the given strategy. -pub fn mask_where( +pub fn mask_where( input: JitTensor, mask: JitTensor, value: JitTensor, strategy: MaskWhereStrategy, ) -> JitTensor { match strategy { - MaskWhereStrategy::Readonly => mask_where_readonly::(input, mask, value), - MaskWhereStrategy::InplaceLhs => mask_where_inplace::(input, mask, value, false), - MaskWhereStrategy::InplaceRhs => mask_where_inplace::(value, mask, input, true), + MaskWhereStrategy::Readonly => mask_where_readonly::(input, mask, value), + MaskWhereStrategy::InplaceLhs => mask_where_inplace::(input, mask, value, false), + MaskWhereStrategy::InplaceRhs => mask_where_inplace::(value, mask, input, true), } } -fn mask_where_readonly( +fn mask_where_readonly( input: JitTensor, mask: JitTensor, value: JitTensor, @@ -88,22 +92,23 @@ fn mask_where_readonly( let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim); + let vectorization = max_vectorization(&input); - mask_where_readonly_kernel::launch::( + mask_where_readonly_kernel::launch::( &input.client, cube_count, cube_dim, - input.as_tensor_arg::(1), - mask.as_tensor_arg::(1), - value.as_tensor_arg::(1), - output.as_tensor_arg::(1), + input.as_tensor_arg::(vectorization), + mask.as_tensor_arg::(vectorization), + value.as_tensor_arg::(vectorization), + output.as_tensor_arg::(vectorization), ndims as u32, ); output } -fn mask_where_inplace( +fn mask_where_inplace( input: JitTensor, mask: JitTensor, value: JitTensor, @@ -112,15 +117,16 @@ fn mask_where_inplace( let ndims = input.shape.num_dims(); let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(input.shape.num_elements(), cube_dim); + let vectorization = max_vectorization(&input); - mask_where_inplace_kernel::launch::( + mask_where_inplace_kernel::launch::( &input.client, cube_count, cube_dim, - input.as_tensor_arg::(1), - mask.as_tensor_arg::(1), - value.as_tensor_arg::(1), - ScalarArg::new(reverse as u32), + input.as_tensor_arg::(vectorization), + mask.as_tensor_arg::(vectorization), + value.as_tensor_arg::(vectorization), + ScalarArg::new(EM::new_bool(reverse)), ndims as u32, ); diff --git a/crates/burn-jit/src/lib.rs b/crates/burn-jit/src/lib.rs index 77a67df37a..ba953ae0d0 100644 --- a/crates/burn-jit/src/lib.rs +++ b/crates/burn-jit/src/lib.rs @@ -21,7 +21,7 @@ pub mod element; use burn_tensor::backend::{DeviceId, DeviceOps}; use cubecl::{compute::CubeTask, Feature, Runtime}; -pub use element::{FloatElement, IntElement, JitElement}; +pub use element::{BoolElement, FloatElement, IntElement, JitElement}; mod backend; diff --git a/crates/burn-jit/src/ops/activation_ops.rs b/crates/burn-jit/src/ops/activation_ops.rs index 7f6b921d16..eecd6849c8 100644 --- a/crates/burn-jit/src/ops/activation_ops.rs +++ b/crates/burn-jit/src/ops/activation_ops.rs @@ -1,10 +1,11 @@ -use crate::{FloatElement, IntElement, JitBackend, JitRuntime}; +use crate::{element::BoolElement, FloatElement, IntElement, JitBackend, JitRuntime}; use burn_tensor::ops::ActivationOps; -impl ActivationOps for JitBackend +impl ActivationOps for JitBackend where R: JitRuntime, F: FloatElement, I: IntElement, + BT: BoolElement, { } diff --git a/crates/burn-jit/src/ops/base.rs b/crates/burn-jit/src/ops/base.rs index 58e3b25c0c..bce600604e 100644 --- a/crates/burn-jit/src/ops/base.rs +++ b/crates/burn-jit/src/ops/base.rs @@ -1,6 +1,6 @@ -use crate::{element::JitElement, kernel, tensor::JitTensor, JitRuntime}; +use crate::{element::JitElement, kernel, tensor::JitTensor, BoolElement, JitRuntime}; use burn_tensor::{Shape, TensorData}; -use cubecl::{tensor_vectorization_factor, CubeElement}; +use cubecl::tensor_vectorization_factor; pub(crate) fn from_data( data: TensorData, @@ -29,11 +29,16 @@ pub fn into_data_sync(tensor: JitTensor) -> Ten TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape) } -pub(crate) async fn bool_into_data(tensor: JitTensor) -> TensorData { +pub(crate) async fn bool_into_data( + tensor: JitTensor, +) -> TensorData { let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one_async(tensor.handle.binding()).await; TensorData::new( - u32::from_bytes(&bytes).iter().map(|i| *i != 0).collect(), + BT::from_bytes(&bytes) + .iter() + .map(|i| *i != BT::false_val()) + .collect(), tensor.shape, ) } diff --git a/crates/burn-jit/src/ops/bool_ops.rs b/crates/burn-jit/src/ops/bool_ops.rs index 036913e88d..017e76f2c4 100644 --- a/crates/burn-jit/src/ops/bool_ops.rs +++ b/crates/burn-jit/src/ops/bool_ops.rs @@ -1,31 +1,32 @@ -use crate::{kernel, FloatElement, IntElement, JitBackend, JitRuntime}; +use crate::{element::BoolElement, kernel, FloatElement, IntElement, JitBackend, JitRuntime}; use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntTensor}; use burn_tensor::{ops::BoolTensorOps, Shape, TensorData}; use std::ops::Range; use super::{expand, permute}; -impl BoolTensorOps for JitBackend +impl BoolTensorOps for JitBackend where R: JitRuntime, F: FloatElement, I: IntElement, + BT: BoolElement, { fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { - super::empty::(shape, device) + super::empty::(shape, device) } async fn bool_into_data(tensor: BoolTensor) -> TensorData { - super::bool_into_data(tensor).await + super::bool_into_data::(tensor).await } fn bool_from_data(data: TensorData, device: &Device) -> BoolTensor { - let data: TensorData = TensorData::new(data.iter::().collect(), data.shape); - super::from_data::(data, device) + let data: TensorData = TensorData::new(data.iter::().collect(), data.shape); + super::from_data::(data, device) } fn bool_into_int(tensor: BoolTensor) -> IntTensor { - kernel::bool_cast::(tensor) + kernel::bool_cast::(tensor) } fn bool_device(tensor: &BoolTensor) -> Device { @@ -41,7 +42,7 @@ where } fn bool_slice(tensor: BoolTensor, ranges: &[Range]) -> BoolTensor { - kernel::slice::(tensor, ranges) + kernel::slice::(tensor, ranges) } fn bool_slice_assign( @@ -49,19 +50,19 @@ where ranges: &[Range], value: BoolTensor, ) -> BoolTensor { - kernel::slice_assign::(tensor, ranges, value) + kernel::slice_assign::(tensor, ranges, value) } fn bool_equal(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { - kernel::equal::(lhs, rhs) + kernel::equal::(lhs, rhs) } fn bool_not(tensor: BoolTensor) -> BoolTensor { - kernel::equal_elem::(tensor, 0) + kernel::equal_elem::(tensor, BT::false_val()) } fn bool_into_float(tensor: BoolTensor) -> FloatTensor { - kernel::bool_cast::(tensor) + kernel::bool_cast::(tensor) } fn bool_swap_dims(mut tensor: BoolTensor, dim1: usize, dim2: usize) -> BoolTensor { @@ -72,7 +73,7 @@ where } fn bool_repeat_dim(tensor: BoolTensor, dim: usize, times: usize) -> BoolTensor { - kernel::repeat_dim::(tensor, dim, times) + kernel::repeat_dim::(tensor, dim, times) } fn bool_permute(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { @@ -84,6 +85,6 @@ where } fn bool_flip(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { - kernel::flip::(tensor, axes) + kernel::flip::(tensor, axes) } } diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index 52b013ec0e..f97b1609ff 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -1,7 +1,10 @@ use super::{expand, numeric, permute}; -use crate::kernel::matmul::{matmul, MatmulStrategy}; use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform}; use crate::kernel::{self, launch_unary, reduce, unary_op, UnaryOp}; +use crate::{ + element::BoolElement, + kernel::matmul::{matmul, MatmulStrategy}, +}; use crate::{execute_with_dtype, JitBackend}; use crate::{FloatElement, IntElement, JitRuntime}; use burn_tensor::ops::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor}; @@ -11,11 +14,12 @@ use cubecl::prelude::*; use half::{bf16, f16}; use std::ops::Range; -impl FloatTensorOps for JitBackend +impl FloatTensorOps for JitBackend where R: JitRuntime, F: FloatElement, I: IntElement, + BT: BoolElement, { fn float_from_data(data: TensorData, device: &Device) -> FloatTensor { super::from_data::(data, device) @@ -248,7 +252,7 @@ where execute_with_dtype!( float(tensor.dtype, value.dtype), E, - kernel::mask_where_auto::(tensor, mask, value) + kernel::mask_where_auto::(tensor, mask, value) ) } @@ -260,7 +264,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - kernel::mask_fill_auto::(tensor, mask, value.elem()) + kernel::mask_fill_auto::(tensor, mask, value.elem()) ) } @@ -268,7 +272,7 @@ where execute_with_dtype!( float(lhs.dtype, rhs.dtype), E, - kernel::equal::(lhs, rhs) + kernel::equal::(lhs, rhs) ) } @@ -276,7 +280,7 @@ where execute_with_dtype!( float(lhs.dtype), E, - kernel::equal_elem::(lhs, rhs.elem()) + kernel::equal_elem::(lhs, rhs.elem()) ) } @@ -284,7 +288,7 @@ where execute_with_dtype!( float(lhs.dtype, rhs.dtype), E, - kernel::greater::(lhs, rhs) + kernel::greater::(lhs, rhs) ) } @@ -292,7 +296,7 @@ where execute_with_dtype!( float(lhs.dtype), E, - kernel::greater_elem::(lhs, rhs.elem()) + kernel::greater_elem::(lhs, rhs.elem()) ) } @@ -300,7 +304,7 @@ where execute_with_dtype!( float(lhs.dtype, rhs.dtype), E, - kernel::greater_equal::(lhs, rhs) + kernel::greater_equal::(lhs, rhs) ) } @@ -308,7 +312,7 @@ where execute_with_dtype!( float(lhs.dtype), E, - kernel::greater_equal_elem::(lhs, rhs.elem()) + kernel::greater_equal_elem::(lhs, rhs.elem()) ) } @@ -316,7 +320,7 @@ where execute_with_dtype!( float(lhs.dtype, rhs.dtype), E, - kernel::lower::(lhs, rhs) + kernel::lower::(lhs, rhs) ) } @@ -324,7 +328,7 @@ where execute_with_dtype!( float(lhs.dtype), E, - kernel::lower_elem::(lhs, rhs.elem()) + kernel::lower_elem::(lhs, rhs.elem()) ) } @@ -332,7 +336,7 @@ where execute_with_dtype!( float(lhs.dtype, rhs.dtype), E, - kernel::lower_equal::(lhs, rhs) + kernel::lower_equal::(lhs, rhs) ) } @@ -340,7 +344,7 @@ where execute_with_dtype!( float(lhs.dtype), E, - kernel::lower_equal_elem::(lhs, rhs.elem()) + kernel::lower_equal_elem::(lhs, rhs.elem()) ) } @@ -633,7 +637,11 @@ where } fn float_flip(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { - execute_with_dtype!(float(tensor.dtype), E, kernel::flip::(tensor, axes)) + execute_with_dtype!( + float(tensor.dtype), + E, + kernel::flip::(tensor, axes) + ) } fn float_cast(tensor: FloatTensor, dtype: FloatDType) -> FloatTensor { diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index cb6603bf80..25bb92521f 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -1,6 +1,9 @@ use super::{expand, numeric, permute}; -use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform}; use crate::kernel::{launch_unary, unary_op, UnaryOp}; +use crate::{ + element::BoolElement, + kernel::prng::{random_bernoulli, random_normal, random_uniform}, +}; use crate::{kernel, FloatElement, IntElement, JitBackend, JitRuntime}; use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; use burn_tensor::{ops::IntTensorOps, Distribution, ElementConversion, Shape, TensorData}; @@ -8,11 +11,12 @@ use cubecl::frontend::Numeric; use cubecl::prelude::*; use std::ops::Range; -impl IntTensorOps for JitBackend +impl IntTensorOps for JitBackend where R: JitRuntime, F: FloatElement, I: IntElement, + BT: BoolElement, { fn int_empty(shape: Shape, device: &Device) -> IntTensor { super::empty::(shape, device) @@ -55,7 +59,7 @@ where mask: BoolTensor, value: IntTensor, ) -> IntTensor { - kernel::mask_where_auto::(tensor, mask, value) + kernel::mask_where_auto::(tensor, mask, value) } fn int_mask_fill( @@ -63,7 +67,7 @@ where mask: BoolTensor, value: IntElem, ) -> IntTensor { - kernel::mask_fill_auto(tensor, mask, value) + kernel::mask_fill_auto::(tensor, mask, value) } fn int_gather( @@ -101,43 +105,43 @@ where } fn int_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - kernel::equal::(lhs, rhs) + kernel::equal::(lhs, rhs) } fn int_equal_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { - kernel::equal_elem::(lhs, rhs) + kernel::equal_elem::(lhs, rhs) } fn int_greater(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - kernel::greater::(lhs, rhs) + kernel::greater::(lhs, rhs) } fn int_greater_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { - kernel::greater_elem::(lhs, rhs) + kernel::greater_elem::(lhs, rhs) } fn int_greater_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - kernel::greater_equal::(lhs, rhs) + kernel::greater_equal::(lhs, rhs) } fn int_greater_equal_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { - kernel::greater_equal_elem::(lhs, rhs) + kernel::greater_equal_elem::(lhs, rhs) } fn int_lower(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - kernel::lower::(lhs, rhs) + kernel::lower::(lhs, rhs) } fn int_lower_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { - kernel::lower_elem::(lhs, rhs) + kernel::lower_elem::(lhs, rhs) } fn int_lower_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { - kernel::lower_equal::(lhs, rhs) + kernel::lower_equal::(lhs, rhs) } fn int_lower_equal_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { - kernel::lower_equal_elem::(lhs, rhs) + kernel::lower_equal_elem::(lhs, rhs) } fn int_add(lhs: IntTensor, rhs: IntTensor) -> IntTensor { @@ -277,6 +281,6 @@ where } fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { - kernel::flip::(tensor, axes) + kernel::flip::(tensor, axes) } } diff --git a/crates/burn-jit/src/ops/module_ops.rs b/crates/burn-jit/src/ops/module_ops.rs index 5539dfc9f2..b5c96058f9 100644 --- a/crates/burn-jit/src/ops/module_ops.rs +++ b/crates/burn-jit/src/ops/module_ops.rs @@ -1,4 +1,5 @@ use crate::{ + element::BoolElement, kernel::{ self, conv::{Conv2dStrategy, ConvTranspose2dStrategy}, @@ -11,11 +12,12 @@ use burn_tensor::ops::{ }; use burn_tensor::ops::{FloatTensor, IntTensor}; -impl ModuleOps for JitBackend +impl ModuleOps for JitBackend where R: JitRuntime, F: FloatElement, I: IntElement, + BT: BoolElement, { fn conv2d( x: FloatTensor, @@ -23,7 +25,7 @@ where bias: Option>, options: ConvOptions<2>, ) -> FloatTensor { - kernel::conv::conv2d::(x, weight, bias, options, Conv2dStrategy::default()) + kernel::conv::conv2d::(x, weight, bias, options, Conv2dStrategy::default()) } fn deform_conv2d( @@ -34,7 +36,7 @@ where bias: Option>, options: DeformConvOptions<2>, ) -> FloatTensor { - kernel::conv::deform_conv2d::(x, offset, weight, mask, bias, options) + kernel::conv::deform_conv2d::(x, offset, weight, mask, bias, options) } fn deform_conv2d_backward( @@ -46,7 +48,7 @@ where output_grad: FloatTensor, options: DeformConvOptions<2>, ) -> DeformConv2dBackward { - kernel::conv::deform_conv2d_backward::( + kernel::conv::deform_conv2d_backward::( x, offset, weight, diff --git a/crates/burn-jit/src/ops/qtensor.rs b/crates/burn-jit/src/ops/qtensor.rs index e5eb4005a6..94b1a6f2ee 100644 --- a/crates/burn-jit/src/ops/qtensor.rs +++ b/crates/burn-jit/src/ops/qtensor.rs @@ -9,6 +9,7 @@ use burn_tensor::{ }; use crate::{ + element::BoolElement, kernel, tensor::{JitQuantizationParameters, JitTensor, QJitTensor}, FloatElement, IntElement, JitBackend, JitRuntime, @@ -27,11 +28,12 @@ fn packed_tensor>( JitTensor::new_contiguous(client, device.clone(), shape.into(), buffer, DType::U32) } -impl QTensorOps for JitBackend +impl QTensorOps for JitBackend where R: JitRuntime, F: FloatElement, I: IntElement, + BT: BoolElement, { fn q_from_data(data: TensorData, device: &Device) -> QuantizedTensor { match data.dtype { diff --git a/crates/burn-jit/src/ops/transaction.rs b/crates/burn-jit/src/ops/transaction.rs index 62477d3ce1..7320186570 100644 --- a/crates/burn-jit/src/ops/transaction.rs +++ b/crates/burn-jit/src/ops/transaction.rs @@ -3,13 +3,14 @@ use burn_tensor::{ DType, TensorData, }; -use crate::{FloatElement, IntElement, JitBackend, JitRuntime}; +use crate::{element::BoolElement, FloatElement, IntElement, JitBackend, JitRuntime}; -impl TransactionOps for JitBackend +impl TransactionOps for JitBackend where R: JitRuntime, F: FloatElement, I: IntElement, + BT: BoolElement, { fn tr_execute( transaction: burn_tensor::ops::TransactionPrimitive, @@ -51,7 +52,7 @@ where client = Some(t.client.clone()); } - kinds.push(Kind::Bool(num_bindings, t.shape.into(), DType::U32)); + kinds.push(Kind::Bool(num_bindings, t.shape.into(), BT::dtype())); num_bindings += 1; bindings.push(t.handle.binding()) }); @@ -64,7 +65,7 @@ where .await .into_iter() .map(Some) - .collect::>(); + .collect::>>(); let mut result = TransactionPrimitiveResult::default(); diff --git a/crates/burn-jit/src/tensor/base.rs b/crates/burn-jit/src/tensor/base.rs index 112260c3bf..3eb44b3e02 100644 --- a/crates/burn-jit/src/tensor/base.rs +++ b/crates/burn-jit/src/tensor/base.rs @@ -162,9 +162,9 @@ macro_rules! execute_with_dtype { type $element = i8; $op } - // NOTE: bool and qfloat dtypes are actually represented as u32 + // NOTE: bool and qfloat dtypes are actually represented as u32/u8 // burn_tensor::DType::Bool => { - // type $element = u32; + // type $element = u32/u8; // $op // } // burn_tensor::DType::QFloat(_) => { diff --git a/crates/burn-jit/src/tensor/qtensor.rs b/crates/burn-jit/src/tensor/qtensor.rs index fdf7068e1a..4ef5f77589 100644 --- a/crates/burn-jit/src/tensor/qtensor.rs +++ b/crates/burn-jit/src/tensor/qtensor.rs @@ -6,7 +6,9 @@ use burn_tensor::{ read_sync, DType, TensorData, TensorMetadata, }; -use crate::{ops::into_data, FloatElement, IntElement, JitBackend, JitRuntime}; +use crate::{ + element::BoolElement, ops::into_data, FloatElement, IntElement, JitBackend, JitRuntime, +}; use super::JitTensor; @@ -96,10 +98,11 @@ impl Clone for JitQuantizationParameters { } } -impl - From>> for JitQuantizationParameters +impl + From>> + for JitQuantizationParameters { - fn from(value: QuantizationParametersPrimitive>) -> Self { + fn from(value: QuantizationParametersPrimitive>) -> Self { JitQuantizationParameters { scale: value.scale, offset: value.offset, diff --git a/crates/burn-jit/src/tests/mask_fill.rs b/crates/burn-jit/src/tests/mask_fill.rs index 4542bbe3f1..c768373d13 100644 --- a/crates/burn-jit/src/tests/mask_fill.rs +++ b/crates/burn-jit/src/tests/mask_fill.rs @@ -11,6 +11,7 @@ mod tests { let actual = Tensor::::from_primitive(TensorPrimitive::Float(mask_fill::< _, ::FloatElem, + ::BoolElem, >( tensor.into_primitive().tensor(), mask.into_primitive(), @@ -31,6 +32,7 @@ mod tests { let actual = Tensor::::from_primitive(TensorPrimitive::Float(mask_fill::< _, ::FloatElem, + ::BoolElem, >( tensor.into_primitive().tensor(), mask.into_primitive(), diff --git a/crates/burn-jit/src/tests/mask_where.rs b/crates/burn-jit/src/tests/mask_where.rs index befdb76af6..a14993995c 100644 --- a/crates/burn-jit/src/tests/mask_where.rs +++ b/crates/burn-jit/src/tests/mask_where.rs @@ -23,6 +23,7 @@ mod tests { Tensor::::from_primitive(TensorPrimitive::Float(mask_where::< _, ::FloatElem, + ::BoolElem, >( tensor.into_primitive().tensor(), mask.into_primitive(), @@ -44,6 +45,7 @@ mod tests { Tensor::::from_primitive(TensorPrimitive::Float(mask_where::< _, ::FloatElem, + ::BoolElem, >( tensor.into_primitive().tensor(), mask.into_primitive(), diff --git a/crates/burn-jit/src/tests/mod.rs b/crates/burn-jit/src/tests/mod.rs index b1ee4ce26d..f60edc2a1b 100644 --- a/crates/burn-jit/src/tests/mod.rs +++ b/crates/burn-jit/src/tests/mod.rs @@ -38,12 +38,12 @@ pub use serial_test; #[macro_export] macro_rules! testgen_all { () => { - use burn_tensor::{Float, Int}; - $crate::testgen_all!([Float], [Int]); + use burn_tensor::{Float, Int, Bool}; + $crate::testgen_all!([Float], [Int], [Bool]); }; - ([$($float:ident),*], [$($int:ident),*]) => { + ([$($float:ident),*], [$($int:ident),*], [$($bool:ident),*]) => { mod jit { - burn_jit::testgen_jit!([$($float),*], [$($int),*]); + burn_jit::testgen_jit!([$($float),*], [$($int),*], [$($bool),*]); mod kernel { use super::*; @@ -84,7 +84,7 @@ macro_rules! testgen_all { } } mod jit_fusion { - burn_jit::testgen_jit_fusion!([$($float),*], [$($int),*]); + burn_jit::testgen_jit_fusion!([$($float),*], [$($int),*], [$($bool),*]); } }; } @@ -92,31 +92,31 @@ macro_rules! testgen_all { #[macro_export] macro_rules! testgen_jit { () => { - use burn_tensor::{Float, Int}; - $crate::testgen_jit!([Float], [Int]); + use burn_tensor::{Float, Int, Bool}; + $crate::testgen_jit!([Float], [Int], [Bool]); }; - ([$($float:ident),*], [$($int:ident),*]) => { + ([$($float:ident),*], [$($int:ident),*], [$($bool:ident),*]) => { pub use super::*; use burn_jit::tests::{burn_autodiff, burn_ndarray, burn_tensor, serial_test}; - pub type TestBackend = JitBackend; - pub type TestBackend2 = JitBackend; + pub type TestBackend = JitBackend; + pub type TestBackend2 = JitBackend; pub type ReferenceBackend = burn_ndarray::NdArray; pub type TestTensor = burn_tensor::Tensor; - pub type TestTensor2 = burn_tensor::Tensor, D>; + pub type TestTensor2 = burn_tensor::Tensor, D>; pub type TestTensorInt = burn_tensor::Tensor; - pub type TestTensorInt2 = - burn_tensor::Tensor, D, burn_tensor::Int>; + pub type TestTensorInt2 = + burn_tensor::Tensor, D, burn_tensor::Int>; pub type TestTensorBool = burn_tensor::Tensor; - pub type TestTensorBool2 = - burn_tensor::Tensor, D, burn_tensor::Bool>; + pub type TestTensorBool2 = + burn_tensor::Tensor, D, burn_tensor::Bool>; pub type ReferenceTensor = burn_tensor::Tensor; - burn_tensor::testgen_all!([$($float),*], [$($int),*]); + burn_tensor::testgen_all!([$($float),*], [$($int),*], [$($bool),*]); burn_autodiff::testgen_all!([$($float),*]); // Not all ops are implemented for quantization yet, notably missing: @@ -135,28 +135,28 @@ macro_rules! testgen_jit_fusion { use burn_tensor::{Float, Int}; $crate::testgen_jit_fusion!([Float], [Int]); }; - ([$($float:ident),*], [$($int:ident),*]) => { + ([$($float:ident),*], [$($int:ident),*], [$($bool:ident),*]) => { use super::*; use burn_jit::tests::{burn_autodiff, burn_fusion, burn_ndarray, burn_tensor}; - pub type TestBackend = burn_fusion::Fusion>; - pub type TestBackend2 = burn_fusion::Fusion>; + pub type TestBackend = burn_fusion::Fusion>; + pub type TestBackend2 = burn_fusion::Fusion>; pub type ReferenceBackend = burn_ndarray::NdArray; pub type TestTensor = burn_tensor::Tensor; - pub type TestTensor2 = burn_tensor::Tensor, D>; + pub type TestTensor2 = burn_tensor::Tensor, D>; pub type TestTensorInt = burn_tensor::Tensor; - pub type TestTensorInt2 = - burn_tensor::Tensor, D, burn_tensor::Int>; + pub type TestTensorInt2 = + burn_tensor::Tensor, D, burn_tensor::Int>; pub type TestTensorBool = burn_tensor::Tensor; - pub type TestTensorBool2 = - burn_tensor::Tensor, D, burn_tensor::Bool>; + pub type TestTensorBool2 = + burn_tensor::Tensor, D, burn_tensor::Bool>; pub type ReferenceTensor = burn_tensor::Tensor; - burn_tensor::testgen_all!([$($float),*], [$($int),*]); + burn_tensor::testgen_all!([$($float),*], [$($int),*], [$($bool),*]); burn_autodiff::testgen_all!([$($float),*]); // Not all ops are implemented for quantization yet, notably missing: diff --git a/crates/burn-ndarray/src/backend.rs b/crates/burn-ndarray/src/backend.rs index 74957f5f1f..060899b979 100644 --- a/crates/burn-ndarray/src/backend.rs +++ b/crates/burn-ndarray/src/backend.rs @@ -53,6 +53,7 @@ impl Backend for type IntElem = I; type BoolTensorPrimitive = NdArrayTensor; + type BoolElem = bool; type QuantizedTensorPrimitive = NdArrayQTensor; type QuantizedEncoding = Q; diff --git a/crates/burn-remote/src/client/channel.rs b/crates/burn-remote/src/client/channel.rs index 6c431702af..d7102dd97a 100644 --- a/crates/burn-remote/src/client/channel.rs +++ b/crates/burn-remote/src/client/channel.rs @@ -19,6 +19,8 @@ impl RunnerChannel for WsChannel { type IntElem = i32; + type BoolElem = u32; + fn name() -> String { "remote".into() } diff --git a/crates/burn-router/src/backend.rs b/crates/burn-router/src/backend.rs index 6fcf80ce3d..a5ada5e5fd 100644 --- a/crates/burn-router/src/backend.rs +++ b/crates/burn-router/src/backend.rs @@ -55,6 +55,8 @@ impl Backend for BackendRouter { type BoolTensorPrimitive = RouterTensor; + type BoolElem = R::BoolElem; + type QuantizedTensorPrimitive = RouterTensor; type QuantizedEncoding = u32; diff --git a/crates/burn-router/src/channel/base.rs b/crates/burn-router/src/channel/base.rs index 876d273f62..887190ecfa 100644 --- a/crates/burn-router/src/channel/base.rs +++ b/crates/burn-router/src/channel/base.rs @@ -18,6 +18,8 @@ pub trait RunnerChannel: Clone + Send + Sync + 'static + Sized { type FloatElem: Element; /// Int element type. type IntElem: Element; + /// Bool element type. + type BoolElem: Element; /// Name of the channel. fn name() -> String; diff --git a/crates/burn-router/src/types.rs b/crates/burn-router/src/types.rs index 3b694d8779..f36e436638 100644 --- a/crates/burn-router/src/types.rs +++ b/crates/burn-router/src/types.rs @@ -206,6 +206,7 @@ macro_rules! impl_multi_backend_types { type FloatElem = $DefaultBackend::FloatElem; type IntElem = $DefaultBackend::IntElem; + type BoolElem = $DefaultBackend::BoolElem; type Client = MultiRunnerClient<$DefaultBackend, $($OtherBackend),+>; diff --git a/crates/burn-tch/src/backend.rs b/crates/burn-tch/src/backend.rs index cef9a8d586..c294ae0025 100644 --- a/crates/burn-tch/src/backend.rs +++ b/crates/burn-tch/src/backend.rs @@ -101,6 +101,7 @@ impl Backend for LibTorch { type IntElem = i64; type BoolTensorPrimitive = TchTensor; + type BoolElem = bool; type QuantizedTensorPrimitive = TchQTensor; type QuantizedEncoding = Q; diff --git a/crates/burn-tensor/src/tensor/backend/base.rs b/crates/burn-tensor/src/tensor/backend/base.rs index f951a9b6f3..973f4ede65 100644 --- a/crates/burn-tensor/src/tensor/backend/base.rs +++ b/crates/burn-tensor/src/tensor/backend/base.rs @@ -83,6 +83,8 @@ pub trait Backend: /// Tensor primitive to be used for all bool operations. type BoolTensorPrimitive: TensorMetadata + 'static; + /// Tensor primitive to be used for all bool operations. + type BoolElem: Element; /// Tensor primitive to be used for all quantized operations. type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static; diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index 5b15a45591..8aa41ee24d 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -17,28 +17,28 @@ macro_rules! testgen_all { pub type FloatType = ::FloatElem; pub type IntType = ::IntElem; - pub type BoolType = ::BoolTensorPrimitive; + pub type BoolType = ::BoolElem; $crate::testgen_with_float_param!(); $crate::testgen_no_param!(); } }; - ([$($float:ident),*], [$($int:ident),*]) => { + ([$($float:ident),*], [$($int:ident),*], [$($bool:ident),*]) => { pub mod tensor { pub use super::*; pub type FloatType = ::FloatElem; pub type IntType = ::IntElem; - pub type BoolType = ::BoolTensorPrimitive; + pub type BoolType = ::BoolElem; ::paste::paste! { $(mod [<$float _ty>] { pub use super::*; - pub type TestBackend = TestBackend2<$float, IntType>; - pub type TestTensor = TestTensor2<$float, IntType, D>; - pub type TestTensorInt = TestTensorInt2<$float, IntType, D>; - pub type TestTensorBool = TestTensorBool2<$float, IntType, D>; + pub type TestBackend = TestBackend2<$float, IntType, BoolType>; + pub type TestTensor = TestTensor2<$float, IntType, BoolType, D>; + pub type TestTensorInt = TestTensorInt2<$float, IntType, BoolType, D>; + pub type TestTensorBool = TestTensorBool2<$float, IntType, BoolType, D>; pub type FloatType = $float; @@ -47,13 +47,25 @@ macro_rules! testgen_all { $(mod [<$int _ty>] { pub use super::*; - pub type TestBackend = TestBackend2; - pub type TestTensor = TestTensor2; - pub type TestTensorInt = TestTensorInt2; - pub type TestTensorBool = TestTensorBool2; + pub type TestBackend = TestBackend2; + pub type TestTensor = TestTensor2; + pub type TestTensorInt = TestTensorInt2; + pub type TestTensorBool = TestTensorBool2; pub type IntType = $int; + $crate::testgen_with_int_param!(); + })* + $(mod [<$bool _bool_ty>] { + pub use super::*; + + pub type TestBackend = TestBackend2; + pub type TestTensor = TestTensor2; + pub type TestTensorInt = TestTensorInt2; + pub type TestTensorBool = TestTensorBool2; + + pub type BoolType = $bool; + $crate::testgen_with_int_param!(); })* } @@ -307,6 +319,29 @@ macro_rules! testgen_with_int_param { }; } +#[allow(missing_docs)] +#[macro_export] +macro_rules! testgen_with_bool_param { + () => { + burn_tensor::testgen_all_op!(); + burn_tensor::testgen_any_op!(); + burn_tensor::testgen_argwhere_nonzero!(); + burn_tensor::testgen_cast!(); + burn_tensor::testgen_cat!(); + burn_tensor::testgen_expand!(); + burn_tensor::testgen_full!(); + burn_tensor::testgen_map_comparison!(); + burn_tensor::testgen_mask!(); + burn_tensor::testgen_nan!(); + burn_tensor::testgen_repeat_dim!(); + burn_tensor::testgen_repeat!(); + burn_tensor::testgen_reshape!(); + burn_tensor::testgen_stack!(); + burn_tensor::testgen_transpose!(); + burn_tensor::tri_mask!(); + }; +} + #[allow(missing_docs)] #[macro_export] macro_rules! testgen_no_param { diff --git a/crates/burn-tensor/src/tests/ops/remainder.rs b/crates/burn-tensor/src/tests/ops/remainder.rs index fa75630fe8..996c71a7b7 100644 --- a/crates/burn-tensor/src/tests/ops/remainder.rs +++ b/crates/burn-tensor/src/tests/ops/remainder.rs @@ -67,7 +67,7 @@ mod tests { fn should_be_zero() { let device = Default::default(); let lhs = Tensor::::from_data(TensorData::from([0.0, 0.0, 0.0]), &device); - let rhs = Tensor::::from_data(TensorData::from([3.5, -2.1, 1e-5]), &device); + let rhs = Tensor::::from_data(TensorData::from([3.5, -2.1, 1e-4]), &device); let output = lhs.remainder(rhs); let expected = TensorData::from([0.0, 0.0, 0.0]); diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index 7c26dcc31b..0751ad9f41 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -10,7 +10,7 @@ pub use burn_jit::{ }; pub use burn_jit::{tensor::JitTensor, JitBackend}; -pub use burn_jit::{FloatElement, IntElement}; +pub use burn_jit::{BoolElement, FloatElement, IntElement}; pub use cubecl::flex32; pub use cubecl::ir::CubeDim; pub use cubecl::wgpu::*; @@ -21,8 +21,12 @@ pub type SpirV = cubecl::wgpu::spirv::VkSpirvCompiler; #[cfg(feature = "spirv")] type Compiler = SpirV; +#[cfg(feature = "spirv")] +type Byte = u8; #[cfg(not(feature = "spirv"))] type Compiler = Wgsl; +#[cfg(not(feature = "spirv"))] +type Bool = u32; #[cfg(feature = "fusion")] /// Tensor backend that uses the wgpu crate for executing GPU compute shaders. @@ -56,8 +60,8 @@ type Compiler = Wgsl; /// /// You can disable the `fusion` feature flag to remove that functionality, which might be /// necessary on `wasm` for now. -pub type Wgpu = - burn_fusion::Fusion, F, I>>; +pub type Wgpu = + burn_fusion::Fusion, F, I, B>>; #[cfg(not(feature = "fusion"))] /// Tensor backend that uses the wgpu crate for executing GPU compute shaders. @@ -91,7 +95,8 @@ pub type Wgpu = /// /// You can enable the `fusion` feature flag to add that functionality, which might improve /// performance. -pub type Wgpu = JitBackend, F, I>; +pub type Wgpu = + JitBackend, F, I, B>; #[cfg(test)] mod tests { @@ -103,7 +108,7 @@ mod tests { // Don't test `flex32` for now, burn sees it as `f32` but is actually `f16` precision, so it // breaks a lot of tests from precision issues #[cfg(feature = "spirv")] - burn_jit::testgen_all!([f16, f32], [i8, i16, i32, i64]); + burn_jit::testgen_all!([f16, f32], [i8, i16, i32, i64], [u8, u32]); #[cfg(not(feature = "spirv"))] - burn_jit::testgen_all!([f32], [i32]); + burn_jit::testgen_all!([f32], [i32], [u32]); } diff --git a/examples/custom-cubecl-kernel/examples/custom-cubecl-kernel.rs b/examples/custom-cubecl-kernel/examples/custom-cubecl-kernel.rs index 6cb9e6a6ae..de6bfcc7d4 100644 --- a/examples/custom-cubecl-kernel/examples/custom-cubecl-kernel.rs +++ b/examples/custom-cubecl-kernel/examples/custom-cubecl-kernel.rs @@ -71,7 +71,7 @@ fn autodiff(device: &B::Device) { } fn main() { - type MyBackend = burn::backend::wgpu::JitBackend; + type MyBackend = burn::backend::wgpu::JitBackend; type MyAutodiffBackend = burn::backend::Autodiff; let device = Default::default(); inference::(&device); diff --git a/examples/custom-cubecl-kernel/src/backward.rs b/examples/custom-cubecl-kernel/src/backward.rs index 3c66ae8e0e..a894f4e446 100644 --- a/examples/custom-cubecl-kernel/src/backward.rs +++ b/examples/custom-cubecl-kernel/src/backward.rs @@ -10,10 +10,10 @@ use burn::{ }, tensor::{Shape, TensorMetadata}, }; -use burn_jit::{FloatElement, IntElement, JitBackend, JitRuntime}; +use burn_jit::{element::BoolElement, FloatElement, IntElement, JitBackend, JitRuntime}; -impl AutodiffBackend - for Autodiff> +impl AutodiffBackend + for Autodiff> { } diff --git a/examples/custom-cubecl-kernel/src/forward.rs b/examples/custom-cubecl-kernel/src/forward.rs index a8bf17fcd7..0e180e231a 100644 --- a/examples/custom-cubecl-kernel/src/forward.rs +++ b/examples/custom-cubecl-kernel/src/forward.rs @@ -3,12 +3,15 @@ use crate::{kernel::fused_matmul_add_relu_kernel, FloatTensor}; use super::Backend; use burn::tensor::Shape; use burn_jit::{ - kernel::into_contiguous, tensor::JitTensor, FloatElement, IntElement, JitBackend, JitRuntime, + element::BoolElement, kernel::into_contiguous, tensor::JitTensor, FloatElement, IntElement, + JitBackend, JitRuntime, }; use cubecl::{CubeCount, CubeDim}; /// Implement our custom backend trait for the generic `JitBackend`. -impl Backend for JitBackend { +impl Backend + for JitBackend +{ fn fused_matmul_add_relu( lhs: FloatTensor, rhs: FloatTensor, diff --git a/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs b/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs index a309ea5716..0c2201080e 100644 --- a/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs +++ b/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs @@ -71,7 +71,7 @@ fn autodiff(device: &B::Device) { } fn main() { - type MyBackend = burn::backend::wgpu::JitBackend; + type MyBackend = burn::backend::wgpu::JitBackend; type MyAutodiffBackend = burn::backend::Autodiff; let device = Default::default(); inference::(&device); diff --git a/examples/custom-wgpu-kernel/src/backward.rs b/examples/custom-wgpu-kernel/src/backward.rs index b9032413bc..eb374d6c10 100644 --- a/examples/custom-wgpu-kernel/src/backward.rs +++ b/examples/custom-wgpu-kernel/src/backward.rs @@ -9,12 +9,15 @@ use burn::{ ops::{broadcast_shape, Backward, Ops, OpsKind}, Autodiff, NodeID, }, - wgpu::{FloatElement, IntElement, JitBackend, WgpuRuntime}, + wgpu::{BoolElement, FloatElement, IntElement, JitBackend, WgpuRuntime}, }, tensor::{Shape, TensorMetadata}, }; -impl AutodiffBackend for Autodiff> {} +impl AutodiffBackend + for Autodiff> +{ +} // Implement our custom backend trait for any backend that also implements our custom backend trait. // diff --git a/examples/custom-wgpu-kernel/src/forward.rs b/examples/custom-wgpu-kernel/src/forward.rs index c8476230d2..e257e13bf0 100644 --- a/examples/custom-wgpu-kernel/src/forward.rs +++ b/examples/custom-wgpu-kernel/src/forward.rs @@ -3,8 +3,8 @@ use crate::FloatTensor; use super::Backend; use burn::{ backend::wgpu::{ - build_info, into_contiguous, kernel_source, FloatElement, IntElement, JitBackend, - JitTensor, KernelSource, SourceKernel, SourceTemplate, WgpuRuntime, + build_info, into_contiguous, kernel_source, BoolElement, FloatElement, IntElement, + JitBackend, JitTensor, KernelSource, SourceKernel, SourceTemplate, WgpuRuntime, }, tensor::Shape, }; @@ -41,7 +41,9 @@ impl KernelSource for FusedMatmulAddRelu { } /// Implement our custom backend trait for the existing backend `WgpuBackend`. -impl Backend for JitBackend { +impl Backend + for JitBackend +{ fn fused_matmul_add_relu( lhs: FloatTensor, rhs: FloatTensor, From a5624c15957d94ead7bb6b351047b63c2de3909b Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Fri, 29 Nov 2024 20:20:21 +0100 Subject: [PATCH 3/4] [Optimization] Implicit gemm rewrite (#2545) --- Cargo.lock | 26 +- Cargo.toml | 4 +- backend-comparison/Cargo.toml | 2 +- backend-comparison/benches/conv2d.rs | 147 +++++- backend-comparison/benches/matmul.rs | 42 +- crates/burn-core/Cargo.toml | 6 +- .../burn-core/src/data/dataloader/batcher.rs | 1 + crates/burn-core/src/lib.rs | 2 + crates/burn-core/src/nn/rnn/lstm.rs | 1 - .../burn-fusion/src/stream/execution/tests.rs | 2 +- crates/burn-jit/src/kernel/contiguous.rs | 6 +- .../burn-jit/src/kernel/conv/conv2d/base.rs | 21 +- .../src/kernel/conv/conv2d/gemm/algorithm.rs | 124 +++++ .../src/kernel/conv/conv2d/gemm/base.rs | 140 ++++++ .../src/kernel/conv/conv2d/gemm/config.rs | 15 + .../conv/conv2d/gemm/homogeneous/base.rs | 435 ++++++++++++++++++ .../conv/conv2d/gemm/homogeneous/mod.rs | 1 + .../src/kernel/conv/conv2d/gemm/launch.rs | 269 +++++++++++ .../kernel/conv/conv2d/gemm/loader/bias.rs | 116 +++++ .../kernel/conv/conv2d/gemm/loader/im2col.rs | 147 ++++++ .../src/kernel/conv/conv2d/gemm/loader/mod.rs | 2 + .../src/kernel/conv/conv2d/gemm/mod.rs | 10 + .../kernel/conv/conv2d/gemm/reader/bias.rs | 38 ++ .../kernel/conv/conv2d/gemm/reader/im2col.rs | 112 +++++ .../src/kernel/conv/conv2d/gemm/reader/mod.rs | 2 + .../burn-jit/src/kernel/conv/conv2d/im2col.rs | 12 +- .../src/kernel/conv/conv2d/implicit_gemm.rs | 43 +- .../src/kernel/conv/conv2d/layout_swap.rs | 2 +- crates/burn-jit/src/kernel/conv/conv2d/mod.rs | 3 + .../src/kernel/conv/conv2d/tune/conv2d.rs | 41 +- .../kernel/conv/deform_conv_transpose2d.rs | 54 ++- crates/burn-jit/src/kernel/matmul/base.rs | 37 +- .../burn-jit/src/kernel/matmul/tune/base.rs | 17 +- .../src/kernel/reduce/subcube/kernel.rs | 2 +- crates/burn-jit/src/ops/transaction.rs | 2 +- crates/burn-jit/src/template/base.rs | 2 +- crates/burn-jit/src/tests/conv2d.rs | 73 ++- crates/burn-ndarray/src/tensor.rs | 2 +- crates/burn-tensor/src/tests/module/conv3d.rs | 3 +- crates/burn-train/src/metric/base.rs | 1 + crates/burn/Cargo.toml | 6 +- 41 files changed, 1830 insertions(+), 141 deletions(-) create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/config.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/mod.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/bias.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/mod.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/mod.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/bias.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/mod.rs diff --git a/Cargo.lock b/Cargo.lock index ace556a7e8..8897110b75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1666,7 +1666,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1674,6 +1674,7 @@ dependencies = [ "cubecl-linalg", "cubecl-runtime 0.4.0", "cubecl-wgpu", + "half", ] [[package]] @@ -1697,7 +1698,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1714,7 +1715,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1732,7 +1733,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1746,7 +1747,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1762,7 +1763,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1773,6 +1774,7 @@ dependencies = [ "derive-new 0.6.0", "half", "log", + "paste", ] [[package]] @@ -1787,7 +1789,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "bytemuck", "cubecl-core", @@ -1798,7 +1800,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "cubecl-common 0.4.0", "darling", @@ -1813,7 +1815,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1850,7 +1852,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "async-channel", "async-lock", @@ -1871,7 +1873,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1885,7 +1887,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index 5c29707122..ca45f967bb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2c09d4dd1ecb9f474e524dc47b05599edb7049e7" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2c09d4dd1ecb9f474e524dc47b05599edb7049e7" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0f8312690a4328266ae9267314f50fb4ed0835ad" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0f8312690a4328266ae9267314f50fb4ed0835ad" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index 9e199ba8c9..1d1a62cf49 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -17,8 +17,8 @@ candle-cuda = ["burn/candle-cuda"] candle-metal = ["burn/candle", "burn/metal"] cuda-jit = ["burn/cuda-jit"] cuda-jit-fusion = ["cuda-jit", "burn/fusion"] -hip-jit = ["burn/hip-jit"] default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"] +hip-jit = ["burn/hip-jit"] ndarray = ["burn/ndarray"] ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] diff --git a/backend-comparison/benches/conv2d.rs b/backend-comparison/benches/conv2d.rs index d9fdf47a8c..c2a46ad64f 100644 --- a/backend-comparison/benches/conv2d.rs +++ b/backend-comparison/benches/conv2d.rs @@ -1,3 +1,5 @@ +use std::hint::black_box; + use backend_comparison::persistence::save; use burn::tensor::{ backend::Backend, module::conv2d, ops::ConvOptions, Distribution, Shape, Tensor, @@ -5,6 +7,7 @@ use burn::tensor::{ use burn_common::benchmark::{run_benchmark, Benchmark}; pub struct Conv2dBenchmark { + suffix: &'static str, input_shape: Shape, weight_shape: Shape, bias_shape: Shape, @@ -16,7 +19,7 @@ impl Benchmark for Conv2dBenchmark { type Args = (Tensor, Tensor, Tensor); fn name(&self) -> String { - "conv2d".into() + format!("conv2d-{}", self.suffix) } fn shapes(&self) -> Vec> { @@ -50,6 +53,10 @@ impl Benchmark for Conv2dBenchmark { fn sync(&self) { B::sync(&self.device) } + + fn num_samples(&self) -> usize { + 40 + } } #[allow(dead_code)] @@ -75,6 +82,7 @@ fn bench( let groups = 1; let options = ConvOptions::new(strides, padding, dilations, groups); let benchmark = Conv2dBenchmark:: { + suffix: "input_16x512x512_weight_16x3x3_stride_1", input_shape: [batch_size, channels_in, height_in, width_in].into(), weight_shape: [ channels_out, @@ -88,14 +96,135 @@ fn bench( device: device.clone(), }; - save::( - vec![run_benchmark(benchmark)], - device, - feature_name, - url, - token, - ) - .unwrap(); + let conv1 = Conv2dBenchmark:: { + suffix: "input_3x227x227_weight_96x11x11_stride_4", + input_shape: [batch_size, 3, 227, 227].into(), + weight_shape: [96, 3, 11, 11].into(), + bias_shape: [96].into(), + options: ConvOptions::new([4, 4], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv2 = Conv2dBenchmark:: { + suffix: "input_3x231x231_weight_96x11x11_stride_4", + input_shape: [batch_size, 3, 231, 231].into(), + weight_shape: [96, 3, 11, 11].into(), + bias_shape: [96].into(), + options: ConvOptions::new([4, 4], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv3 = Conv2dBenchmark:: { + suffix: "input_3x227x227_weight_64x7x7_stride_2", + input_shape: [batch_size, 3, 227, 227].into(), + weight_shape: [64, 3, 7, 7].into(), + bias_shape: [64].into(), + options: ConvOptions::new([2, 2], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv4 = Conv2dBenchmark:: { + suffix: "input_64x224x224_weight_64x7x7_stride_2", + input_shape: [batch_size, 64, 224, 224].into(), + weight_shape: [64, 64, 7, 7].into(), + bias_shape: [64].into(), + options: ConvOptions::new([2, 2], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv5 = Conv2dBenchmark:: { + suffix: "input_96x24x24_weight_256x5x5_stride_1", + input_shape: [batch_size, 96, 24, 24].into(), + weight_shape: [256, 96, 5, 5].into(), + bias_shape: [256].into(), + options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv6 = Conv2dBenchmark:: { + suffix: "input_256x12x12_weight_512x3x3_stride_1", + input_shape: [batch_size, 256, 12, 12].into(), + weight_shape: [512, 256, 3, 3].into(), + bias_shape: [512].into(), + options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv7 = Conv2dBenchmark:: { + suffix: "input_3x224x224_weight_64x3x3_stride_1", + input_shape: [batch_size, 3, 224, 224].into(), + weight_shape: [64, 3, 3, 3].into(), + bias_shape: [64].into(), + options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv8 = Conv2dBenchmark:: { + suffix: "input_64x112x112_weight_128x3x3_stride_1", + input_shape: [batch_size, 64, 112, 112].into(), + weight_shape: [128, 64, 3, 3].into(), + bias_shape: [128].into(), + options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv9 = Conv2dBenchmark:: { + suffix: "input_64x56x56_weight_64x3x3_stride_1", + input_shape: [batch_size, 64, 56, 56].into(), + weight_shape: [64, 64, 3, 3].into(), + bias_shape: [64].into(), + options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv10 = Conv2dBenchmark:: { + suffix: "input_128x28x28_weight_128x3x3_stride_1", + input_shape: [batch_size, 128, 28, 28].into(), + weight_shape: [128, 128, 3, 3].into(), + bias_shape: [128].into(), + options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv11 = Conv2dBenchmark:: { + suffix: "input_256x14x14_weight_256x3x3_stride_1", + input_shape: [batch_size, 256, 14, 14].into(), + weight_shape: [256, 256, 3, 3].into(), + bias_shape: [256].into(), + options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv12 = Conv2dBenchmark:: { + suffix: "input_512x7x7_weight_512x3x3_stride_1", + input_shape: [batch_size, 512, 7, 7].into(), + weight_shape: [512, 512, 3, 3].into(), + bias_shape: [512].into(), + options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv13 = Conv2dBenchmark:: { + suffix: "input_96x224x224_weight_64x1x1_stride_1", + input_shape: [batch_size, 96, 224, 224].into(), + weight_shape: [64, 96, 1, 1].into(), + bias_shape: [64].into(), + options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let benches = vec![ + benchmark, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8, conv9, conv10, conv11, + conv12, conv13, + ]; + let mut results = Vec::new(); + + for bench in benches { + let result = black_box(run_benchmark(bench)); + results.push(result); + } + + save::(results, device, feature_name, url, token).unwrap(); } fn main() { diff --git a/backend-comparison/benches/matmul.rs b/backend-comparison/benches/matmul.rs index e4766c3df5..d31e7cc954 100644 --- a/backend-comparison/benches/matmul.rs +++ b/backend-comparison/benches/matmul.rs @@ -1,5 +1,5 @@ use backend_comparison::persistence::save; -use burn::tensor::{backend::Backend, Distribution, Shape, Tensor}; +use burn::tensor::{backend::Backend, Shape, Tensor}; use burn_common::benchmark::{run_benchmark, Benchmark}; use derive_new::new; @@ -21,17 +21,13 @@ impl Benchmark for MatmulBenchmark { vec![self.shape_lhs.dims.clone(), self.shape_rhs.dims.clone()] } - fn num_samples(&self) -> usize { - 10 - } - fn execute(&self, (lhs, rhs): Self::Args) { - lhs.clone().matmul(rhs.clone()); + lhs.matmul(rhs); } fn prepare(&self) -> Self::Args { - let lhs = Tensor::random(self.shape_lhs.clone(), Distribution::Default, &self.device); - let rhs = Tensor::random(self.shape_rhs.clone(), Distribution::Default, &self.device); + let lhs = Tensor::zeros(self.shape_lhs.clone(), &self.device); + let rhs = Tensor::zeros(self.shape_rhs.clone(), &self.device); (lhs, rhs) } @@ -48,24 +44,18 @@ fn bench( url: Option<&str>, token: Option<&str>, ) { - const D: usize = 3; - let batch_size = 8; - let m = 2048; - let k = 2048; - let n = 2048; - let shape_lhs = [batch_size, m, k].into(); - let shape_rhs = [batch_size, k, n].into(); - - let benchmark = MatmulBenchmark::::new(shape_lhs, shape_rhs, device.clone()); - - save::( - vec![run_benchmark(benchmark)], - device, - feature_name, - url, - token, - ) - .unwrap(); + let benchmarks = [(2, 4096, 4096, 4096), (8, 2048, 2048, 2048)] + .into_iter() + .map(|(b, m, n, k)| { + let shape_lhs = [b, m, k].into(); + let shape_rhs = [b, k, n].into(); + + MatmulBenchmark::::new(shape_lhs, shape_rhs, device.clone()) + }) + .map(run_benchmark) + .collect(); + + save::(benchmarks, device, feature_name, url, token).unwrap(); } fn main() { diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index 1ef80fd4f1..68b47b3826 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -91,10 +91,10 @@ blas-netlib = ["burn-ndarray?/blas-netlib"] metal = ["burn-candle?/metal"] openblas = ["burn-ndarray?/blas-openblas"] openblas-system = ["burn-ndarray?/blas-openblas-system"] -template = ["burn-wgpu?/template"] remote = ["burn-remote/client"] router = ["burn-router"] server = ["burn-remote/server"] +template = ["burn-wgpu?/template"] candle = ["burn-candle"] candle-cuda = ["candle", "burn-candle/cuda"] @@ -138,10 +138,10 @@ burn-candle = { path = "../burn-candle", version = "0.16.0", optional = true } burn-cuda = { path = "../burn-cuda", version = "0.16.0", optional = true, default-features = false } burn-hip = { path = "../burn-hip", version = "0.16.0", optional = true, default-features = false } burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true, default-features = false } -burn-tch = { path = "../burn-tch", version = "0.16.0", optional = true } -burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", optional = true, default-features = false } burn-remote = { path = "../burn-remote", version = "0.16.0", default-features = false, optional = true } burn-router = { path = "../burn-router", version = "0.16.0", default-features = false, optional = true } +burn-tch = { path = "../burn-tch", version = "0.16.0", optional = true } +burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", optional = true, default-features = false } data-encoding = { workspace = true } uuid = { workspace = true } diff --git a/crates/burn-core/src/data/dataloader/batcher.rs b/crates/burn-core/src/data/dataloader/batcher.rs index 2ab3b87255..b0c242952e 100644 --- a/crates/burn-core/src/data/dataloader/batcher.rs +++ b/crates/burn-core/src/data/dataloader/batcher.rs @@ -29,6 +29,7 @@ where } } +/// Test batcher #[cfg(test)] #[derive(new, Clone)] pub struct TestBatcher; diff --git a/crates/burn-core/src/lib.rs b/crates/burn-core/src/lib.rs index d1788d10cd..f554518430 100644 --- a/crates/burn-core/src/lib.rs +++ b/crates/burn-core/src/lib.rs @@ -48,6 +48,7 @@ pub use burn_remote::server; extern crate alloc; +/// Backend for test cases #[cfg(all( test, not(feature = "test-tch"), @@ -65,6 +66,7 @@ pub type TestBackend = burn_wgpu::Wgpu; #[cfg(all(test, feature = "test-cuda"))] pub type TestBackend = burn_cuda::Cuda; +/// Backend for autodiff test cases #[cfg(feature = "std")] #[cfg(test)] pub type TestAutodiffBackend = burn_autodiff::Autodiff; diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index 802d7f4720..9a7c23399b 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -384,7 +384,6 @@ mod tests { /// i_t = sigmoid(0.5*0.1 + 0.5*0) = sigmoid(0.05) = 0.5123725 /// o_t = sigmoid(1.1*0.1 + 1.1*0) = sigmoid(0.11) = 0.5274723 /// c_t = tanh(0.9*0.1 + 0.9*0) = tanh(0.09) = 0.0892937 - /// C_t = f_t * 0 + i_t * c_t = 0 + 0.5123725 * 0.0892937 = 0.04575243 /// h_t = o_t * tanh(C_t) = 0.5274723 * tanh(0.04575243) = 0.5274723 * 0.04568173 = 0.024083648 #[test] diff --git a/crates/burn-fusion/src/stream/execution/tests.rs b/crates/burn-fusion/src/stream/execution/tests.rs index 9dff06f5cd..98cea0e935 100644 --- a/crates/burn-fusion/src/stream/execution/tests.rs +++ b/crates/burn-fusion/src/stream/execution/tests.rs @@ -500,7 +500,7 @@ impl OptimizationBuilder for TestOptimizationBuilder { } } -impl<'i> StreamSegment for TestSegment<'i> { +impl StreamSegment for TestSegment<'_> { // The operations in the process. fn operations(&self) -> &[OperationDescription] { self.operations diff --git a/crates/burn-jit/src/kernel/contiguous.rs b/crates/burn-jit/src/kernel/contiguous.rs index 170f202e76..b21d032c78 100644 --- a/crates/burn-jit/src/kernel/contiguous.rs +++ b/crates/burn-jit/src/kernel/contiguous.rs @@ -7,8 +7,10 @@ pub fn into_contiguous(tensor: JitTensor) -> JitTensor { } execute_with_dtype!(tensor.dtype, E, { - let output = - cubecl::linalg::tensor::into_contiguous::(&tensor.client, tensor.as_handle_ref()); + let output = cubecl::linalg::tensor::into_contiguous::( + &tensor.client, + &tensor.as_handle_ref(), + ); JitTensor::new( tensor.client, diff --git a/crates/burn-jit/src/kernel/conv/conv2d/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/base.rs index 9f07d36c55..0b3a35dc45 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/base.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/base.rs @@ -1,15 +1,12 @@ -use burn_tensor::{ - ops::{ConvOptions, ConvTransposeOptions}, - TensorData, -}; +use burn_tensor::ops::{ConvOptions, ConvTransposeOptions}; -use crate::{tensor::JitTensor, FloatElement, IntElement, JitElement, JitRuntime}; +use crate::{tensor::JitTensor, FloatElement, IntElement, JitRuntime}; #[cfg(feature = "autotune")] use super::{conv2d_autotune, conv_transpose2d_autotune}; use super::{ conv2d_direct, conv2d_im2col, conv_transpose2d_col2im, conv_transpose2d_direct, - implicit_gemm::conv2d_implicit_gemm, + gemm::launch::conv2d_gemm_cmma_large_m, implicit_gemm::conv2d_implicit_gemm, }; /// The strategy to be used when launching a convolution kernel. @@ -24,6 +21,9 @@ pub enum Conv2dStrategy { /// Implicit GEMM implementation of convolution. Lower memory usage but requires CMMA and /// has constraints on tensor shape. ImplicitGemm, + /// Implicit GEMM implementation of convolution. Uses `cubecl` matmul components to provide + /// the flexibility needed to work well for varied problem sizes. + ImplicitGemmComplex, } impl Default for Conv2dStrategy { @@ -82,6 +82,9 @@ pub fn conv2d( Conv2dStrategy::Autotune => conv2d_autotune::(input, weight, bias, options), Conv2dStrategy::Gemm => conv2d_im2col::(input, weight, bias, options), Conv2dStrategy::ImplicitGemm => conv2d_implicit_gemm::(input, weight, bias, options), + Conv2dStrategy::ImplicitGemmComplex => { + conv2d_gemm_cmma_large_m::(input, weight, bias, options) + } } } @@ -113,9 +116,3 @@ pub fn conv_transpose2d( } } } - -#[allow(unused)] -pub(crate) fn debug_data(tensor: JitTensor) -> TensorData { - let bytes = tensor.client.read_one(tensor.handle.binding()); - TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape) -} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs new file mode 100644 index 0000000000..7a210e7c70 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs @@ -0,0 +1,124 @@ +use std::marker::PhantomData; + +use cubecl::{ + linalg::matmul::{ + components::{ + stage::{self, StageSize}, + tile::{ + self, + accelerated::{Accelerated16x16x16, CmmaValid}, + Matmul as _, + }, + MatmulKernel, + }, + kernels::{matmul::AdvancedConfig, MatmulAvailabilityError}, + }, + prelude::*, +}; + +use super::{ + base::{Convolution, ConvolutionKernel, ConvolutionLaunch, ConvolutionProblem}, + homogeneous::base::ImplicitGemmConvolution, +}; + +/// Specifications for a convolution algorithm +pub trait Algorithm { + const PLANE_DIM: u32; + + type EG: Numeric; + type ES: Numeric; + type EA: Numeric; + + type TileMatmul: tile::Matmul + MatmulKernel; + + type StageSize: StageSize; + type StageMatmul: stage::Matmul + MatmulKernel; + + type GlobalConvolution: Convolution + + ConvolutionLaunch; + + /// Cube dim for launch + fn cube_dim() -> CubeDim; + /// The cube count for a given convolution problem + fn cube_count(problem: &ConvolutionProblem) -> CubeCount; + + /// Make a convolution config from a convolution problem, and launch options + fn make_config( + problem: &ConvolutionProblem, + cube_dim: &CubeDim, + cube_count: &CubeCount, + advanced_config: &AdvancedConfig, + ) -> >::Config { + Self::GlobalConvolution::make_config(problem, cube_dim, cube_count, advanced_config) + } + + /// Check availability of the matmul algorithm + fn check_availability( + client: &ComputeClient, + ) -> Result<(), MatmulAvailabilityError> { + Self::GlobalConvolution::check_availability::(client) + } + + /// Determine whether the given convolution problem is valid to launch (within hardware limits) + fn can_launch( + client: &ComputeClient, + problem: &ConvolutionProblem, + ) -> bool { + if problem.options.groups > 1 || Self::check_availability::(client).is_err() { + return false; + } + + let cube_count = Self::cube_count(problem); + let (max_x, max_y, max_z) = R::max_cube_count(); + match cube_count { + CubeCount::Static(x, y, z) => x <= max_x && y <= max_y && z <= max_z, + _ => true, + } + } +} + +/// Cmma convolution +pub struct Cmma { + pub _eg: PhantomData, + pub _es: PhantomData, + pub _ea: PhantomData, + pub _stage: PhantomData, +} + +impl Algorithm + for Cmma +where + (ES, EA): CmmaValid, +{ + const PLANE_DIM: u32 = 32; + type EG = EG; + type ES = ES; + type EA = EA; + + type TileMatmul = Accelerated16x16x16; + + type StageSize = Stage; + type StageMatmul = stage::multi_buffer::Matmul< + Self::ES, + Self::EG, + Self::EA, + Self::TileMatmul, + Self::StageSize, + >; + + type GlobalConvolution = + ImplicitGemmConvolution; + + fn cube_dim() -> CubeDim { + CubeDim::new(Self::PLANE_DIM, Self::StageSize::NUM_M, 1) + } + + fn cube_count(problem: &ConvolutionProblem) -> CubeCount { + let m_stage = Self::StageSize::NUM_M * Self::TileMatmul::M; + let n_stage = Self::StageSize::NUM_N * Self::TileMatmul::N; + let cubes_needed_m = (problem.m as u32).div_ceil(m_stage); + let cubes_needed_n = (problem.n as u32).div_ceil(n_stage); + + CubeCount::Static(cubes_needed_m, cubes_needed_n, 1) + } +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs new file mode 100644 index 0000000000..bc242107f9 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs @@ -0,0 +1,140 @@ +use burn_tensor::ops::ConvOptions; +use cubecl::linalg::matmul::{ + components::{ + global::{AccumulatorLoader, Unloader}, + stage, MatmulProblem, MatrixLayout, + }, + kernels::{matmul::AdvancedConfig, MatmulAvailabilityError}, +}; +use cubecl::prelude::*; + +use super::Config; + +#[cube] +pub trait Convolution>: + 'static + Send + Sync + ConvolutionKernel +{ + type LhsLoader: CubeType; + type RhsLoader: CubeType; + type AccumulatorLoader: AccumulatorLoader; + + type Out: Unloader; + type Accumulator: CubeType; + + /// Performs the convolution over data loaded by the + /// LHS and RHS loaders, over the range given for K, and stores with + /// using the output unloader. + /// + /// To compute the whole range of k values, use k_range=(0, K) where + /// K is the K dimension of LHS and RHS. + fn execute( + lhs_loader: Self::LhsLoader, + rhs_loader: Self::RhsLoader, + acc_loader: Self::AccumulatorLoader, + unloader: Self::Out, + acc: &mut Self::Accumulator, + k_range: (u32, u32), + #[comptime] config: Self::Config, + ); + + fn init_lhs_loader( + lhs: &Tensor>, + x_offset: u32, + y_offset: u32, + #[comptime] config: Self::Config, + ) -> Self::LhsLoader; + + fn init_rhs_loader( + rhs: &Tensor>, + x_offset: u32, + y_offset: u32, + #[comptime] config: Self::Config, + ) -> Self::RhsLoader; + + fn init_bias_loader( + rhs: &Tensor>, + n_offset: u32, + #[comptime] config: Self::Config, + #[comptime] has_bias: bool, + ) -> Self::AccumulatorLoader; + + fn init_unloader(out: &mut Tensor>, x_offset: u32, y_offset: u32) -> Self::Out; + + fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator; +} + +/// Provides configuration for a matmul kernel at any level +pub trait ConvolutionKernel { + /// Configuration tailored to the matmul implementation + type Config: Config; + + /// Asserts that the configuration for this matmul will lead to a valid computation + fn check_config(config: Self::Config); + + /// Checks if the client can handle the features used in this computation + fn check_availability( + client: &ComputeClient, + ) -> Result<(), MatmulAvailabilityError>; + + fn make_config( + problem: &ConvolutionProblem, + cube_dim: &CubeDim, + cube_count: &CubeCount, + advanced_config: &AdvancedConfig, + ) -> Self::Config; +} + +/// Provides launch entry point to solve a matmul +pub trait ConvolutionLaunch: ConvolutionKernel { + /// Entry point + /// + /// # Safety + /// + /// Out-of-bounds can happen + #[allow(clippy::too_many_arguments)] + unsafe fn launch_unchecked( + client: &ComputeClient<::Server, ::Channel>, + cube_dim: CubeDim, + cube_count: CubeCount, + input: TensorArg<'_, R>, + weight: TensorArg<'_, R>, + bias: TensorArg<'_, R>, + out: TensorArg<'_, R>, + config: >::Config, + ); +} + +#[derive(Clone)] +/// Description of a matmul problem to solve, regardless of actual data +pub struct ConvolutionProblem { + pub m: usize, + pub n: usize, + pub k: usize, + pub lhs_layout: MatrixLayout, + pub rhs_layout: MatrixLayout, + pub lhs_line_size: u8, + pub rhs_line_size: u8, + pub out_line_size: u8, + + pub kernel_size: (u32, u32), + pub options: ConvOptions<2>, + pub out_shape_y: usize, + pub out_shape_x: usize, + pub has_bias: bool, +} + +impl ConvolutionProblem { + pub fn as_matmul_problem(&self) -> MatmulProblem { + MatmulProblem { + m: self.m, + n: self.n, + k: self.k, + batches: (vec![], vec![]), + lhs_layout: self.lhs_layout, + rhs_layout: self.rhs_layout, + lhs_line_size: self.lhs_line_size, + rhs_line_size: self.rhs_line_size, + out_line_size: self.out_line_size, + } + } +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/config.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/config.rs new file mode 100644 index 0000000000..7895a5cc1a --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/config.rs @@ -0,0 +1,15 @@ +use cubecl::linalg::matmul::components::global; + +/// Convolution specific config, extends regular matmul [`Config`](global::Config) +pub trait Config: global::Config { + /// The shape of the output at `dim` + fn out_shape(&self, dim: u32) -> u32; + /// The size of the convolution kernel at `dim` + fn kernel_size(&self, dim: u32) -> u32; + /// The dilation of the kernel at `dim` + fn dilation(&self, dim: u32) -> u32; + /// The stride of the kernel at `dim` + fn stride(&self, dim: u32) -> u32; + /// The padding of the kernel at `dim` + fn padding(&self, dim: u32) -> i32; +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs new file mode 100644 index 0000000000..ca7399e72d --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs @@ -0,0 +1,435 @@ +use cubecl::{ + linalg::matmul::{ + components::{ + global::{ + self, + homogeneous::{self, CyclicLoading, RhsLoader}, + unloader::Unloader, + AccumulatorLoader, Config as _, Loader, + }, + stage::{ + self, + multi_buffer::{LhsReader, RhsReader}, + TilingOrderConfig, + }, + Ident, MatrixLayout, StageDim, + }, + kernels::{matmul::AdvancedConfig, MatmulAvailabilityError}, + }, + prelude::*, +}; +use std::marker::PhantomData; + +use crate::kernel::conv::{ + conv2d::gemm::base::{Convolution, ConvolutionKernel, ConvolutionLaunch, ConvolutionProblem}, + loader::im2col::SimpleIm2colLoader, +}; +use crate::kernel::conv::{conv2d::gemm::Config as _, loader::bias::BiasLoader}; + +/// Performs matrix multiplication at the global level, with each plane sharing the same responsibilities +/// - All planes load data to the stage +/// - All planes are used in the stage matmul computation +pub struct ImplicitGemmConvolution< + EG: Numeric, + ES: Numeric, + Acc: Numeric, + SMM: stage::Matmul, +> { + _eg: PhantomData, + _es: PhantomData, + _acc: PhantomData, + _stage_matmul: PhantomData, +} + +#[cube] +impl Convolution + for ImplicitGemmConvolution +where + EG: Numeric, + ES: Numeric, + Acc: Numeric, + SMMConf: stage::Config, + SMM: stage::Matmul< + ES, + EG, + Acc, + LhsReader = LhsReader, + RhsReader = RhsReader, + Config = SMMConf, + >, +{ + type LhsLoader = SimpleIm2colLoader; + type RhsLoader = RhsLoader; + type AccumulatorLoader = BiasLoader; + + type Out = Unloader; + type Accumulator = SMM::Accumulator; + + fn execute( + mut lhs_loader: Self::LhsLoader, + mut rhs_loader: Self::RhsLoader, + mut acc_loader: Self::AccumulatorLoader, + mut out_unloader: Self::Out, + acc: &mut Self::Accumulator, + k_range: (u32, u32), + #[comptime] config: Self::Config, + ) { + let k_step = SMM::K; + let range = k_range.1 - k_range.0; + #[allow(clippy::manual_div_ceil)] + let num_loops = (range + k_step - 1) / k_step; + + Self::AccumulatorLoader::fill_stage(&mut acc_loader, config.to_smm_config()); + let (mut lhs_tile, mut rhs_tile) = SMM::init_tile_inputs(config.to_smm_config()); + + sync_units(); + + SMM::fill_accumulator::( + &mut acc_loader, + acc, + config.to_smm_config(), + ); + + for _ in 0..num_loops { + sync_units(); + + let lhs_stage_reader = &Self::LhsLoader::fill_stage(&mut lhs_loader, config); + let rhs_stage_reader = + &Self::RhsLoader::fill_stage(&mut rhs_loader, config.to_matmul_config()); + + sync_units(); + + SMM::execute( + lhs_stage_reader, + rhs_stage_reader, + &mut lhs_tile, + &mut rhs_tile, + acc, + config.to_smm_config(), + ); + + Self::LhsLoader::advance_view(&mut lhs_loader, k_step); + Self::RhsLoader::advance_view(&mut rhs_loader, k_step); + } + + sync_units(); + + SMM::read_accumulator::( + acc, + &mut out_unloader, + config.to_smm_config(), + config, + ); + } + + fn init_lhs_loader( + lhs: &Tensor>, + x_offset: u32, + y_offset: u32, + #[comptime] config: Self::Config, + ) -> Self::LhsLoader { + Self::LhsLoader::new( + lhs, + config.out_shape(0), + config.out_shape(1), + x_offset, + y_offset, + config, + ) + } + + fn init_rhs_loader( + rhs: &Tensor>, + x_offset: u32, + y_offset: u32, + #[comptime] config: Self::Config, + ) -> Self::RhsLoader { + Self::RhsLoader::new::(rhs, x_offset, y_offset, 0, config) + } + + fn init_bias_loader( + bias: &Tensor>, + n_offset: u32, + #[comptime] config: Self::Config, + #[comptime] has_bias: bool, + ) -> Self::AccumulatorLoader { + Self::AccumulatorLoader::new(bias, n_offset, config.to_smm_config(), has_bias) + } + + fn init_unloader(out: &mut Tensor>, x_offset: u32, y_offset: u32) -> Self::Out { + Self::Out::new(out, x_offset, y_offset, 0) + } + + fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator { + SMM::init_accumulator(config.to_smm_config()) + } +} + +impl ConvolutionKernel for ImplicitGemmConvolution +where + EG: Numeric, + ES: Numeric, + Acc: Numeric, + SMM: stage::Matmul, +{ + type Config = config::Config>; + + fn check_config(config: Self::Config) { + SMM::check_config(config.to_smm_config()); + } + + fn check_availability( + client: &ComputeClient, + ) -> Result<(), MatmulAvailabilityError> { + SMM::check_availability::(client) + } + + fn make_config( + problem: &ConvolutionProblem, + cube_dim: &CubeDim, + cube_count: &CubeCount, + advanced_config: &AdvancedConfig, + ) -> Self::Config { + let smm_config = SMM::make_config( + &problem.as_matmul_problem(), + cube_dim, + cube_count, + advanced_config, + ); + + config::Config::new( + homogeneous::Config::new( + smm_config, + problem.m as u32 % SMM::M != 0, + problem.n as u32 % SMM::N != 0, + problem.k as u32 % SMM::K != 0, + problem.lhs_layout, + problem.rhs_layout, + problem.lhs_line_size as u32, + problem.rhs_line_size as u32, + problem.out_line_size as u32, + ), + (problem.out_shape_y as u32, problem.out_shape_x as u32), + problem.kernel_size, + &problem.options, + problem.has_bias, + ) + } +} + +impl< + EG: Numeric, + ES: Numeric, + Acc: Numeric, + SMM: stage::Matmul, RhsReader = RhsReader>, + > ConvolutionLaunch for ImplicitGemmConvolution +{ + unsafe fn launch_unchecked( + client: &ComputeClient<::Server, ::Channel>, + cube_dim: CubeDim, + cube_count: CubeCount, + input: TensorArg<'_, R>, + weight: TensorArg<'_, R>, + bias: TensorArg<'_, R>, + out: TensorArg<'_, R>, + config: >::Config, + ) { + Self::check_config(config); + + implicit_conv::launch_unchecked::( + client, + cube_count, + cube_dim, + input, + weight, + bias, + out, + config, + config.has_bias, + ); + } +} + +#[cube(launch_unchecked)] +pub(crate) fn implicit_conv< + EG: Numeric, + ES: Numeric, + Acc: Numeric, + GMM: Convolution, + SMM: stage::Matmul, +>( + lhs: &Tensor>, + rhs: &Tensor>, + bias: &Tensor>, + out: &mut Tensor>, + #[comptime] config: GMM::Config, + #[comptime] has_bias: bool, +) { + let x_offset = CUBE_POS_X * config.stage_dim(Ident::Lhs).num_elements_x_dim(); + let y_offset = CUBE_POS_Y * config.stage_dim(Ident::Rhs).num_elements_y_dim(); + let k_range = (0, rhs.shape(0)); + + GMM::execute( + GMM::init_lhs_loader(lhs, x_offset, k_range.0, config), + GMM::init_rhs_loader(rhs, k_range.0, y_offset, config), + GMM::init_bias_loader(bias, y_offset, config, has_bias), + GMM::init_unloader(out, x_offset, y_offset), + &mut GMM::init_accumulator(config), + k_range, + config, + ); +} + +pub mod config { + use std::ops::Deref; + + use burn_tensor::ops::ConvOptions; + use cubecl::linalg::matmul::components::MatmulConfig; + + use crate::kernel::conv::conv2d::gemm::{self}; + + use super::*; + + #[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)] + pub struct Config { + matmul: M, + + out_shape: (u32, u32), + + kernel_size: (u32, u32), + stride: (u32, u32), + dilation: (u32, u32), + padding: (i32, i32), + + pub has_bias: bool, + } + + impl Deref for Config { + type Target = M; + + fn deref(&self) -> &Self::Target { + &self.matmul + } + } + + impl global::Config for Config { + type SmmConfig = M::SmmConfig; + + fn to_smm_config(&self) -> Self::SmmConfig { + self.matmul.to_smm_config() + } + + fn global_line_size(&self, ident: Ident) -> u32 { + self.matmul.global_line_size(ident) + } + + fn stage_line_size(&self, ident: Ident) -> u32 { + self.matmul.stage_line_size(ident) + } + + fn stage_dim(&self, ident: Ident) -> Box { + self.matmul.stage_dim(ident) + } + + fn layout(&self, ident: Ident) -> MatrixLayout { + self.matmul.layout(ident) + } + + fn num_planes(&self) -> u32 { + self.matmul.num_planes() + } + + fn plane_dim(&self) -> u32 { + self.matmul.plane_dim() + } + + fn tiling_order(&self, ident: Ident) -> TilingOrderConfig { + self.matmul.tiling_order(ident) + } + + fn check_m_bounds(&self) -> bool { + self.matmul.check_m_bounds() + } + + fn check_n_bounds(&self) -> bool { + self.matmul.check_n_bounds() + } + + fn check_k_bounds(&self) -> bool { + self.matmul.check_k_bounds() + } + + fn transpose_load(&self, ident: Ident) -> bool { + self.matmul.transpose_load(ident) + } + } + + impl gemm::Config for Config { + fn out_shape(&self, dim: u32) -> u32 { + match dim { + 0 => self.out_shape.0, + 1 => self.out_shape.1, + _ => unreachable!(), + } + } + + fn kernel_size(&self, dim: u32) -> u32 { + match dim { + 0 => self.kernel_size.0, + 1 => self.kernel_size.1, + _ => unreachable!(), + } + } + + fn dilation(&self, dim: u32) -> u32 { + match dim { + 0 => self.dilation.0, + 1 => self.dilation.1, + _ => unreachable!(), + } + } + + fn stride(&self, dim: u32) -> u32 { + match dim { + 0 => self.stride.0, + 1 => self.stride.1, + _ => unreachable!(), + } + } + + fn padding(&self, dim: u32) -> i32 { + match dim { + 0 => self.padding.0, + 1 => self.padding.1, + _ => unreachable!(), + } + } + } + + impl MatmulConfig for Config {} + + impl Config { + #[allow(clippy::too_many_arguments)] + pub fn new( + matmul: M, + out_shape: (u32, u32), + kernel_size: (u32, u32), + conv_args: &ConvOptions<2>, + has_bias: bool, + ) -> Self { + Self { + matmul, + out_shape, + kernel_size, + stride: (conv_args.stride[0] as u32, conv_args.stride[1] as u32), + dilation: (conv_args.dilation[0] as u32, conv_args.dilation[1] as u32), + padding: (conv_args.padding[0] as i32, conv_args.padding[1] as i32), + has_bias, + } + } + + pub fn to_matmul_config(self) -> M { + self.matmul + } + } +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/mod.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/mod.rs new file mode 100644 index 0000000000..6cf245d4db --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/mod.rs @@ -0,0 +1 @@ +pub mod base; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs new file mode 100644 index 0000000000..0ecf1880a6 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs @@ -0,0 +1,269 @@ +use burn_tensor::{ + ops::{conv::calculate_conv_output_size, ConvOptions}, + Shape, +}; +use cubecl::{ + ir::{Elem, FloatKind}, + linalg::matmul::{ + self, + components::{ + stage::{S4x2x4, S8x4x2}, + MatrixLayout, + }, + }, + tensor_line_size, tf32, Feature, +}; +use half::{bf16, f16}; + +use crate::{ + kernel::{ + conv::{ + conv2d::gemm::{ + algorithm::{Algorithm, Cmma}, + base::{ConvolutionLaunch, ConvolutionProblem}, + }, + nchw_to_nhwc, Conv2dAutotuneKey, + }, + into_contiguous, + }, + ops::{numeric::empty_device, permute, reshape}, + tensor::JitTensor, + FloatElement, JitRuntime, +}; + +/// Large m stage size for the usual case where `batch_size * out_h * out_w` is significantly larger +/// than `out_channels` +pub type CmmaLargeMAlgorithm = Cmma; +/// Balanced stage size for cases where `batch_size * out_h * out_w` is relatively small and `k` or +/// `out_channels` is relatively large +pub type CmmaBalancedAlgorithm = Cmma; + +macro_rules! select_launch_algo { + ($algo:tt, $float:ty, $input:expr) => { + match (<$float>::as_elem(), has_tf32(&$input)) { + (Elem::Float(FloatKind::F32), true) => { + conv2d_gemm_with_algo::> + } + (Elem::Float(FloatKind::F16), _) => { + conv2d_gemm_with_algo::> + } + (Elem::Float(FloatKind::BF16), _) => { + conv2d_gemm_with_algo::> + } + _ => conv2d_gemm_with_algo::>, + } + }; +} + +/// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul +/// components. Uses [`CmmaLargeMAlgorithm`] for the stage size +/// +/// * `input` - The input feature map +/// * `weight` - The weights (filter) applied to each kernel +/// * `bias` - The bias added to each channel +/// * `options` - The options to use for the convolution +pub fn conv2d_gemm_cmma_large_m( + input: JitTensor, + weight: JitTensor, + bias: Option>, + options: ConvOptions<2>, +) -> JitTensor { + let launch = select_launch_algo!(CmmaLargeMAlgorithm, F, input); + launch(input, weight, bias, options) +} + +/// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul +/// components. Uses [`CmmaBalancedAlgorithm`] for the stage size +/// +/// * `input` - The input feature map +/// * `weight` - The weights (filter) applied to each kernel +/// * `bias` - The bias added to each channel +/// * `options` - The options to use for the convolution +/// +/// +pub fn conv2d_gemm_cmma_balanced( + input: JitTensor, + weight: JitTensor, + bias: Option>, + options: ConvOptions<2>, +) -> JitTensor { + let launch = select_launch_algo!(CmmaBalancedAlgorithm, F, input); + launch(input, weight, bias, options) +} + +/// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul +/// components, using the specified algorithm. +/// +/// * `input` - The input feature map +/// * `weight` - The weights (filter) applied to each kernel +/// * `bias` - The bias added to each channel +/// * `options` - The options to use for the convolution +/// +/// +pub fn conv2d_gemm_with_algo>( + input: JitTensor, + weight: JitTensor, + bias: Option>, + options: ConvOptions<2>, +) -> JitTensor { + let [batch_size, in_channels, height, width] = input.shape.dims(); + let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); + + let out_h = calculate_conv_output_size( + kernel_h, + options.stride[0], + options.padding[0], + options.dilation[0], + height, + ); + let out_w = calculate_conv_output_size( + kernel_w, + options.stride[1], + options.padding[1], + options.dilation[1], + width, + ); + + let input = match input.is_contiguous() { + true => nchw_to_nhwc::(input), + false => into_contiguous(permute(input, &[0, 2, 3, 1])), + }; + let weight = into_contiguous(permute(weight, &[2, 3, 1, 0])); + + // Implicit GEMM matrix size + let gemm_m = batch_size * out_h * out_w; + let gemm_n = out_channels; + let gemm_k = kernel_h * kernel_w * in_channels; + + let weight = reshape(weight, Shape::new([gemm_k, gemm_n])); + + let out_shape = Shape::new([gemm_m, gemm_n]); + let out = empty_device::(input.client.clone(), input.device.clone(), out_shape); + + // Target 128 bit accesses + let available_vectorizations = R::supported_line_sizes() + .iter() + .copied() + .filter(|it| *it as usize * size_of::() <= 16) + .collect::>(); + let lhs_line_size = tensor_line_size( + &available_vectorizations, + &input.shape.dims, + &input.strides, + 3, + ); + let rhs_line_size = tensor_line_size( + &available_vectorizations, + &weight.shape.dims, + &weight.strides, + 1, + ); + let out_line_size = + tensor_line_size(&available_vectorizations, &out.shape.dims, &out.strides, 1); + + let problem = ConvolutionProblem { + m: gemm_m, + n: gemm_n, + k: gemm_k, + lhs_layout: matmul::components::MatrixLayout::RowMajor, + rhs_layout: matmul::components::MatrixLayout::RowMajor, + lhs_line_size, + rhs_line_size, + out_line_size, + + kernel_size: (kernel_h as u32, kernel_w as u32), + options, + out_shape_y: out_h, + out_shape_x: out_w, + + has_bias: bias.is_some(), + }; + + if !Alg::can_launch::(&input.client, &problem) { + panic!("Can't do implicit GEMM"); + } + + let cube_dim = Alg::cube_dim(); + let cube_count = Alg::cube_count(&problem); + + let advanced_config = Default::default(); + let config = Alg::make_config(&problem, &cube_dim, &cube_count, &advanced_config); + let bias = bias.unwrap_or_else(|| { + empty_device::(input.client.clone(), input.device.clone(), Shape::new([1])) + }); + + unsafe { + Alg::GlobalConvolution::launch_unchecked::( + &input.client, + cube_dim, + cube_count, + input.as_tensor_arg::(lhs_line_size), + weight.as_tensor_arg::(rhs_line_size), + bias.as_tensor_arg::(out_line_size), + out.as_tensor_arg::(out_line_size), + config, + ); + } + + // Reset to NCHW + let out = reshape(out, Shape::new([batch_size, out_h, out_w, out_channels])); + permute(out, &[0, 3, 1, 2]) +} + +pub fn problem_from_key( + key: &Conv2dAutotuneKey, + out_h: usize, + out_w: usize, +) -> ConvolutionProblem { + let in_stride_2 = key.in_channels; + let in_stride_1 = key.width * in_stride_2; + let in_stride_0 = key.height * in_stride_1; + + let m = key.batch_size * out_h * out_w; + let n = key.out_channels; + let k = key.kernel_size[0] * key.kernel_size[1] * key.in_channels; + + let options = ConvOptions { + stride: key.stride, + padding: key.padding, + dilation: key.dilation, + groups: key.groups, + }; + + // Target 128 bit accesses + let available_vectorizations = R::supported_line_sizes() + .iter() + .copied() + .filter(|it| *it as usize * size_of::() <= 16) + .collect::>(); + let lhs_line_size = tensor_line_size( + &available_vectorizations, + &[key.batch_size, key.height, key.width, key.in_channels], + &[in_stride_0, in_stride_1, in_stride_2, 1], + 3, + ); + let rhs_line_size = tensor_line_size(&available_vectorizations, &[k, n], &[n, 1], 1); + let out_line_size = tensor_line_size(&available_vectorizations, &[m, n], &[n, 1], 1); + + ConvolutionProblem { + m, + n, + k, + lhs_layout: MatrixLayout::RowMajor, + rhs_layout: MatrixLayout::RowMajor, + lhs_line_size, + rhs_line_size, + out_line_size, + kernel_size: (key.kernel_size[0] as u32, key.kernel_size[1] as u32), + options, + out_shape_y: out_h, + out_shape_x: out_w, + has_bias: key.has_bias, + } +} + +pub(crate) fn has_tf32(c: &JitTensor) -> bool { + c.client + .properties() + .feature_enabled(Feature::Type(Elem::Float(FloatKind::TF32))) +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/bias.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/bias.rs new file mode 100644 index 0000000000..bb4c5bb017 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/bias.rs @@ -0,0 +1,116 @@ +use std::marker::PhantomData; + +use cubecl::{ + linalg::matmul::components::{ + global::AccumulatorLoader, + stage::{self, Stage}, + tile::{self, Config as _}, + Ident, + }, + prelude::*, +}; + +use crate::kernel::conv::reader::bias::BiasReader; + +/// Special loader to broadcast the 1D bias to the 2D accumulator matrix +#[derive(CubeType)] +pub struct BiasLoader { + pub tensor_view: BiasReader, + pub stage: Stage, + pub has_bias: bool, + _config: PhantomData, +} + +#[cube] +impl AccumulatorLoader + for BiasLoader +{ + fn fill_stage(this: &mut Self, #[comptime] config: G) { + if this.has_bias { + let stage_dim = config.stage_dim(Ident::Rhs); + let line_size = config.line_size(Ident::Out); + + let num_stage_elements = stage_dim.num_elements_y_dim(); + + let unit_id = UNIT_POS_Y * config.plane_dim() + UNIT_POS_X; + let unit_position_base = unit_id * line_size; + + let mut slice = this.stage.as_slice_mut(); + + if unit_position_base < num_stage_elements { + let read_line = this + .tensor_view + .load_simple::(unit_position_base, config); + slice[unit_id] = Line::cast_from(read_line); + } + } + } + + /// Load accumulator + fn load>( + this: &mut Self, + acc: &mut Tile::Accumulator, + tile_n: u32, + #[comptime] config: Tile::Config, + ) { + if this.has_bias { + let line_size = config.line_size(Ident::Out); + let tile_elems = Tile::N / line_size; + let start = tile_n * tile_elems; + let slice = this.stage.as_slice_mut().slice(start, start + tile_elems); + Tile::fill_accumulator(&slice, acc, 0, config); + } else { + Tile::zero_accumulator(acc, config); + } + } +} + +#[cube] +impl BiasLoader { + pub fn new( + tensor: &Tensor>, + n_offset: u32, + #[comptime] config: G, + #[comptime] has_bias: bool, + ) -> Self { + if has_bias { + let stage = { + let line_size = config.line_size(Ident::Out); + + let smem = SharedMemory::new_lined( + comptime!(config.stage_dim(Ident::Rhs).num_elements_y_dim() / line_size), + line_size, + ); + + Stage:: { smem } + }; + let tensor_view = BiasReader:: { + tensor, + n_offset, + shape_n: tensor.shape(0), + }; + + BiasLoader:: { + tensor_view, + stage, + has_bias, + _config: PhantomData::.runtime(), + } + } else { + let stage = Stage:: { + smem: SharedMemory::new(1), + }; + let tensor_view = BiasReader:: { + tensor, + n_offset: 0, + shape_n: 0, + }; + BiasLoader:: { + stage, + tensor_view, + has_bias, + _config: PhantomData::.runtime(), + } + } + } +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs new file mode 100644 index 0000000000..0a1ed8728c --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs @@ -0,0 +1,147 @@ +use cubecl::{ + linalg::matmul::components::{ + global::Loader, + stage::{ + multi_buffer::LhsReader, ColMajorTiling, RowMajorTiling, Stage, TilingOrder as _, + TilingOrderConfig, + }, + Ident, + }, + prelude::*, +}; +use std::marker::PhantomData; + +use crate::kernel::conv::{reader::im2col::Im2colReader, Config}; + +/// Loader that translates matrix coordinates to input coordinates using the `im2col` algorithm +#[derive(CubeType)] +pub struct SimpleIm2colLoader { + pub tensor_view: Im2colReader, + pub stage: Stage, + _config: PhantomData, +} + +#[cube] +impl Loader for SimpleIm2colLoader { + type StageReader = LhsReader; + + fn fill_stage(this: &mut Self, #[comptime] config: G) -> Self::StageReader { + SimpleIm2col::load_to_slice::( + &this.tensor_view, + &mut this.stage.as_slice_mut(), + Ident::Lhs, + config, + ); + LhsReader::new(this.stage) + } + + fn advance_view(this: &mut Self, k_offset: u32) { + this.tensor_view.update_view(k_offset); + } +} + +#[cube] +impl SimpleIm2colLoader { + pub fn new( + tensor: &Tensor>, + shape_out_y: u32, + shape_out_x: u32, + x_offset: u32, + y_offset: u32, + #[comptime] config: G, + ) -> Self { + let stage = Stage::new::(Ident::Lhs, config.to_smm_config()); + let shape_batch = tensor.shape(0); + let shape_channel = tensor.shape(3); + + let shape_m = shape_batch * shape_out_y * shape_out_x; + let shape_k = shape_channel * config.kernel_size(0) * config.kernel_size(1); + + let tensor_view = Im2colReader:: { + tensor, + m_offset: x_offset, + k_offset: y_offset, + stride_batch: tensor.stride(0), + stride_y: tensor.stride(1), + stride_x: tensor.stride(2), + stride_channel: tensor.stride(3), + shape_y: tensor.shape(1), + shape_x: tensor.shape(2), + shape_channel, + shape_out_y, + shape_out_x, + + shape_m, + shape_k, + }; + + SimpleIm2colLoader:: { + tensor_view, + stage, + _config: PhantomData::.runtime(), + } + } +} + +#[derive(CubeType, Clone, Copy)] +/// Loads the content of all tiles in the tensor view using all planes, +/// iterating with steps determined by the plane's dimension. +pub struct SimpleIm2col; + +#[cube] +impl SimpleIm2col { + pub fn load_to_slice( + read_view: &Im2colReader, + slice: &mut SliceMut>, + #[comptime] ident: Ident, + #[comptime] config: G, + ) { + let stage_dim = config.stage_dim(ident); + let line_size = config.global_line_size(ident); + + let num_stage_elements = stage_dim.total_elements(); + let total_units = comptime!(config.num_planes() * config.plane_dim()); + let jump_length = comptime!(total_units * line_size); + let num_loads_per_unit = num_stage_elements / jump_length; + + #[allow(clippy::all)] + let _ = comptime!(check_jump_divides_well(num_stage_elements, jump_length)); + + let unit_id = UNIT_POS_Y * config.plane_dim() + UNIT_POS_X; + let unit_position_base = unit_id * line_size; + + for i in 0..num_loads_per_unit { + let unit_position = unit_position_base + i * jump_length; + + let tile_num_elements = stage_dim.tile_num_elements(); + let nth_tile = unit_position / tile_num_elements; + let pos_within_tile = unit_position % tile_num_elements; + + let (tile_x, tile_y) = match config.tiling_order(ident) { + TilingOrderConfig::RowMajor => RowMajorTiling::to_x_y( + nth_tile, + stage_dim.num_tiles_x_dim(), + stage_dim.num_tiles_y_dim(), + ), + TilingOrderConfig::ColMajor => ColMajorTiling::to_x_y( + nth_tile, + stage_dim.num_tiles_x_dim(), + stage_dim.num_tiles_y_dim(), + ), + }; + + let line_read = + read_view.load_simple::(tile_x, tile_y, pos_within_tile, ident, config); + + slice[unit_position / line_size] = Line::cast_from(line_read); + } + } +} + +pub fn check_jump_divides_well(num_stage_elements: u32, jump_length: u32) { + assert!( + num_stage_elements % jump_length == 0, + "Too many data will be loaded, resulting in out of bounds. + Try setting line size and number of planes so that jump_length divides num_stage_elements." + ); +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/mod.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/mod.rs new file mode 100644 index 0000000000..13d3809513 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/mod.rs @@ -0,0 +1,2 @@ +pub mod bias; +pub mod im2col; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/mod.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/mod.rs new file mode 100644 index 0000000000..5fd4a309b9 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/mod.rs @@ -0,0 +1,10 @@ +pub mod algorithm; +pub mod base; +mod config; +pub mod homogeneous; +pub mod launch; +pub mod loader; +pub mod reader; + +pub use config::*; +pub use launch::*; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/bias.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/bias.rs new file mode 100644 index 0000000000..67162a28a8 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/bias.rs @@ -0,0 +1,38 @@ +use cubecl::{ + linalg::matmul::components::{stage, Ident}, + prelude::*, +}; + +#[derive(CubeType)] +/// A view of a tensor that starts reading data from a specified offset. +/// Ensures safe access by preventing out-of-bounds errors. +/// Includes pre-fetched shapes and strides for optimized performance. +pub struct BiasReader { + pub tensor: *const Tensor>, + pub n_offset: u32, + pub shape_n: u32, +} + +unsafe impl Sync for BiasReader {} +unsafe impl Send for BiasReader {} + +#[cube] +impl BiasReader { + /// Load the 1D bias into shared memory + pub fn load_simple(&self, unit_id: u32, #[comptime] config: G) -> Line { + let line_size = config.line_size(Ident::Out); + + let view_n = self.n_offset + unit_id; + let read_pos = view_n / line_size; + + select( + view_n < self.shape_n, + self.read(read_pos), + Line::empty(line_size).fill(E::from_int(0)), + ) + } + + fn read(&self, position: u32) -> Line { + unsafe { *(*self.tensor).index_unchecked(position) } + } +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs new file mode 100644 index 0000000000..b278bb051b --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs @@ -0,0 +1,112 @@ +use cubecl::{linalg::matmul::components::Ident, prelude::*}; + +use crate::kernel::conv::Config; + +#[derive(CubeType)] +/// A view of a feature map tensor that starts reading data from a specified offset. +/// Ensures safe access by preventing out-of-bounds errors. +/// Includes pre-fetched shapes and strides for optimized performance. +pub struct Im2colReader { + pub tensor: *const Tensor>, + pub m_offset: u32, + pub k_offset: u32, + + pub stride_batch: u32, + pub stride_y: u32, + pub stride_x: u32, + pub stride_channel: u32, + + pub shape_y: u32, + pub shape_x: u32, + pub shape_channel: u32, + + pub shape_out_y: u32, + pub shape_out_x: u32, + + pub shape_m: u32, + pub shape_k: u32, +} + +unsafe impl Sync for Im2colReader {} +unsafe impl Send for Im2colReader {} + +#[cube] +impl Im2colReader { + /// Advance the view along the k dimension by a specified offset, `k_offset`. + pub fn update_view(&mut self, k_offset: u32) { + self.k_offset += k_offset; + } + + /// Reads data from the tensor view at the specified tile coordinates (tile_x, tile_y) using + /// the `im2col` algorithm to translate them to input coordinates. + /// + /// Each unit loads one line in a coalesced manner for improved efficiency. + /// For row-major tensors, subsequent units read lines horizontally within the tile, + /// while for column-major tensors, they read lines vertically. + /// + /// # Note + /// + /// Out-of-bounds reads will be translated to zeros. + pub fn load_simple( + &self, + tile_x: u32, + tile_y: u32, + unit_id: u32, + #[comptime] ident: Ident, + #[comptime] config: G, + ) -> Line { + let line_size = config.global_line_size(ident); + let tile_size_x = config.stage_dim(ident).tile_size_x_dim(); + let tile_size_y = config.stage_dim(ident).tile_size_y_dim(); + + let view_tile_m = tile_x * tile_size_x + self.m_offset; + let view_tile_k = tile_y * tile_size_y + self.k_offset; + + let load_m = unit_id / tile_size_y; + let load_k = unit_id % tile_size_y; + + let view_m = view_tile_m + load_m; + let view_k = view_tile_k + load_k; + + let out_x = view_m % self.shape_out_x; + let rem = view_m / self.shape_out_x; + let out_y = rem % self.shape_out_y; + let batch = rem / self.shape_out_y; + + let kernel_w = config.kernel_size(1); + + let channel = view_k % self.shape_channel; + let rem = view_k / self.shape_channel; + let kernel_x = rem % kernel_w; + let kernel_y = rem / kernel_w; + + let y = + (out_y * config.stride(0) + kernel_y * config.dilation(0)) as i32 - config.padding(0); + let x = + (out_x * config.stride(1) + kernel_x * config.dilation(1)) as i32 - config.padding(1); + + let m_in_bounds = comptime!(!config.check_m_bounds()) || view_m < self.shape_m; + let k_in_bounds = comptime!(!config.check_k_bounds()) || view_k < self.shape_k; + let no_padding = comptime!(config.padding(0) == 0 && config.padding(1) == 0); + let hw_in_bounds = no_padding + || (y >= 0 && (y as u32) < self.shape_y && x >= 0 && (x as u32) < self.shape_x); + let in_bounds = m_in_bounds && k_in_bounds && hw_in_bounds; + let read_pos = batch * self.stride_batch + + y as u32 * self.stride_y + + x as u32 * self.stride_x + + channel * self.stride_channel; + + let read_pos = read_pos / line_size; + + let mut res = Line::empty(line_size).fill(F::from_int(0)); + if in_bounds { + res = self.read(read_pos); + } + + res + } + + fn read(&self, position: u32) -> Line { + unsafe { *(*self.tensor).index_unchecked(position) } + } +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/mod.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/mod.rs new file mode 100644 index 0000000000..13d3809513 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/mod.rs @@ -0,0 +1,2 @@ +pub mod bias; +pub mod im2col; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs index abcb8488fb..3a7b23df76 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs @@ -6,7 +6,10 @@ use cubecl::{calculate_cube_count_elemwise, linalg::matmul, prelude::*}; use crate::{ kernel::{ - conv::index, into_contiguous, launch_binop, matmul::matmul, matmul::MatmulStrategy, AddOp, + conv::index, + into_contiguous, launch_binop, + matmul::{cube_strategy, matmul, MatmulStrategy}, + AddOp, }, ops::{numeric::empty_device, reshape, swap_dims}, tensor::JitTensor, @@ -300,9 +303,10 @@ fn execute( let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0])); matmul::launch_ref::( + &cube_strategy::(&client), &client, - weight.as_handle_ref(), - columns.as_handle_ref(), - out.as_handle_ref(), + &weight.as_handle_ref(), + &columns.as_handle_ref(), + &out.as_handle_ref(), ); } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs index 6771f2c5e2..a021f7c089 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs @@ -318,14 +318,23 @@ fn implicit_gemm_kernel( let pos = calculate_positions(gemm_settings); + let in_vec = input.line_size(); + let weight_vec = weight.line_size(); + // Shared memory tiles, currently only holds enough data for // each warp to have its own tile for a single MMA op (8 * 16 * 16 elements) // conceptually a WARPS_PER_CUBE x (CMMA_M * CMMA_K) matrix - let mut smem_input_tile = SharedMemory::::new(cmma_input_tile_size * warps_per_cube); - let mut smem_weight_tile = SharedMemory::::new(cmma_filter_tile_size * warps_per_cube); + let mut smem_input_tile = SharedMemory::::new_lined( + comptime!(cmma_input_tile_size * warps_per_cube / in_vec), + in_vec, + ); + let mut smem_weight_tile = SharedMemory::::new_lined( + comptime!(cmma_filter_tile_size * warps_per_cube / weight_vec), + weight_vec, + ); - let input_tile_start = pos.cube_linear_warp_idx * cmma_input_tile_size; - let weight_tile_start = pos.cube_linear_warp_idx * cmma_filter_tile_size; + let input_tile_start = pos.cube_linear_warp_idx * (cmma_input_tile_size / in_vec); + let weight_tile_start = pos.cube_linear_warp_idx * (cmma_filter_tile_size / weight_vec); let mut input_tile = smem_input_tile.slice_mut(input_tile_start, input_tile_start + cmma_input_tile_size); let mut weight_tile = @@ -441,8 +450,8 @@ fn execute_gemm( weight: &Tensor>, bias: &Tensor, out: &mut SliceMut, - input_tile: &mut SliceMut, - weight_tile: &mut SliceMut, + input_tile: &mut SliceMut>, + weight_tile: &mut SliceMut>, dims: &Dimensions, pos: &Positions, args: &ConvArgs, @@ -484,7 +493,7 @@ fn execute_gemm( fn load_input_tile( input: &Tensor>, args: &ConvArgs, - tile: &mut SliceMut, + tile: &mut SliceMut>, dims: &Dimensions, pos: &Positions, k: u32, @@ -566,21 +575,18 @@ fn load_input_tile( + channel; let value = select( in_bounds && m_in_bounds && k_in_bounds, - FMat::cast_from(input[idx / vec]), - FMat::vectorized(0.0, vec), + Line::cast_from(input[idx / vec]), + Line::new(FMat::new(0.0)), ); - #[unroll] - for i in 0..vec { - tile[m + i] = value[i]; - } + tile[m / vec] = value; } } #[cube] fn load_weight_tile( weight: &Tensor>, - tile: &mut SliceMut, + tile: &mut SliceMut>, dims: &Dimensions, pos: &Positions, k: u32, @@ -629,13 +635,10 @@ fn load_weight_tile( let idx = k_idx + global_n; - let value = FMat::cast_from(weight[idx / vec]); - let value = select(k_in_bounds && n_in_bounds, value, FMat::new(0.0)); + let value = Line::cast_from(weight[idx / vec]); + let value = select(k_in_bounds && n_in_bounds, value, Line::new(FMat::new(0.0))); - #[unroll] - for i in 0..vec { - tile[n + i] = value[i]; - } + tile[n / vec] = value; } } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs b/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs index a998bea86d..62f0e56d8f 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs @@ -113,7 +113,7 @@ fn nchw_to_nhwc_kernel( let batch_offset = batch * input.stride(0); let warp_id = plane_broadcast(unit_pos / 32, 0); - let warp_id_x = warp_id / CUBE_DIM_Y; + let warp_id_x = warp_id % tiles_x; let tile_x = CUBE_POS_X * tiles_x + warp_id_x; let tile_y = ABSOLUTE_POS_Y; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/mod.rs b/crates/burn-jit/src/kernel/conv/conv2d/mod.rs index 13900acdc1..f48a490aa6 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/mod.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/mod.rs @@ -1,15 +1,18 @@ mod base; mod col2im; mod direct; +mod gemm; mod im2col; mod implicit_gemm; mod layout_swap; mod transpose_direct; + mod tune; pub use base::*; pub use col2im::*; pub use direct::*; +pub use gemm::*; pub use im2col::*; pub use implicit_gemm::*; pub use layout_swap::*; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs index 4a8122a478..56fc73965e 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs @@ -3,15 +3,19 @@ use burn_tensor::{ ElementConversion, Shape, }; use cubecl::{ - tune, + ir::{Elem, FloatKind}, + tf32, tune, tune::{local_tuner, tune_with, LocalTuner}, }; +use half::{bf16, f16}; use crate::{ kernel::{ conv::{ - batches_per_run, can_do_implicit_gemm, conv2d_direct, conv2d_im2col, - conv2d_implicit_gemm, + algorithm::Algorithm, batches_per_run, can_do_implicit_gemm, conv2d_direct, + conv2d_gemm_cmma_balanced, conv2d_gemm_cmma_large_m, conv2d_im2col, + conv2d_implicit_gemm, has_tf32, problem_from_key, CmmaBalancedAlgorithm, + CmmaLargeMAlgorithm, }, prng::random_uniform, }, @@ -40,7 +44,13 @@ pub fn conv2d_autotune( } #[tune( - operations(conv2d_direct, conv2d_im2col, conv2d_implicit_gemm), + operations( + conv2d_direct, + conv2d_im2col, + conv2d_implicit_gemm, + conv2d_gemm_cmma_large_m, + conv2d_gemm_cmma_balanced + ), create_key = create_key::, should_run = should_run )] @@ -72,6 +82,23 @@ pub fn conv2d_operations( tune_with!(input, weights, bias, options) } +macro_rules! check_algo { + ($algo:tt, $float:ty, $input:expr, $problem:expr) => { + match (<$float>::as_elem(), has_tf32(&$input)) { + (Elem::Float(FloatKind::F32), true) => { + $algo::<$float, tf32, f32>::can_launch::(&$input.client, &$problem) + } + (Elem::Float(FloatKind::F16), _) => { + $algo::<$float, f16, f16>::can_launch::(&$input.client, &$problem) + } + (Elem::Float(FloatKind::BF16), _) => { + $algo::<$float, bf16, f32>::can_launch::(&$input.client, &$problem) + } + _ => $algo::<$float, f16, f32>::can_launch::(&$input.client, &$problem), + } + }; +} + fn should_run( op: &Conv2dOperations, key: &JitAutotuneKey, @@ -97,6 +124,8 @@ fn should_run( key.width, ); + let conv_problem = problem_from_key::(key, out_h, out_w); + match index { // im2col 1 => batches_per_run(key.batch_size, out_h, out_w).is_some(), @@ -111,6 +140,10 @@ fn should_run( out_w, &op.input.client, ), + // GEMM large m + 3 => check_algo!(CmmaLargeMAlgorithm, F, op.input, conv_problem), + // GEMM balanced + 4 => check_algo!(CmmaBalancedAlgorithm, F, op.input, conv_problem), _ => true, } } diff --git a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs index 907b5ef344..163e4796e4 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs @@ -211,29 +211,31 @@ fn compute_offset_and_mask_gradient( let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(num_elements_offset, cube_dim); - deform_col2img_coord_kernel::launch::( - &image.client, - cube_count, - cube_dim, - image.as_handle_ref().as_tensor_arg(1), - offset.as_handle_ref().as_tensor_arg(1), - mask.as_handle_ref().as_tensor_arg(1), - columns.as_handle_ref().as_tensor_arg(1), - grad_offset.as_handle_ref().as_tensor_arg(1), - grad_mask.as_handle_ref().as_tensor_arg(1), - DeformConv2dCol2ImgCoordArgsLaunch::new( - ScalarArg::new(options.stride[0] as u32), - ScalarArg::new(options.stride[1] as u32), - ScalarArg::new(options.dilation[0] as u32), - ScalarArg::new(options.dilation[1] as u32), - ScalarArg::new(E::from_elem(options.padding[0] as f32)), - ScalarArg::new(E::from_elem(options.padding[1] as f32)), - ScalarArg::new(options.offset_groups as u32), - ScalarArg::new(kernel_height as u32), - ScalarArg::new(kernel_width as u32), - ), - use_mask, - ); + unsafe { + deform_col2img_coord_kernel::launch_unchecked::( + &image.client, + cube_count, + cube_dim, + image.as_handle_ref().as_tensor_arg(1), + offset.as_handle_ref().as_tensor_arg(1), + mask.as_handle_ref().as_tensor_arg(1), + columns.as_handle_ref().as_tensor_arg(1), + grad_offset.as_handle_ref().as_tensor_arg(1), + grad_mask.as_handle_ref().as_tensor_arg(1), + DeformConv2dCol2ImgCoordArgsLaunch::new( + ScalarArg::new(options.stride[0] as u32), + ScalarArg::new(options.stride[1] as u32), + ScalarArg::new(options.dilation[0] as u32), + ScalarArg::new(options.dilation[1] as u32), + ScalarArg::new(E::from_elem(options.padding[0] as f32)), + ScalarArg::new(E::from_elem(options.padding[1] as f32)), + ScalarArg::new(options.offset_groups as u32), + ScalarArg::new(kernel_height as u32), + ScalarArg::new(kernel_width as u32), + ), + use_mask, + ) + }; let mask_gradient = if use_mask { Some(grad_mask) } else { None }; (grad_offset, mask_gradient) @@ -253,7 +255,7 @@ struct DeformConv2dCol2ImgCoordArgs { } #[allow(clippy::collapsible_if)] -#[cube(launch)] +#[cube(launch_unchecked)] fn deform_col2img_coord_kernel( image: &Tensor, offset: &Tensor, @@ -267,6 +269,10 @@ fn deform_col2img_coord_kernel( // Position format: [batch, [offset_group, kernel_h, kernel_w, 2], out_h, out_w] // Alternatively : [batch, offset_channels, out_h, out_w] + if ABSOLUTE_POS >= grad_offset.len() { + return; + } + let offset_channels = offset.shape(1); let out_h = offset.shape(2); let out_w = offset.shape(3); diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 562197647f..60c7cbbd1c 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -1,7 +1,12 @@ use super::{init_matmul_output, matmul_simple}; use crate::{tensor::JitTensor, FloatElement, JitRuntime}; use burn_tensor::Shape; -use cubecl::prelude::*; +use cubecl::{ + ir::{Elem, FloatKind}, + linalg::matmul::Strategy, + prelude::*, + Feature, +}; #[cfg(feature = "autotune")] use super::matmul_autotune; @@ -42,16 +47,20 @@ pub fn matmul( match strategy { MatmulStrategy::Simple { grid_x, grid_y } => { let out = init_matmul_output::(&lhs, &rhs); + matmul_simple::(lhs, rhs, out, grid_x, grid_y) } MatmulStrategy::Cube => { let out = init_matmul_output::(&lhs, &rhs); + let client = &lhs.client; + cubecl::linalg::matmul::launch_ref::( + &cube_strategy::(client), client, - lhs.as_handle_ref(), - rhs.as_handle_ref(), - out.as_handle_ref(), + &lhs.as_handle_ref(), + &rhs.as_handle_ref(), + &out.as_handle_ref(), ); out } @@ -60,6 +69,26 @@ pub fn matmul( } } +pub(crate) fn cube_strategy( + client: &ComputeClient, +) -> Strategy { + // TODO: Replace with auto option once cubecl has one + let cmma_available = client.properties().feature_enabled(Feature::Cmma { + a: Elem::Float(FloatKind::F16), + b: Elem::Float(FloatKind::F16), + c: Elem::Float(FloatKind::F32), + m: 16, + k: 16, + n: 16, + }); + let plane_available = client.properties().feature_enabled(Feature::Plane); + match (cmma_available, plane_available) { + (true, _) => Strategy::Accelerated, + (false, true) => Strategy::PlaneMma, + _ => Strategy::Tiling2D(Default::default()), + } +} + pub(crate) fn simple_cube_count( lhs_shape: &Shape, rhs_shape: &Shape, diff --git a/crates/burn-jit/src/kernel/matmul/tune/base.rs b/crates/burn-jit/src/kernel/matmul/tune/base.rs index 38df9f7fe1..e49ea3154c 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/base.rs @@ -5,7 +5,10 @@ use cubecl::tune::{local_tuner, AutotuneOperation, AutotuneOperationSet, LocalTu use crate::{ element::FloatElement, - kernel::{matmul::utils::init_matmul_output, prng::random_like_uniform}, + kernel::{ + matmul::{cube_strategy, utils::init_matmul_output}, + prng::random_like_uniform, + }, ops::numeric::empty_device, tensor::JitTensor, tune_key::JitAutotuneKey, @@ -87,10 +90,10 @@ pub fn matmul_autotune( lhs: JitTensor, rhs: JitTensor, ) -> JitTensor { - let client = lhs.client.clone(); - let output = init_matmul_output::(&lhs, &rhs); + let client = lhs.client.clone(); + static TUNER: LocalTuner = local_tuner!(); TUNER.execute( @@ -149,11 +152,13 @@ matmul_tune_ops!(SimpleMatmul16x16, |lhs, rhs, out| { matmul_tune_ops!( MatmulCube, |lhs: JitTensor, rhs: JitTensor, out: JitTensor| { + let strategy = cube_strategy::(&lhs.client); cubecl::linalg::matmul::launch_ref::( + &strategy, &lhs.client, - lhs.as_handle_ref(), - rhs.as_handle_ref(), - out.as_handle_ref(), + &lhs.as_handle_ref(), + &rhs.as_handle_ref(), + &out.as_handle_ref(), ); } ); diff --git a/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs b/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs index 4e783e74e9..4a32b5d641 100644 --- a/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs +++ b/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs @@ -28,7 +28,7 @@ pub fn reduce_dim_subcube_kernel< let should_unroll = elems_per_thread <= 8; - let warp_id = UNIT_POS / PLANE_DIM; + let warp_id = plane_broadcast(UNIT_POS / PLANE_DIM, 0); let mut shared_memory = RD::init_shared(subcube_size); diff --git a/crates/burn-jit/src/ops/transaction.rs b/crates/burn-jit/src/ops/transaction.rs index 7320186570..b67740bc97 100644 --- a/crates/burn-jit/src/ops/transaction.rs +++ b/crates/burn-jit/src/ops/transaction.rs @@ -60,7 +60,7 @@ where let client = client.unwrap(); async move { - let mut data = client + let mut data: Vec> = client .read_async(bindings) .await .into_iter() diff --git a/crates/burn-jit/src/template/base.rs b/crates/burn-jit/src/template/base.rs index 9ff5f28247..54e50468fb 100644 --- a/crates/burn-jit/src/template/base.rs +++ b/crates/burn-jit/src/template/base.rs @@ -24,7 +24,7 @@ impl CubeTask for SourceKernel { let source = source_template.complete(); CompiledKernel { - entrypoint_name: "kernel".to_string(), + entrypoint_name: "main".to_string(), debug_name: Some(core::any::type_name::()), source, cube_dim: self.cube_dim, diff --git a/crates/burn-jit/src/tests/conv2d.rs b/crates/burn-jit/src/tests/conv2d.rs index f93adffe8f..061ab54e65 100644 --- a/crates/burn-jit/src/tests/conv2d.rs +++ b/crates/burn-jit/src/tests/conv2d.rs @@ -52,7 +52,78 @@ mod tests { output .into_data() - .assert_approx_eq(&output_ref.into_data(), 1); + .assert_approx_eq(&output_ref.into_data(), 2); + } + + /// Regression test for bias loader in new implicit GEMM + #[test] + fn conv2d_should_match_reference_backend_bias_regression() { + let test_device = Default::default(); + let input = + Tensor::::random([1, 1, 1, 1], Distribution::Default, &test_device); + let weight = + Tensor::::random([32, 1, 3, 3], Distribution::Default, &test_device); + let bias = Tensor::::random([32], Distribution::Default, &test_device); + let ref_device = Default::default(); + + let input_ref = Tensor::::from_data(input.to_data(), &ref_device); + let weight_ref = Tensor::::from_data(weight.to_data(), &ref_device); + let bias_ref = Tensor::::from_data(bias.to_data(), &ref_device); + + let options = burn_tensor::ops::ConvOptions::new([1, 1], [1, 1], [1, 1], 1); + + let output = + module::conv2d(input, weight, Some(bias), options.clone()).permute([0, 2, 3, 1]); + let output_ref = + module::conv2d(input_ref, weight_ref, Some(bias_ref), options).permute([0, 2, 3, 1]); + + output + .into_data() + .assert_approx_eq(&output_ref.into_data(), 2); + } + + #[test] + fn nchw_to_nhwc_should_match_into_contiguous() { + let test_device = Default::default(); + let input = + Tensor::::random([4, 72, 53, 56], Distribution::Default, &test_device); + + type Float = ::FloatElem; + + let output = nchw_to_nhwc::(input.clone().into_primitive().tensor()); + let output_ref = into_contiguous( + input + .clone() + .permute([0, 2, 3, 1]) + .into_primitive() + .tensor(), + ); + + into_data_sync::(output) + .assert_approx_eq(&into_data_sync::(output_ref), 4); + } + + /// Regression test for transpose kernel that was causing corruption with 17-64 in channels and + /// at least 17 hw + #[test] + fn nchw_to_nhwc_should_match_into_contiguous_regression() { + let test_device = Default::default(); + let input = + Tensor::::random([1, 18, 17, 1], Distribution::Default, &test_device); + + type Float = ::FloatElem; + + let output = nchw_to_nhwc::(input.clone().into_primitive().tensor()); + let output_ref = into_contiguous( + input + .clone() + .permute([0, 2, 3, 1]) + .into_primitive() + .tensor(), + ); + + into_data_sync::(output) + .assert_approx_eq(&into_data_sync::(output_ref), 4); } #[test] diff --git a/crates/burn-ndarray/src/tensor.rs b/crates/burn-ndarray/src/tensor.rs index a69faf73de..64e8037c91 100644 --- a/crates/burn-ndarray/src/tensor.rs +++ b/crates/burn-ndarray/src/tensor.rs @@ -463,7 +463,7 @@ mod tests { scale: B::float_from_data(TensorData::from([0.009_019_608]), &device), offset: Some(B::int_from_data(TensorData::from([72]), &device)), }; - let qtensor: NdArrayQTensor = B::quantize(tensor.into(), &scheme, qparams); + let qtensor: NdArrayQTensor = B::quantize(tensor, &scheme, qparams); assert_eq!(qtensor.scheme(), &scheme); assert_eq!( diff --git a/crates/burn-tensor/src/tests/module/conv3d.rs b/crates/burn-tensor/src/tests/module/conv3d.rs index b7a12b374c..77c827d928 100644 --- a/crates/burn-tensor/src/tests/module/conv3d.rs +++ b/crates/burn-tensor/src/tests/module/conv3d.rs @@ -293,7 +293,8 @@ mod tests { ), ); - y.to_data().assert_approx_eq(&output.into_data(), 3); + y.to_data() + .assert_approx_eq_diff(&output.into_data(), 0.002); } } } diff --git a/crates/burn-train/src/metric/base.rs b/crates/burn-train/src/metric/base.rs index e0eafe649d..db58c15886 100644 --- a/crates/burn-train/src/metric/base.rs +++ b/crates/burn-train/src/metric/base.rs @@ -19,6 +19,7 @@ pub struct MetricMetadata { } impl MetricMetadata { + /// Fake metric metadata #[cfg(test)] pub fn fake() -> Self { Self { diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index d6287382d9..583bfc1ddf 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -52,12 +52,12 @@ candle = ["burn-core/candle"] cuda-jit = ["burn-core/cuda-jit"] hip-jit = ["burn-core/hip-jit"] ndarray = ["burn-core/ndarray"] +remote = ["burn-core/remote"] +router = ["burn-core/router"] +server = ["burn-core/server"] tch = ["burn-core/tch"] wgpu = ["burn-core/wgpu"] wgpu-spirv = ["burn-core/wgpu-spirv"] -remote = ["burn-core/remote"] -server = ["burn-core/server"] -router = ["burn-core/router"] # Network utils network = ["burn-core/network"] From 3dc4b43e92ec5683ca39c194acf274b131d44dfa Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Fri, 29 Nov 2024 15:08:46 -0500 Subject: [PATCH 4/4] Matmul + CubeCL Update (#2551) --- Cargo.lock | 26 ++-- Cargo.toml | 4 +- backend-comparison/Cargo.toml | 2 + backend-comparison/benches/matmul.rs | 24 +-- backend-comparison/src/lib.rs | 27 ++++ crates/burn-jit/Cargo.toml | 18 +-- .../burn-jit/src/kernel/conv/conv2d/col2im.rs | 2 +- .../conv/conv2d/gemm/homogeneous/base.rs | 14 +- .../kernel/conv/conv2d/gemm/loader/im2col.rs | 7 +- .../burn-jit/src/kernel/conv/conv2d/im2col.rs | 15 +- .../burn-jit/src/kernel/conv/deform_conv2d.rs | 2 +- .../kernel/conv/deform_conv_transpose2d.rs | 4 +- crates/burn-jit/src/kernel/matmul/base.rs | 69 +-------- crates/burn-jit/src/kernel/matmul/mod.rs | 2 - crates/burn-jit/src/kernel/matmul/simple.rs | 143 ------------------ .../burn-jit/src/kernel/matmul/tune/base.rs | 63 +++++--- crates/burn-jit/src/ops/float_ops.rs | 2 +- crates/burn-wgpu/Cargo.toml | 4 +- 18 files changed, 137 insertions(+), 291 deletions(-) delete mode 100644 crates/burn-jit/src/kernel/matmul/simple.rs diff --git a/Cargo.lock b/Cargo.lock index 8897110b75..e7f9b7ea4e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -424,6 +424,7 @@ dependencies = [ "github-device-flow", "half", "indicatif", + "log", "os_info", "percent-encoding", "rand", @@ -435,6 +436,7 @@ dependencies = [ "strum", "strum_macros", "sysinfo 0.32.1", + "tracing-subscriber", "wgpu", "wsl", ] @@ -1666,7 +1668,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1698,7 +1700,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1715,7 +1717,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1733,7 +1735,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1747,7 +1749,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1763,7 +1765,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1789,7 +1791,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "bytemuck", "cubecl-core", @@ -1800,7 +1802,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "cubecl-common 0.4.0", "darling", @@ -1815,7 +1817,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1852,7 +1854,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "async-channel", "async-lock", @@ -1873,7 +1875,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1887,7 +1889,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" +source = "git+https://github.com/tracel-ai/cubecl?rev=a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6#a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index ca45f967bb..1a22a867a4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0f8312690a4328266ae9267314f50fb4ed0835ad" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0f8312690a4328266ae9267314f50fb4ed0835ad" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index 1d1a62cf49..39ddd0c6f5 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -54,6 +54,8 @@ strum_macros = { workspace = true } sysinfo = { workspace = true, features = ["serde"] } wgpu = { workspace = true } wsl = { workspace = true } +tracing-subscriber = { workspace = true } +log = { workspace = true } [dev-dependencies] rstest = { workspace = true } diff --git a/backend-comparison/benches/matmul.rs b/backend-comparison/benches/matmul.rs index d31e7cc954..0e9f3622b1 100644 --- a/backend-comparison/benches/matmul.rs +++ b/backend-comparison/benches/matmul.rs @@ -44,16 +44,20 @@ fn bench( url: Option<&str>, token: Option<&str>, ) { - let benchmarks = [(2, 4096, 4096, 4096), (8, 2048, 2048, 2048)] - .into_iter() - .map(|(b, m, n, k)| { - let shape_lhs = [b, m, k].into(); - let shape_rhs = [b, k, n].into(); - - MatmulBenchmark::::new(shape_lhs, shape_rhs, device.clone()) - }) - .map(run_benchmark) - .collect(); + let benchmarks = [ + (3, 4096, 4096, 4096), + (8, 2048, 2048, 2048), + (2, 4096, 4096, 512), + ] + .into_iter() + .map(|(b, m, n, k)| { + let shape_lhs = [b, m, k].into(); + let shape_rhs = [b, k, n].into(); + + MatmulBenchmark::::new(shape_lhs, shape_rhs, device.clone()) + }) + .map(run_benchmark) + .collect(); save::(benchmarks, device, feature_name, url, token).unwrap(); } diff --git a/backend-comparison/src/lib.rs b/backend-comparison/src/lib.rs index a1eb8f16a1..03e2d70444 100644 --- a/backend-comparison/src/lib.rs +++ b/backend-comparison/src/lib.rs @@ -1,3 +1,7 @@ +use std::error::Error; + +use tracing_subscriber::filter::LevelFilter; + pub mod burnbenchapp; pub mod persistence; @@ -26,10 +30,33 @@ pub fn get_sharing_url(args: &[String]) -> Option<&str> { get_argument(args, "--sharing-url") } +pub fn init_log() -> Result<(), Box> { + let result = tracing_subscriber::fmt() + .with_max_level(LevelFilter::DEBUG) + .without_time() + .try_init(); + + if result.is_ok() { + update_panic_hook(); + } + result +} + +fn update_panic_hook() { + let hook = std::panic::take_hook(); + + std::panic::set_hook(Box::new(move |info| { + log::error!("PANIC => {}", info.to_string()); + hook(info); + })); +} + #[macro_export] macro_rules! bench_on_backend { () => { use std::env; + backend_comparison::init_log().unwrap(); + let args: Vec = env::args().collect(); let url = backend_comparison::get_sharing_url(&args); let token = backend_comparison::get_sharing_token(&args); diff --git a/crates/burn-jit/Cargo.toml b/crates/burn-jit/Cargo.toml index ce5b22b8ac..0811374fd1 100644 --- a/crates/burn-jit/Cargo.toml +++ b/crates/burn-jit/Cargo.toml @@ -16,13 +16,13 @@ autotune = [] default = ["autotune", "std", "fusion", "cubecl/default"] doc = ["default"] export_tests = [ - "burn-tensor-testgen", - "serial_test", - "burn-autodiff/export_tests", - "burn-tensor/export_tests", - "burn-ndarray", - "fusion", - "paste", + "burn-tensor-testgen", + "serial_test", + "burn-autodiff/export_tests", + "burn-tensor/export_tests", + "burn-ndarray", + "fusion", + "paste", ] fusion = ["burn-fusion"] std = ["cubecl/std"] @@ -32,8 +32,8 @@ template = [] burn-common = { path = "../burn-common", version = "0.16.0" } burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [ - "cubecl", - "repr", + "cubecl", + "repr", ] } cubecl = { workspace = true, features = ["linalg"] } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs index 0659561805..0d9c48dc30 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs @@ -145,7 +145,7 @@ fn execute( let input_shape = Shape::new([groups, input_ch_per_group, col_shape_1]); let input = reshape(input, input_shape); - let columns = matmul::(weight, input, MatmulStrategy::default()); + let columns = matmul::(weight, input, None, MatmulStrategy::default()); let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1])); col2im::( diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs index ca7399e72d..582b1e59af 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs @@ -3,7 +3,7 @@ use cubecl::{ components::{ global::{ self, - homogeneous::{self, CyclicLoading, RhsLoader}, + full_load::{self, CyclicLoading, RhsLoader}, unloader::Unloader, AccumulatorLoader, Config as _, Loader, }, @@ -93,9 +93,11 @@ where for _ in 0..num_loops { sync_units(); - let lhs_stage_reader = &Self::LhsLoader::fill_stage(&mut lhs_loader, config); - let rhs_stage_reader = - &Self::RhsLoader::fill_stage(&mut rhs_loader, config.to_matmul_config()); + Self::LhsLoader::fill_stage(&mut lhs_loader, config); + Self::RhsLoader::fill_stage(&mut rhs_loader, config.to_matmul_config()); + + let lhs_stage_reader = &Self::LhsLoader::as_stage_reader(&lhs_loader); + let rhs_stage_reader = &Self::RhsLoader::as_stage_reader(&rhs_loader); sync_units(); @@ -172,7 +174,7 @@ where Acc: Numeric, SMM: stage::Matmul, { - type Config = config::Config>; + type Config = config::Config>; fn check_config(config: Self::Config) { SMM::check_config(config.to_smm_config()); @@ -198,7 +200,7 @@ where ); config::Config::new( - homogeneous::Config::new( + full_load::Config::new( smm_config, problem.m as u32 % SMM::M != 0, problem.n as u32 % SMM::N != 0, diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs index 0a1ed8728c..11ee03d83e 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs @@ -25,19 +25,22 @@ pub struct SimpleIm2colLoader { impl Loader for SimpleIm2colLoader { type StageReader = LhsReader; - fn fill_stage(this: &mut Self, #[comptime] config: G) -> Self::StageReader { + fn fill_stage(this: &mut Self, #[comptime] config: G) { SimpleIm2col::load_to_slice::( &this.tensor_view, &mut this.stage.as_slice_mut(), Ident::Lhs, config, ); - LhsReader::new(this.stage) } fn advance_view(this: &mut Self, k_offset: u32) { this.tensor_view.update_view(k_offset); } + + fn as_stage_reader(this: &Self) -> Self::StageReader { + LhsReader::new(this.stage) + } } #[cube] diff --git a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs index 3a7b23df76..a65c29466c 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs @@ -2,13 +2,13 @@ use burn_tensor::{ ops::{conv::calculate_conv_output_size, ConvOptions}, Shape, }; -use cubecl::{calculate_cube_count_elemwise, linalg::matmul, prelude::*}; +use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ kernel::{ conv::index, into_contiguous, launch_binop, - matmul::{cube_strategy, matmul, MatmulStrategy}, + matmul::{matmul, MatmulStrategy}, AddOp, }, ops::{numeric::empty_device, reshape, swap_dims}, @@ -271,7 +271,7 @@ fn execute_1x1_kernel( let weight = reshape(weight, Shape::new([groups, out_c_per_grp, in_c_per_grp])); let in_shape = Shape::new([groups, in_c_per_grp, batch_size * height * width]); let input = reshape(input, in_shape); - let out = matmul::(weight, input, MatmulStrategy::default()); + let out = matmul::(weight, input, None, MatmulStrategy::default()); let mut out = reshape(out, Shape::new([out_channels, batch_size, height, width])); if let Some(bias) = bias { @@ -290,7 +290,6 @@ fn execute( out_h: usize, out_w: usize, ) { - let client = input.client.clone(); let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); let groups = options.groups; @@ -302,11 +301,5 @@ fn execute( let columns = reshape(columns, Shape::new([groups, col_shape_0, col_shape_1])); let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0])); - matmul::launch_ref::( - &cube_strategy::(&client), - &client, - &weight.as_handle_ref(), - &columns.as_handle_ref(), - &out.as_handle_ref(), - ); + matmul::(weight, columns, Some(out), Default::default()); } diff --git a/crates/burn-jit/src/kernel/conv/deform_conv2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv2d.rs index 438850fe72..b22821aef1 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv2d.rs @@ -298,7 +298,7 @@ pub(crate) fn deform_conv2d( let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_size_0])); let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1])); - let out = matmul::(weight, columns, MatmulStrategy::default()); + let out = matmul::(weight, columns, None, MatmulStrategy::default()); let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w])); let out = swap_dims(out, 0, 1); diff --git a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs index 163e4796e4..b75ac43182 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs @@ -108,7 +108,7 @@ fn compute_weight_grad( let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1])); let columns = swap_dims(columns, 1, 2); - let grad_weight = matmul::(out_grad, columns, MatmulStrategy::default()); + let grad_weight = matmul::(out_grad, columns, None, MatmulStrategy::default()); reshape( grad_weight, @@ -150,7 +150,7 @@ fn backward_gradient_inputs( for group in 0..groups { let weight = swap_dims(index::(weight.clone(), group), 0, 1); let out_grad = index::(out_grad.clone(), group); - let values = matmul::(weight, out_grad, MatmulStrategy::default()); + let values = matmul::(weight, out_grad, None, MatmulStrategy::default()); let values = reshape(values, Shape::new([1, col_shape_0, col_shape_1])); columns = slice_assign::( columns, diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 60c7cbbd1c..e0b87e8931 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -1,25 +1,11 @@ -use super::{init_matmul_output, matmul_simple}; +use super::init_matmul_output; use crate::{tensor::JitTensor, FloatElement, JitRuntime}; -use burn_tensor::Shape; -use cubecl::{ - ir::{Elem, FloatKind}, - linalg::matmul::Strategy, - prelude::*, - Feature, -}; #[cfg(feature = "autotune")] use super::matmul_autotune; /// The strategy to be used when launching a matmul kernel. pub enum MatmulStrategy { - /// A simple kernel will be used with memory coalescing optimization. - Simple { - /// Number of invocations in x - grid_x: usize, - /// Number of invocations in y - grid_y: usize, - }, #[cfg(feature = "autotune")] /// Using autotune to choose the best kernel based on runtime information. Autotune, @@ -42,21 +28,17 @@ impl Default for MatmulStrategy { pub fn matmul( lhs: JitTensor, rhs: JitTensor, + out: Option>, strategy: MatmulStrategy, ) -> JitTensor { match strategy { - MatmulStrategy::Simple { grid_x, grid_y } => { - let out = init_matmul_output::(&lhs, &rhs); - - matmul_simple::(lhs, rhs, out, grid_x, grid_y) - } MatmulStrategy::Cube => { - let out = init_matmul_output::(&lhs, &rhs); + let out = out.unwrap_or_else(|| init_matmul_output::(&lhs, &rhs)); let client = &lhs.client; cubecl::linalg::matmul::launch_ref::( - &cube_strategy::(client), + &Default::default(), client, &lhs.as_handle_ref(), &rhs.as_handle_ref(), @@ -65,47 +47,6 @@ pub fn matmul( out } #[cfg(feature = "autotune")] - MatmulStrategy::Autotune => matmul_autotune::(lhs, rhs), - } -} - -pub(crate) fn cube_strategy( - client: &ComputeClient, -) -> Strategy { - // TODO: Replace with auto option once cubecl has one - let cmma_available = client.properties().feature_enabled(Feature::Cmma { - a: Elem::Float(FloatKind::F16), - b: Elem::Float(FloatKind::F16), - c: Elem::Float(FloatKind::F32), - m: 16, - k: 16, - n: 16, - }); - let plane_available = client.properties().feature_enabled(Feature::Plane); - match (cmma_available, plane_available) { - (true, _) => Strategy::Accelerated, - (false, true) => Strategy::PlaneMma, - _ => Strategy::Tiling2D(Default::default()), + MatmulStrategy::Autotune => matmul_autotune::(lhs, rhs, out), } } - -pub(crate) fn simple_cube_count( - lhs_shape: &Shape, - rhs_shape: &Shape, - output_shape: &Shape, - cube_dim_x: usize, - cube_dim_y: usize, -) -> CubeCount { - let ndims = lhs_shape.num_dims(); - let num_rows = lhs_shape.dims[ndims - 2]; - let num_cols = rhs_shape.dims[ndims - 1]; - - let cubes_x = f32::ceil(num_rows as f32 / cube_dim_x as f32) as u32; - let cubes_y = f32::ceil(num_cols as f32 / cube_dim_y as f32) as u32; - let mut num_iter = 1; - for i in 0..ndims - 2 { - num_iter *= output_shape.dims[i]; - } - - CubeCount::Static(cubes_x, cubes_y, num_iter as u32) -} diff --git a/crates/burn-jit/src/kernel/matmul/mod.rs b/crates/burn-jit/src/kernel/matmul/mod.rs index 633743564b..80fa8ed82c 100644 --- a/crates/burn-jit/src/kernel/matmul/mod.rs +++ b/crates/burn-jit/src/kernel/matmul/mod.rs @@ -1,11 +1,9 @@ mod base; -mod simple; mod tune; /// Contains utilitary for matmul operation pub mod utils; pub use base::*; -pub use simple::*; pub use tune::*; pub use utils::*; diff --git a/crates/burn-jit/src/kernel/matmul/simple.rs b/crates/burn-jit/src/kernel/matmul/simple.rs deleted file mode 100644 index 7d75b30395..0000000000 --- a/crates/burn-jit/src/kernel/matmul/simple.rs +++ /dev/null @@ -1,143 +0,0 @@ -//! Naive matmul kernel implementation -//! -//! Each local unit will compute a single element of the output matrix. -use crate::{ - kernel::{into_contiguous, PLANE_DIM_APPROX}, - ops::swap_dims, - tensor::JitTensor, - FloatElement, JitRuntime, -}; - -use super::simple_cube_count; -use cubecl::prelude::*; - -#[cube(launch_unchecked)] -fn matmul_kernel( - lhs: &Tensor, - rhs: &Tensor, - out: &mut Tensor, - // number of dimensions not involved in the matmul - #[comptime] num_batches: Option, -) { - let rank = out.rank(); - let end = num_batches.unwrap_or_else(|| rank - 2); - let unroll = num_batches.is_some(); - - let n_rows = lhs.shape(rank - 2); - let n_cols = rhs.shape(rank - 1); - let mut k = rhs.shape(rank - 2); - - let batch_pos = ABSOLUTE_POS_Z; - let row = CUBE_DIM_X * CUBE_POS_X + UNIT_POS_X; - let col = CUBE_DIM_Y * CUBE_POS_Y + UNIT_POS_Y; - - if row >= n_rows || col >= n_cols { - return; - } - - let vectorization_factor = vectorization_of(lhs); - - let mut offset_lhs = 0; - let mut offset_rhs = 0; - let offset_out = n_rows * n_cols * batch_pos; - - #[unroll(unroll)] - for i in 0..end { - let ogwl = offset_out / out.stride(i); - - offset_lhs += ogwl % lhs.shape(i) * lhs.stride(i); - offset_rhs += ogwl % rhs.shape(i) * rhs.stride(i); - } - - offset_lhs /= vectorization_factor; - offset_rhs /= vectorization_factor; - - let mut sum = F::vectorized(0., vectorization_factor); - - k /= vectorization_factor; - - for i in 0..k { - let lhs_index = row * k + i + offset_lhs; - let rhs_index = col * k + i + offset_rhs; - - sum += lhs[lhs_index] * rhs[rhs_index]; - } - - let mut out_index = row * n_cols + col; - out_index += offset_out; - - let unroll_sum = vectorization_factor != 1; - if unroll_sum { - let mut accum = F::new(0.); - // we unroll the loop to sum `vectorization_factor` elements at once, which lets us - // use SIMD instructions to speed up the computation - #[unroll] - for v in 0..vectorization_factor { - accum += sum[v]; - } - - out[out_index] = accum; - } else { - out[out_index] = sum; - } -} - -/// Matrix multiplication using memory coalescing algorithm with cube dimensions of size 16 -pub fn matmul_mem_coalescing_default( - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, -) -> JitTensor { - matmul_simple::(lhs, rhs, out, PLANE_DIM_APPROX, PLANE_DIM_APPROX) -} - -/// Matrix multiplication using memory coalescing algorithm with custom cube dimensions -pub fn matmul_simple( - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, - cube_dim_x: usize, - cube_dim_y: usize, -) -> JitTensor { - lhs.assert_is_on_same_device(&rhs); - let ndims = lhs.shape.num_dims(); - let lhs = into_contiguous(lhs); - - let rhs_original_shape = rhs.shape.clone(); - // we swap the dimensions to achieve memory-coalescing: - // consecutive elements of a column in the original rhs tensor will now be stored - // consecutively in memory, which allows to fetch them with fewer memory instructions - let rhs = into_contiguous(swap_dims(rhs, ndims - 1, ndims - 2)); - - let cube_count = simple_cube_count( - &lhs.shape, - &rhs_original_shape, - &out.shape, - cube_dim_x, - cube_dim_y, - ); - - let vectorization_factor = match lhs.shape.dims[ndims - 1] % 4 == 0 { - true => 4, - false => 1, - }; - - unsafe { - matmul_kernel::launch_unchecked::( - &lhs.client, - cube_count, - CubeDim::new(cube_dim_x as u32, cube_dim_y as u32, 1), - lhs.as_tensor_arg::(vectorization_factor), - TensorArg::from_raw_parts::( - &rhs.handle, - &rhs.strides, - &rhs_original_shape.dims, // We need the original shape. - vectorization_factor, - ), - out.as_tensor_arg::(1), - Some(ndims as u32 - 2), - ); - }; - - out -} diff --git a/crates/burn-jit/src/kernel/matmul/tune/base.rs b/crates/burn-jit/src/kernel/matmul/tune/base.rs index e49ea3154c..2b8050dc07 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/base.rs @@ -1,14 +1,14 @@ use core::marker::PhantomData; use burn_tensor::{Element, ElementConversion}; -use cubecl::tune::{local_tuner, AutotuneOperation, AutotuneOperationSet, LocalTuner}; +use cubecl::{ + linalg::matmul::{kernels::tiling2d::Tiling2dConfig, Strategy}, + tune::{local_tuner, AutotuneOperation, AutotuneOperationSet, LocalTuner}, +}; use crate::{ element::FloatElement, - kernel::{ - matmul::{cube_strategy, utils::init_matmul_output}, - prng::random_like_uniform, - }, + kernel::{matmul::utils::init_matmul_output, prng::random_like_uniform}, ops::numeric::empty_device, tensor::JitTensor, tune_key::JitAutotuneKey, @@ -57,17 +57,17 @@ impl AutotuneOperationSet ); vec![ - Box::new(SimpleMatmul::::new( + Box::new(MatmulTiling2d::::new( lhs.clone(), rhs.clone(), out.clone(), )), - Box::new(SimpleMatmul16x16::::new( + Box::new(MatmulAccelerated::::new( lhs.clone(), rhs.clone(), out.clone(), )), - Box::new(MatmulCube::::new( + Box::new(MatmulSimple::::new( lhs.clone(), rhs.clone(), out.clone(), @@ -77,9 +77,9 @@ impl AutotuneOperationSet fn fastest(self: Box, fastest_index: usize) -> Box { match fastest_index { - 0 => Box::new(SimpleMatmul::::new(self.lhs, self.rhs, self.out)), - 1 => Box::new(SimpleMatmul16x16::::new(self.lhs, self.rhs, self.out)), - 2 => Box::new(MatmulCube::::new(self.lhs, self.rhs, self.out)), + 0 => Box::new(MatmulTiling2d::::new(self.lhs, self.rhs, self.out)), + 1 => Box::new(MatmulAccelerated::::new(self.lhs, self.rhs, self.out)), + 2 => Box::new(MatmulSimple::::new(self.lhs, self.rhs, self.out)), _ => panic!("Fastest index is out of bound"), } } @@ -89,8 +89,9 @@ impl AutotuneOperationSet pub fn matmul_autotune( lhs: JitTensor, rhs: JitTensor, + out: Option>, ) -> JitTensor { - let output = init_matmul_output::(&lhs, &rhs); + let output = out.unwrap_or_else(|| init_matmul_output::(&lhs, &rhs)); let client = lhs.client.clone(); @@ -137,24 +138,40 @@ macro_rules! matmul_tune_ops { }; } -// Potentially better for small matrices. +// Probably the fastest in the general case. matmul_tune_ops!( - SimpleMatmul, - crate::kernel::matmul::matmul_mem_coalescing_default:: + MatmulAccelerated, + |lhs: JitTensor, rhs: JitTensor, out: JitTensor| { + cubecl::linalg::matmul::launch_ref::( + &Strategy::Accelerated, + &lhs.client, + &lhs.as_handle_ref(), + &rhs.as_handle_ref(), + &out.as_handle_ref(), + ); + } ); -// Potentially better for small matrices. -matmul_tune_ops!(SimpleMatmul16x16, |lhs, rhs, out| { - crate::kernel::matmul::matmul_simple::(lhs, rhs, out, 16, 16) -}); +// Probably the fastest when tensor cores are not available. +matmul_tune_ops!( + MatmulTiling2d, + |lhs: JitTensor, rhs: JitTensor, out: JitTensor| { + cubecl::linalg::matmul::launch_ref::( + &Strategy::Tiling2D(Tiling2dConfig::default()), + &lhs.client, + &lhs.as_handle_ref(), + &rhs.as_handle_ref(), + &out.as_handle_ref(), + ); + } +); -// Probably the fastest in the general case, without loop unrolling +// Probably the fastest for small matrices. matmul_tune_ops!( - MatmulCube, + MatmulSimple, |lhs: JitTensor, rhs: JitTensor, out: JitTensor| { - let strategy = cube_strategy::(&lhs.client); cubecl::linalg::matmul::launch_ref::( - &strategy, + &Strategy::Simple, &lhs.client, &lhs.as_handle_ref(), &rhs.as_handle_ref(), diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index f97b1609ff..2dc8a4a6f2 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -162,7 +162,7 @@ where execute_with_dtype!( float(lhs.dtype, rhs.dtype), E, - matmul::(lhs, rhs, MatmulStrategy::default()) + matmul::(lhs, rhs, None, MatmulStrategy::default()) ) } diff --git a/crates/burn-wgpu/Cargo.toml b/crates/burn-wgpu/Cargo.toml index ea629af3f5..d3975faad3 100644 --- a/crates/burn-wgpu/Cargo.toml +++ b/crates/burn-wgpu/Cargo.toml @@ -27,13 +27,13 @@ cubecl = { workspace = true, features = ["wgpu"] } burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false } burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [ - "cubecl-wgpu", + "cubecl-wgpu", ] } [dev-dependencies] burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [ - "export_tests", + "export_tests", ] } half = { workspace = true } paste = { workspace = true }