Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Ops.batch #535

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
docs: setup batching tutorial
  • Loading branch information
avik-pal committed Jan 22, 2025
commit dd222194957a57e5c2fa5aab823b5f6cb1609827
12 changes: 10 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -24,8 +24,11 @@ examples = [
pages = [
"Reactant.jl" => "index.md",
"Introduction" => ["Getting Started" => "introduction/index.md"],
"Tutorials" =>
["Overview" => "tutorials/index.md", "Profiling" => "tutorials/profiling.md"],
"Tutorials" => [
"Overview" => "tutorials/index.md",
"Profiling" => "tutorials/profiling.md",
"Batching Functions with `Reactant.Ops.batch`" => "tutorials/batching.md",
],
"API Reference" => [
"Reactant API" => "api/api.md",
"Ops" => "api/ops.md",
@@ -38,6 +41,11 @@ pages = [
"Func" => "api/func.md",
"StableHLO" => "api/stablehlo.md",
"VHLO" => "api/vhlo.md",
"GPU" => "api/gpu.md",
"LLVM" => "api/llvm.md",
"NVVM" => "api/nvvm.md",
"TPU" => "api/tpu.md",
"Triton" => "api/triton.md",
],
"MLIR API" => "api/mlirc.md",
"XLA" => "api/xla.md",
10 changes: 9 additions & 1 deletion docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
@@ -56,8 +56,12 @@ export default defineConfig({
{
text: "Tutorials",
items: [
{text: "Overview", link: "/tutorials/"},
{ text: "Overview", link: "/tutorials/" },
{text: "Profiling", link: "/tutorials/profiling"},
{
text: "Batching Functions with `Reactant.Ops.batch`",
link: "/tutorials/batching"
},
],
},
{
@@ -112,6 +116,10 @@ export default defineConfig({
items: [
{ text: "Overview", link: "/tutorials/" },
{ text: "Profiling", link: "/tutorials/profiling" },
{
text: "Batching Functions with `Reactant.Ops.batch`",
link: "/tutorials/batching",
},
],
},
"/api/": {
3 changes: 3 additions & 0 deletions docs/src/tutorials/batching.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# [Batching Functions with [`Reactant.Ops.batch`](@ref)](@id batching-tutorial)


1 change: 1 addition & 0 deletions docs/src/tutorials/index.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Tutorials

- [Profiling](@ref profiling).
- [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial)

We are currently working on adding more tutorials to Reactant!! Please check back soon!
4 changes: 2 additions & 2 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
@@ -795,7 +795,7 @@ function codegen_unflatten!(
paths = (
(
p for p in Reactant.TracedUtils.get_paths(result) if
length(p) ≥ 1 && (p[1] == :result || p[1] == :resargs)
length(p) > 0 && (p[1] == :result || p[1] == :resargs)
)...,
)
for path in paths
@@ -865,7 +865,7 @@ function codegen_unflatten!(
paths = (
(
p for p in Reactant.TracedUtils.get_paths(result) if
length(p) ≥ 1 && (p[1] == :result || p[1] == :resargs || p[1] == :args)
length(p) > 0 && (p[1] == :result || p[1] == :resargs || p[1] == :args)
)...,
)

22 changes: 20 additions & 2 deletions src/Ops.jl
Original file line number Diff line number Diff line change
@@ -2013,8 +2013,24 @@ end
# This function assumes that the last dimension of each element is the batch dimension by
# default. This is the standard Julia ordering for batching. We permutedims the ordering to
# make sure the first dimension is the batch dimension when calling `batch_internal` below.
# XXX: Mutation inside a batched function is not supported yet (need to set the results
# correctly)
"""
batch(f, args...; batch_dims=nothing, result_dims=nothing)

Map `f` over the arguments `args` along the batch dimensions `batch_dims` and return the results with the corresponding batch dimensions specified by `result_dims`. (For users
familiar with `jax`, this operation corresponds to `jax.vmap`.)

If `batch_dims` is `nothing`, we assume that the last dimension of each leaf of `args` is the batch dimension. If `result_dims` is `nothing`, we assume that the last dimension of each leaf of the returned values is the batch dimension.

To avoid batching a specific leaf, pass `nothing` for the corresponding `batch_dims`.

## Examples

For usage examples, see the [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial) tutorial.

!!! danger

Mutation inside a batched function is not supported yet and will lead to unexpected results.
"""
@noinline function batch(f, args...; batch_dims=nothing, result_dims=nothing)
batch_sizes = Int64[]
batching_dims = if batch_dims === nothing
@@ -2060,6 +2076,8 @@ end
end

return fmap(results, result_dims) do result, dim
@assert dim !== nothing "Result batch dimension cannot be `nothing`"

order = collect(Int64, 1:ndims(result))
order[dim] = 1
order[1] = dim
2 changes: 0 additions & 2 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
@@ -9,8 +9,6 @@ using Functors: @leaf
using Adapt: Adapt, WrappedArray
using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)`

using Functors: @leaf

export @allowscalar # re-exported from GPUArraysCore

# auxiliary types and functions
2 changes: 2 additions & 0 deletions test/batching.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
using Reactant, Test

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -57,6 +57,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
@safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
@safetestset "Control Flow" include("control_flow.jl")
@safetestset "Sorting" include("sorting.jl")
@safetestset "Batching" include("batching.jl")
end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"