Skip to content

Commit

Permalink
JLL related fixups (#706)
Browse files Browse the repository at this point in the history
* JLL related fixups

* Bump enzymexla

* Update WORKSPACE

* bump

* add new opts
  • Loading branch information
wsmoses authored Feb 8, 2025
1 parent 611b800 commit c20b142
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.5"
Reactant_jll = "0.0.62"
Reactant_jll = "0.0.64"
Scratch = "1.2"
Sockets = "1.10"
SpecialFunctions = "2.4"
Expand Down
19 changes: 11 additions & 8 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -852,14 +852,7 @@ extern "C" void RegisterDialects(MlirContext cctx) {
#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h"

extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
mlir::DialectRegistry &registry = *unwrap(creg);
prepareRegistry(registry);

mlir::registerLLVMDialectImport(registry);
mlir::registerNVVMDialectImport(registry);
mlir::LLVM::registerInlinerInterface(registry);

extern "C" void InitializePasses(MlirDialectRegistry creg) {
mlir::registerenzymePasses();
enzyme::registerenzymexlaPasses();

Expand Down Expand Up @@ -901,6 +894,16 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
xla::sdy::registerSdyRoundTripImportPipeline();
}

extern "C" void InitializeRegistry(MlirDialectRegistry creg) {
mlir::DialectRegistry &registry = *unwrap(creg);
prepareRegistry(registry);

mlir::registerLLVMDialectImport(registry);
mlir::registerNVVMDialectImport(registry);
mlir::LLVM::registerInlinerInterface(registry);

}

/// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
/// suffix in `lastUsedID`.
static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName,
Expand Down
3 changes: 2 additions & 1 deletion deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,8 @@ cc_library(
"-Wl,-exported_symbol,_FutureAwait",
"-Wl,-exported_symbol,_XLAExecute",
"-Wl,-exported_symbol,_RegisterDialects",
"-Wl,-exported_symbol,_InitializeRegistryAndPasses",
"-Wl,-exported_symbol,_InitializeRegistry",
"-Wl,-exported_symbol,_InitializePasses",
"-Wl,-exported_symbol,_ifrt_*",
"-Wl,-exported_symbol,_RegisterCustomCallTarget",
"-Wl,-exported_symbol,_ConvertLLVMToMLIR",
Expand Down
2 changes: 1 addition & 1 deletion deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ http_archive(
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
)

ENZYMEXLA_COMMIT = "e2cc5276372199a5b291b8140bd55c46e8e1538a"
ENZYMEXLA_COMMIT = "e3b0a810763eab1fdab9a8231088160cd3c42e0c"
ENZYMEXLA_SHA256 = ""

http_archive(
Expand Down
17 changes: 15 additions & 2 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ function create_result(
end

# Optimization passes via transform dialect
function optimization_passes(; no_nan::Bool=false, sroa::Bool=false)
function optimization_passes(; no_nan::Bool=false, sroa::Bool=false, inline::Bool=true)
transform_passes_list = [
"patterns=compare_op_canon<16>",
"transpose_transpose<16>",
Expand Down Expand Up @@ -389,6 +389,8 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false)
"common_compare_expression_rewrite",
"compare_select_simplify",
"while_simplify<1>",
"scatter_update_computation_const_prop",
"if_remove_unused",
]
if no_nan
append!(
Expand All @@ -407,7 +409,10 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false)
",",
)
func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",")
passes = ["inline{default-pipeline=canonicalize max-iterations=4}"]
passes = String[]
if inline
push!(passes, "inline{default-pipeline=canonicalize max-iterations=4}")
end
if sroa
push!(passes, "propagate-constant-bounds")
if DUMP_LLVMIR[]
Expand Down Expand Up @@ -703,6 +708,14 @@ function compile_mlir!(
run_pass_pipeline!(
mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math," * kern
)
elseif optimize === :canonicalize
run_pass_pipeline!(
mod, "canonicalize"
)
elseif optimize === :just_batch
run_pass_pipeline!(
mod, "enzyme-batch"
)
elseif optimize !== :none
error("Invalid optimize option: $(Meta.quot(optimize))")
end
Expand Down
4 changes: 2 additions & 2 deletions src/Precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ end
@compile_workload begin
@static if precompilation_supported()
x = ConcreteRNumber(2.0; client)
Reactant.compile(sin, (x,); client)
Reactant.compile(sin, (x,); client, optimize=:all)

y = ConcreteRArray([2.0]; client)
Reactant.compile(Base.sum, (y,); client)
Reactant.compile(Base.sum, (y,); client, optimize=:all)
end
end
XLA.free_client(client)
Expand Down
13 changes: 9 additions & 4 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,23 @@ export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace, withi

const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}()

const initialize_dialect_first_run = Ref{Bool}(true)

const passes_initialized = Ref(false)
function initialize_dialect()
registry[] = MLIR.IR.DialectRegistry()
@ccall MLIR.API.mlir_c.InitializeRegistryAndPasses(
@ccall MLIR.API.mlir_c.InitializeRegistry(
registry[]::MLIR.API.MlirDialectRegistry
)::Cvoid
initialize_dialect_first_run[] = false
if !passes_initialized[]
@ccall MLIR.API.mlir_c.InitializePasses(
registry[]::MLIR.API.MlirDialectRegistry
)::Cvoid
passes_initialized[] = true
end
return nothing
end

function deinitialize_dialect()
passes_initialized[] = false
return registry[] = nothing
end

Expand Down
4 changes: 3 additions & 1 deletion src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,15 @@ function __init__()
println(stdout, e)
end
else
if !Reactant.precompiling()
try
gpu = GPUClient()
backends["gpu"] = gpu
default_backend[] = gpu
catch e
println(stdout, e)
println(stdout, e)
end
end
end
end

Expand Down

0 comments on commit c20b142

Please sign in to comment.