Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JLL related fixups #706

Merged
merged 5 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.5"
Reactant_jll = "0.0.62"
Reactant_jll = "0.0.64"
Scratch = "1.2"
Sockets = "1.10"
SpecialFunctions = "2.4"
Expand Down
19 changes: 11 additions & 8 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -852,14 +852,7 @@ extern "C" void RegisterDialects(MlirContext cctx) {
#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h"

extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
mlir::DialectRegistry &registry = *unwrap(creg);
prepareRegistry(registry);

mlir::registerLLVMDialectImport(registry);
mlir::registerNVVMDialectImport(registry);
mlir::LLVM::registerInlinerInterface(registry);

extern "C" void InitializePasses(MlirDialectRegistry creg) {
mlir::registerenzymePasses();
enzyme::registerenzymexlaPasses();

Expand Down Expand Up @@ -901,6 +894,16 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
xla::sdy::registerSdyRoundTripImportPipeline();
}

extern "C" void InitializeRegistry(MlirDialectRegistry creg) {
mlir::DialectRegistry &registry = *unwrap(creg);
prepareRegistry(registry);

mlir::registerLLVMDialectImport(registry);
mlir::registerNVVMDialectImport(registry);
mlir::LLVM::registerInlinerInterface(registry);

}

/// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
/// suffix in `lastUsedID`.
static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName,
Expand Down
3 changes: 2 additions & 1 deletion deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,8 @@ cc_library(
"-Wl,-exported_symbol,_FutureAwait",
"-Wl,-exported_symbol,_XLAExecute",
"-Wl,-exported_symbol,_RegisterDialects",
"-Wl,-exported_symbol,_InitializeRegistryAndPasses",
"-Wl,-exported_symbol,_InitializeRegistry",
"-Wl,-exported_symbol,_InitializePasses",
"-Wl,-exported_symbol,_ifrt_*",
"-Wl,-exported_symbol,_RegisterCustomCallTarget",
"-Wl,-exported_symbol,_ConvertLLVMToMLIR",
Expand Down
2 changes: 1 addition & 1 deletion deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ http_archive(
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
)

ENZYMEXLA_COMMIT = "e2cc5276372199a5b291b8140bd55c46e8e1538a"
ENZYMEXLA_COMMIT = "e3b0a810763eab1fdab9a8231088160cd3c42e0c"
ENZYMEXLA_SHA256 = ""

http_archive(
Expand Down
17 changes: 15 additions & 2 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ function create_result(
end

# Optimization passes via transform dialect
function optimization_passes(; no_nan::Bool=false, sroa::Bool=false)
function optimization_passes(; no_nan::Bool=false, sroa::Bool=false, inline::Bool=true)
transform_passes_list = [
"patterns=compare_op_canon<16>",
"transpose_transpose<16>",
Expand Down Expand Up @@ -389,6 +389,8 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false)
"common_compare_expression_rewrite",
"compare_select_simplify",
"while_simplify<1>",
"scatter_update_computation_const_prop",
"if_remove_unused",
]
if no_nan
append!(
Expand All @@ -407,7 +409,10 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false)
",",
)
func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",")
passes = ["inline{default-pipeline=canonicalize max-iterations=4}"]
passes = String[]
if inline
push!(passes, "inline{default-pipeline=canonicalize max-iterations=4}")
end
if sroa
push!(passes, "propagate-constant-bounds")
if DUMP_LLVMIR[]
Expand Down Expand Up @@ -703,6 +708,14 @@ function compile_mlir!(
run_pass_pipeline!(
mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math," * kern
)
elseif optimize === :canonicalize
run_pass_pipeline!(
mod, "canonicalize"
)
Comment on lines +712 to +714
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
run_pass_pipeline!(
mod, "canonicalize"
)
run_pass_pipeline!(mod, "canonicalize")

elseif optimize === :just_batch
run_pass_pipeline!(
mod, "enzyme-batch"
)
Comment on lines +716 to +718
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
run_pass_pipeline!(
mod, "enzyme-batch"
)
run_pass_pipeline!(mod, "enzyme-batch")

elseif optimize !== :none
error("Invalid optimize option: $(Meta.quot(optimize))")
end
Expand Down
4 changes: 2 additions & 2 deletions src/Precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ end
@compile_workload begin
@static if precompilation_supported()
x = ConcreteRNumber(2.0; client)
Reactant.compile(sin, (x,); client)
Reactant.compile(sin, (x,); client, optimize=:all)

y = ConcreteRArray([2.0]; client)
Reactant.compile(Base.sum, (y,); client)
Reactant.compile(Base.sum, (y,); client, optimize=:all)
end
end
XLA.free_client(client)
Expand Down
13 changes: 9 additions & 4 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,23 @@ export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace, withi

const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}()

const initialize_dialect_first_run = Ref{Bool}(true)

const passes_initialized = Ref(false)
function initialize_dialect()
registry[] = MLIR.IR.DialectRegistry()
@ccall MLIR.API.mlir_c.InitializeRegistryAndPasses(
@ccall MLIR.API.mlir_c.InitializeRegistry(
registry[]::MLIR.API.MlirDialectRegistry
)::Cvoid
initialize_dialect_first_run[] = false
if !passes_initialized[]
@ccall MLIR.API.mlir_c.InitializePasses(
registry[]::MLIR.API.MlirDialectRegistry
)::Cvoid
passes_initialized[] = true
Comment on lines +163 to +166
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@ccall MLIR.API.mlir_c.InitializePasses(
registry[]::MLIR.API.MlirDialectRegistry
)::Cvoid
passes_initialized[] = true
@ccall MLIR.API.mlir_c.InitializePasses(
registry[]::MLIR.API.MlirDialectRegistry
)::Cvoid
passes_initialized[] = true

end
return nothing
end

function deinitialize_dialect()
passes_initialized[] = false
return registry[] = nothing
end

Expand Down
4 changes: 3 additions & 1 deletion src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,15 @@ function __init__()
println(stdout, e)
end
else
if !Reactant.precompiling()
try
gpu = GPUClient()
backends["gpu"] = gpu
default_backend[] = gpu
catch e
println(stdout, e)
println(stdout, e)
Comment on lines +220 to +226
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
if !Reactant.precompiling()
try
gpu = GPUClient()
backends["gpu"] = gpu
default_backend[] = gpu
catch e
println(stdout, e)
println(stdout, e)
if !Reactant.precompiling()
try
gpu = GPUClient()
backends["gpu"] = gpu
default_backend[] = gpu
catch e
println(stdout, e)
end

end
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end

end
end

Expand Down
Loading