Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segfault on convert_simplify optimization with complex numbers #695

Closed
Todorbsc opened this issue Feb 6, 2025 · 4 comments · Fixed by #706
Closed

Segfault on convert_simplify optimization with complex numbers #695

Todorbsc opened this issue Feb 6, 2025 · 4 comments · Fixed by #706

Comments

@Todorbsc
Copy link

Todorbsc commented Feb 6, 2025

CC @mofeing

We found the following unexpected behavior when performing dot_general operation over two vectors of different types.

julia> using Reactant
AssertionError("Could not find registered platform with name: \"cuda\". Available platform names are: ")

julia> a = ones(ComplexF64, 2)
2-element Vector{ComplexF64}:
 1.0 + 0.0im
 1.0 + 0.0im

julia> b = ones(Int, 2)
2-element Vector{Int64}:
 1
 1

julia> a_re = ConcreteRArray(a)
2-element ConcreteRArray{ComplexF64, 1}:
 1.0 + 0.0im
 1.0 + 0.0im


julia> function f(x)
           bconst = Reactant.TracedRArray{ComplexF64, 1}(Reactant.Ops.constant(b))
           Reactant.Ops.dot_general(x, bconst; contracting_dimensions=([1],[1]), batching_dimensions=(Int[], Int[]))
       end
f (generic function with 1 method)

julia> @code_hlo optimize=false f(a_re)
module {
  func.func @main(%arg0: tensor<2xcomplex<f64>>) -> (tensor<complex<f64>>, tensor<2xcomplex<f64>>) {
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<2xcomplex<f64>>) -> tensor<2xcomplex<f64>>
    %c = stablehlo.constant dense<1> : tensor<2xi64>
    %1 = stablehlo.convert %c : (tensor<2xi64>) -> tensor<2xcomplex<f64>>
    %2 = stablehlo.dot_general %0, %1, contracting_dims = [0] x [0] : (tensor<2xcomplex<f64>>, tensor<2xcomplex<f64>>) -> tensor<complex<f64>>
    %3 = stablehlo.transpose %0, dims = [0] : (tensor<2xcomplex<f64>>) -> tensor<2xcomplex<f64>>
    return %2, %3 : tensor<complex<f64>>, tensor<2xcomplex<f64>>
  }
}

julia> @code_hlo optimize=true f(a_re)

[3108287] signal 11 (1): Segmentation fault
in expression starting at none:0
__memcpy_avx512_unaligned_erms at /lib64/libc.so.6 (unknown line)
_ZL8readBitsPKcmm at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZNK4mlir17DenseElementsAttr25ComplexIntElementIteratordeEv at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN4llvm12function_refIFvjEE11callback_fnIZN4mlir10AsmPrinter4Impl29printDenseIntOrFPElementsAttrENS4_24DenseIntOrFPElementsAttrEbEUljE0_EEvlj at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZL26printDenseElementsAttrImplbN4mlir10ShapedTypeERN4llvm11raw_ostreamENS1_12function_refIFvjEEE at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN4mlir10AsmPrinter4Impl29printDenseIntOrFPElementsAttrENS_24DenseIntOrFPElementsAttrEb at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN4mlir10AsmPrinter4Impl18printAttributeImplENS_9AttributeENS1_15AttrTypeElisionE at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN4mlir3hlo15printConstantOpERNS_12OpAsmPrinterEPNS_9OperationENS_12ElementsAttrE at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN4llvm6detail18UniqueFunctionBaseIvJPN4mlir9OperationERNS2_12OpAsmPrinterENS_9StringRefEEE8CallImplIKZNS2_2OpINS2_9stablehlo10ConstantOpEJNS2_7OpTrait11ZeroRegionsENSD_9OneResultENSD_14OneTypedResultINS2_16RankedTensorTypeEE4ImplENSD_14ZeroSuccessorsENSD_12ZeroOperandsENSD_12OpInvariantsENS2_19BytecodeOpInterface5TraitENSD_12ConstantLikeENS2_25ConditionallySpeculatable5TraitENSD_27AlwaysSpeculatableImplTraitENS2_23MemoryEffectOpInterface5TraitENS2_20InferTypeOpInterface5TraitENS2_16OpAsmOpInterface5TraitEEE18getPrintAssemblyFnEvEUlS4_S6_S7_E_EEvPvS4_S6_S7_ at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN4mlir23RegisteredOperationName5ModelINS_9stablehlo10ConstantOpEE13printAssemblyEPNS_9OperationERNS_12OpAsmPrinterEN4llvm9StringRefE at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN12_GLOBAL__N_116OperationPrinter27printFullOpWithIndentAndLocEPN4mlir9OperationE at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN12_GLOBAL__N_116OperationPrinter5printEPN4mlir5BlockEbb at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN12_GLOBAL__N_116OperationPrinter11printRegionERN4mlir6RegionEbbb at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN4mlir23function_interface_impl15printFunctionOpERNS_12OpAsmPrinterENS_19FunctionOpInterfaceEbN4llvm9StringRefENS_10StringAttrES6_ at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN4mlir4func6FuncOp5printERNS_12OpAsmPrinterE at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN4llvm6detail18UniqueFunctionBaseIvJPN4mlir9OperationERNS2_12OpAsmPrinterENS_9StringRefEEE8CallImplIKZNS2_2OpINS2_4func6FuncOpEJNS2_7OpTrait9OneRegionENSD_11ZeroResultsENSD_14ZeroSuccessorsENSD_12ZeroOperandsENSD_12OpInvariantsENS2_19BytecodeOpInterface5TraitENSD_11AffineScopeENSD_24AutomaticAllocationScopeENS2_17SymbolOpInterface5TraitENS2_19CallableOpInterface5TraitENS2_19FunctionOpInterface5TraitENSD_19IsIsolatedFromAboveENS2_16OpAsmOpInterface5TraitEEE18getPrintAssemblyFnEvEUlS4_S6_S7_E_EEvPvS4_S6_S7_ at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN4mlir23RegisteredOperationName5ModelINS_4func6FuncOpEE13printAssemblyEPNS_9OperationERNS_12OpAsmPrinterEN4llvm9StringRefE at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN12_GLOBAL__N_116OperationPrinter27printFullOpWithIndentAndLocEPN4mlir9OperationE at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN12_GLOBAL__N_116OperationPrinter5printEPN4mlir5BlockEbb at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN12_GLOBAL__N_116OperationPrinter11printRegionERN4mlir6RegionEbbb at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN4mlir8ModuleOp5printERNS_12OpAsmPrinterE at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN4llvm6detail18UniqueFunctionBaseIvJPN4mlir9OperationERNS2_12OpAsmPrinterENS_9StringRefEEE8CallImplIKZNS2_2OpINS2_8ModuleOpEJNS2_7OpTrait9OneRegionENSC_11ZeroResultsENSC_14ZeroSuccessorsENSC_12ZeroOperandsENSC_17NoRegionArgumentsENSC_12NoTerminatorENSC_11SingleBlockENSC_12OpInvariantsENS2_19BytecodeOpInterface5TraitENSC_11AffineScopeENSC_19IsIsolatedFromAboveENSC_11SymbolTableENS2_17SymbolOpInterface5TraitENS2_16OpAsmOpInterface5TraitENS2_19RegionKindInterface5TraitENSC_18HasOnlyGraphRegionEEE18getPrintAssemblyFnEvEUlS4_S6_S7_E_EEvPvS4_S6_S7_ at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN4mlir23RegisteredOperationName5ModelINS_8ModuleOpEE13printAssemblyEPNS_9OperationERNS_12OpAsmPrinterEN4llvm9StringRefE at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN12_GLOBAL__N_116OperationPrinter27printFullOpWithIndentAndLocEPN4mlir9OperationE at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN4mlir9Operation5printERN4llvm11raw_ostreamERNS_8AsmStateE at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
_ZN4mlir9Operation5printERN4llvm11raw_ostreamERKNS_15OpPrintingFlagsE at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
mlirOperationPrintWithFlags at /home/bsc/bsc021386/.julia/artifacts/5eee33a67e693bfd3663d3c63eba2d8e1f48712c/lib/libReactantExtra.so (unknown line)
mlirOperationPrintWithFlags at /home/bsc/bsc021386/.julia/packages/Reactant/7y9bj/src/mlir/libMLIR_h.jl:1408 [inlined]
show at /home/bsc/bsc021386/.julia/packages/Reactant/7y9bj/src/mlir/IR/Operation.jl:232
show at /home/bsc/bsc021386/.julia/packages/Reactant/7y9bj/src/mlir/IR/Module.jl:58 [inlined]
show at ./multimedia.jl:47
unknown function (ip: 0x7ef883223456)
#68 at /gpfs/apps/MN5/GPP/JULIA/SRC/julia-1.11.0/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:348
jfptr_YY.68_10156 at /gpfs/apps/MN5/GPP/JULIA/1.11.0/INTEL/share/julia/compiled/v1.11/REPL/u0gqU_vklEZ.so (unknown line)
with_repl_linfo at /gpfs/apps/MN5/GPP/JULIA/SRC/julia-1.11.0/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:646
jfptr_with_repl_linfo_10298 at /gpfs/apps/MN5/GPP/JULIA/1.11.0/INTEL/share/julia/compiled/v1.11/REPL/u0gqU_vklEZ.so (unknown line)
display at /gpfs/apps/MN5/GPP/JULIA/SRC/julia-1.11.0/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:334
display at /gpfs/apps/MN5/GPP/JULIA/SRC/julia-1.11.0/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:353 [inlined]
display at ./multimedia.jl:340
jfptr_display_13645 at /gpfs/apps/MN5/GPP/JULIA/1.11.0/INTEL/share/julia/compiled/v1.11/REPL/u0gqU_vklEZ.so (unknown line)
jl_apply at /gpfs/apps/MN5/GPP/JULIA/SRC/julia-1.11.0/src/julia.h:2157 [inlined]
jl_f__call_latest at /gpfs/apps/MN5/GPP/JULIA/SRC/julia-1.11.0/src/builtins.c:875
#invokelatest#2 at ./essentials.jl:1054 [inlined]
invokelatest at ./essentials.jl:1051 [inlined]
print_response at /gpfs/apps/MN5/GPP/JULIA/SRC/julia-1.11.0/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:390
unknown function (ip: 0x7ef883188cab)
#70 at /gpfs/apps/MN5/GPP/JULIA/SRC/julia-1.11.0/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:359
unknown function (ip: 0x7ef883187972)
with_repl_linfo at /gpfs/apps/MN5/GPP/JULIA/SRC/julia-1.11.0/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:646
jfptr_with_repl_linfo_10298 at /gpfs/apps/MN5/GPP/JULIA/1.11.0/INTEL/share/julia/compiled/v1.11/REPL/u0gqU_vklEZ.so (unknown line)
print_response at /gpfs/apps/MN5/GPP/JULIA/SRC/julia-1.11.0/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:357
do_respond at /gpfs/apps/MN5/GPP/JULIA/SRC/julia-1.11.0/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:988
unknown function (ip: 0x7ef88318718b)
jl_apply at /gpfs/apps/MN5/GPP/JULIA/SRC/julia-1.11.0/src/julia.h:2157 [inlined]
jl_f__call_latest at /gpfs/apps/MN5/GPP/JULIA/SRC/julia-1.11.0/src/builtins.c:875
#invokelatest#2 at ./essentials.jl:1054 [inlined]
invokelatest at ./essentials.jl:1051 [inlined]
run_interface at /gpfs/apps/MN5/GPP/JULIA/SRC/julia-1.11.0/usr/share/julia/stdlib/v1.11/REPL/src/LineEdit.jl:2749
jfptr_run_interface_8813 at /gpfs/apps/MN5/GPP/JULIA/1.11.0/INTEL/share/julia/compiled/v1.11/REPL/u0gqU_vklEZ.so (unknown line)
run_frontend at /gpfs/apps/MN5/GPP/JULIA/SRC/julia-1.11.0/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:1456
#75 at /gpfs/apps/MN5/GPP/JULIA/SRC/julia-1.11.0/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:461
jfptr_YY.75_10252 at /gpfs/apps/MN5/GPP/JULIA/1.11.0/INTEL/share/julia/compiled/v1.11/REPL/u0gqU_vklEZ.so (unknown line)
jl_apply at /gpfs/apps/MN5/GPP/JULIA/SRC/julia-1.11.0/src/julia.h:2157 [inlined]
start_task at /gpfs/apps/MN5/GPP/JULIA/SRC/julia-1.11.0/src/task.c:1202
Allocations: 27219278 (Pool: 27218853; Big: 425); GC: 41
Segmentation fault
@mofeing
Copy link
Collaborator

mofeing commented Feb 6, 2025

Seems like convert_simplify pass is not handling complex numbers correctly. Like @Pangoraw observed, it fails with complex<f64> since the following IR segfaults

julia> Reactant.Compiler.run_pass_pipeline_on_source("""
       module {
         func.func @main(%arg0: tensor<2xcomplex<f64>>) -> (tensor<complex<f64>>, tensor<2xcomplex<f64>>) {
           %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<2xcomplex<f64>>) -> tensor<2xcomplex<f64>>
           %c = stablehlo.constant dense<1> : tensor<2xi64>
           %1 = stablehlo.convert %c : (tensor<2xi64>) -> tensor<2xcomplex<f64>>
           %2 = stablehlo.dot_general %0, %1, contracting_dims = [0] x [0] : (tensor<2xcomplex<f64>>, tensor<2xcomplex<f64>>) -> tensor<complex<f64>>
           %3 = stablehlo.transpose %0, dims = [0] : (tensor<2xcomplex<f64>>) -> tensor<2xcomplex<f64>>
           return %2, %3 : tensor<complex<f64>>, tensor<2xcomplex<f64>>
         }
       }""", "enzyme-hlo-generate-td{patterns=convert_simplify},transform-interpreter")

but with f64 works alright

julia> Reactant.Compiler.run_pass_pipeline_on_source("""
              module {
                func.func @main(%arg0: tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>) {
                  %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<2xf64>) -> tensor<2xf64>
                  %c = stablehlo.constant dense<1> : tensor<2xi64>
                  %1 = stablehlo.convert %c : (tensor<2xi64>) -> tensor<2xf64>
                  %2 = stablehlo.dot_general %0, %1, contracting_dims = [0] x [0] : (tensor<2xf64>, tensor<2xf64>) -> tensor<f64>
                  %3 = stablehlo.transpose %0, dims = [0] : (tensor<2xf64>) -> tensor<2xf64>
                  return %2, %3 : tensor<f64>, tensor<2xf64>
                }
              }""", "enzyme-hlo-generate-td{patterns=convert_simplify},transform-interpreter")
module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
    transform.apply_patterns to %0 {
      transform.apply_patterns.enzyme_hlo.convert_simplify
    } : !transform.any_op
    transform.yield
  }
  func.func @main(%arg0: tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>) {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<2xf64>
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<2xf64>) -> tensor<2xf64>
    %1 = stablehlo.dot_general %0, %cst, contracting_dims = [0] x [0] : (tensor<2xf64>, tensor<2xf64>) -> tensor<f64>
    %2 = stablehlo.transpose %0, dims = [0] : (tensor<2xf64>) -> tensor<2xf64>
    return %1, %2 : tensor<f64>, tensor<2xf64>
  }
}

@mofeing mofeing changed the title Segmentation fault when optimizing HLO code for different element types Segfault on convert_simplify optimization with complex numbers Feb 6, 2025
@giordano
Copy link
Member

giordano commented Feb 6, 2025

The smaller reproducer provided by Sergio hits an assertion in a debug build of Reactant:

julia: external/enzyme_ad/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp:1343: DenseElementsAttr (anonymous namespace)::fromTensor(stablehlo::Tensor): Assertion `bitWidth != -1 && "expect integer or float"' failed.

Thread 1 "julia" received signal SIGABRT, Aborted.
Download failed: Invalid argument.  Continuing without source file ./nptl/./nptl/pthread_kill.c.
__pthread_kill_implementation (no_tid=0, signo=6, threadid=<optimized out>) at ./nptl/pthread_kill.c:44
warning: 44     ./nptl/pthread_kill.c: No such file or directory
(gdb) bt
#0  __pthread_kill_implementation (no_tid=0, signo=6, threadid=<optimized out>) at ./nptl/pthread_kill.c:44
#1  __pthread_kill_internal (signo=6, threadid=<optimized out>) at ./nptl/pthread_kill.c:78
#2  __GI___pthread_kill (threadid=<optimized out>, signo=signo@entry=6) at ./nptl/pthread_kill.c:89
#3  0x00007ffff7c4526e in __GI_raise (sig=sig@entry=6) at ../sysdeps/posix/raise.c:26
#4  0x00007ffff7c288ff in __GI_abort () at ./stdlib/abort.c:79
#5  0x00007ffff7c2881b in __assert_fail_base (fmt=0x7ffff7dd01e8 "%s%s%s:%u: %s%sAssertion `%s' failed.\n%n", assertion=assertion@entry=0x7ffc6d323b8f "bitWidth != -1 && \"expect integer or float\"",
    file=file@entry=0x7ffc6e1eca48 "external/enzyme_ad/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp", line=line@entry=1343, function=function@entry=0x7ffc6ea42908 "DenseElementsAttr (anonymous namespace)::fromTensor(stablehlo::Tensor)")
    at ./assert/assert.c:94
#6  0x00007ffff7c3b507 in __assert_fail (assertion=0x7ffc6d323b8f "bitWidth != -1 && \"expect integer or float\"", file=0x7ffc6e1eca48 "external/enzyme_ad/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp", line=1343,
    function=0x7ffc6ea42908 "DenseElementsAttr (anonymous namespace)::fromTensor(stablehlo::Tensor)") at ./assert/assert.c:103
#7  0x00007ffc760a99b7 in (anonymous namespace)::fromTensor (inp=...) at external/enzyme_ad/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp:1343
#8  0x00007ffc760b17f0 in (anonymous namespace)::ConvertSimplify::matchAndRewrite (this=0x168b030, op=..., rewriter=...) at external/enzyme_ad/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp:3559
#9  0x00007ffc7617da5b in mlir::detail::OpOrInterfaceRewritePatternBase<mlir::stablehlo::ConvertOp>::matchAndRewrite (this=0x168b030, op=0xb6ec20, rewriter=...) at external/llvm-project/mlir/include/mlir/IR/PatternMatch.h:331
#10 0x00007ffc8715eb7f in mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>)::$_0::operator()() const (this=0x7fffffff6838) at external/llvm-project/mlir/lib/Rewrite/PatternApplicator.cpp:212
#11 0x00007ffc8715e9f5 in llvm::function_ref<void ()>::callback_fn<mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>)::$_0>(long) (callable=140737488316472) at external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:46
#12 0x00007ffc833d1889 in llvm::function_ref<void ()>::operator()() const (this=0x7fffffff66d0) at external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:69
#13 0x00007ffc8716023b in mlir::MLIRContext::executeAction<mlir::ApplyPatternAction, mlir::Pattern const&>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>, mlir::Pattern const&) (this=0xed99a0, actionFn=..., irUnits=..., args=...)
    at external/llvm-project/mlir/include/mlir/IR/MLIRContext.h:280
#14 0x00007ffc8715cf95 in mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>) (this=0x7fffffff7560, op=0xb6ec20, rewriter=..., canApply=..., onFailure=..., onSuccess=...) at external/llvm-project/mlir/lib/Rewrite/PatternApplicator.cpp:195
#15 0x00007ffc870d00d0 in (anonymous namespace)::GreedyPatternRewriteDriver::processWorklist (this=0x7fffffff7468) at external/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:615
#16 0x00007ffc870cf151 in (anonymous namespace)::RegionPatternRewriteDriver::simplify(bool*) &&::$_2::operator()() const (this=0x7fffffff72b8) at external/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:874
#17 0x00007ffc870cf125 in llvm::function_ref<void ()>::callback_fn<(anonymous namespace)::RegionPatternRewriteDriver::simplify(bool*) &&::$_2>(long) (callable=140737488319160) at external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:46
#18 0x00007ffc833d1889 in llvm::function_ref<void ()>::operator()() const (this=0x7fffffff7250) at external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:69
#19 0x00007ffc870ce8ab in mlir::MLIRContext::executeAction<(anonymous namespace)::GreedyPatternRewriteIteration, long&>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>, long&) (this=0xed99a0, actionFn=..., irUnits=...,
    args=@0x7fffffff73b8: 1) at external/llvm-project/mlir/include/mlir/IR/MLIRContext.h:280
#20 0x00007ffc870cca33 in (anonymous namespace)::RegionPatternRewriteDriver::simplify(bool*) && (this=0x7fffffff7468, changed=0x7fffffff7657) at external/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:872
#21 0x00007ffc870cc60e in mlir::applyPatternsGreedily (region=..., patterns=..., config=..., changed=0x7fffffff7657) at external/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:919
#22 0x00007ffc76072850 in mlir::applyPatternsGreedily (op=0x1481e90, patterns=..., config=..., changed=0x0) at external/llvm-project/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h:174
#23 0x00007ffc79948e09 in mlir::transform::ApplyPatternsOp::applyToOne (this=0x7fffffff7ea0, rewriter=..., target=0x1481e90, results=..., state=...) at external/llvm-project/mlir/lib/Dialect/Transform/IR/TransformOps.cpp:420
#24 0x00007ffc79899e52 in mlir::transform::detail::applyTransformToEach<mlir::transform::ApplyPatternsOp, llvm::iterator_range<llvm::filter_iterator_impl<mlir::Operation* const*, mlir::transform::TransformState::getPayloadOps(mlir::Value) const::{lambda(mlir::Operation*)#1}, std::bidirectional_iterator_tag> >&>(mlir::transform::ApplyPatternsOp, mlir::transform::TransformRewriter&, llvm::iterator_range<llvm::filter_iterator_impl<mlir::Operation* const*, mlir::transform::TransformState::getPayloadOps(mlir::Value) const::{lambda(mlir::Operation*)#1}, std::bidirectional_iterator_tag> >&, llvm::SmallVectorImpl<mlir::transform::ApplyToEachResultList>&, mlir::transform::TransformState&) (transformOp=..., rewriter=..., targets=...,
    results=..., state=...) at external/llvm-project/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h:1512
#25 0x00007ffc798999d9 in mlir::transform::TransformEachOpTrait<mlir::transform::ApplyPatternsOp>::apply (this=0x7fffffff8518, rewriter=..., transformResults=..., state=...)
    at external/llvm-project/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h:1583
#26 0x00007ffc798993a1 in mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Model<mlir::transform::ApplyPatternsOp>::apply (impl=0x166cb40, tablegen_opaque_val=0xf60200, rewriter=..., transformResults=..., state=...)
    at bazel-out/k8-dbg/bin/external/llvm-project/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc:477
#27 0x00007ffc8582502e in mlir::transform::TransformOpInterface::apply (this=0x7fffffff8db8, rewriter=..., transformResults=..., state=...)
    at bazel-out/k8-dbg/bin/external/llvm-project/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc:61
#28 0x00007ffc858243e4 in mlir::transform::TransformState::applyTransform (this=0x7fffffffaaf0, transform=...) at external/llvm-project/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp:951
#29 0x00007ffc79953e30 in applySequenceBlock (block=..., mode=mlir::transform::FailurePropagationMode::Propagate, state=..., results=...) at external/llvm-project/mlir/lib/Dialect/Transform/IR/TransformOps.cpp:1786
#30 0x00007ffc79957af4 in mlir::transform::NamedSequenceOp::apply (this=0x7fffffff9888, rewriter=..., results=..., state=...) at external/llvm-project/mlir/lib/Dialect/Transform/IR/TransformOps.cpp:2153
#31 0x00007ffc798d52c1 in mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Model<mlir::transform::NamedSequenceOp>::apply (impl=0xa11ac0, tablegen_opaque_val=0x1398600, rewriter=..., transformResults=..., state=...)
    at bazel-out/k8-dbg/bin/external/llvm-project/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc:477
#32 0x00007ffc8582502e in mlir::transform::TransformOpInterface::apply (this=0x7fffffffa128, rewriter=..., transformResults=..., state=...)
    at bazel-out/k8-dbg/bin/external/llvm-project/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc:61
#33 0x00007ffc858243e4 in mlir::transform::TransformState::applyTransform (this=0x7fffffffaaf0, transform=...) at external/llvm-project/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp:951
#34 0x00007ffc8582d531 in mlir::transform::applyTransforms(mlir::Operation*, mlir::transform::TransformOpInterface, mlir::RaggedArray<llvm::PointerUnion<mlir::Operation*, mlir::Attribute, mlir::Value> > const&, mlir::transform::TransformOptions const&, bool, llvm::function_ref<void (mlir::transform::TransformState&)>, llvm::function_ref<llvm::LogicalResult (mlir::transform::TransformState&)>) (payloadRoot=0x8503a0, transform=..., extraMapping=..., options=...,
    enforceToplevelTransformOp=false, stateInitializer=..., stateExporter=...) at external/llvm-project/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp:2018
#35 0x00007ffc798681cb in mlir::transform::applyTransformNamedSequence (bindings=..., transformRoot=..., transformModule=..., options=...) at external/llvm-project/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp:234
#36 0x00007ffc7985f26c in (anonymous namespace)::InterpreterPass::runOnOperation (this=0x8a52a0) at external/llvm-project/mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp:147
#37 0x00007ffc8766dfb4 in mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1::operator()() const (this=0x7fffffffb780) at external/llvm-project/mlir/lib/Pass/Pass.cpp:526
#38 0x00007ffc8766df55 in llvm::function_ref<void ()>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1>(long) (callable=140737488336768)
    at external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:46
#39 0x00007ffc833d1889 in llvm::function_ref<void ()>::operator()() const (this=0x7fffffffb690) at external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:69
#40 0x00007ffc87670bab in mlir::MLIRContext::executeAction<mlir::PassExecutionAction, mlir::Pass&>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>, mlir::Pass&) (this=0xed99a0, actionFn=..., irUnits=..., args=...)
    at external/llvm-project/mlir/include/mlir/IR/MLIRContext.h:280
#41 0x00007ffc87668d19 in mlir::detail::OpToOpPassAdaptor::run (pass=0x8a52a0, op=0x8503a0, am=..., verifyPasses=true, parentInitGeneration=1) at external/llvm-project/mlir/lib/Pass/Pass.cpp:520
#42 0x00007ffc8766942b in mlir::detail::OpToOpPassAdaptor::runPipeline (pm=..., op=0x8503a0, am=..., verifyPasses=true, parentInitGeneration=1, instrumentor=0x0, parentInfo=0x0) at external/llvm-project/mlir/lib/Pass/Pass.cpp:592
#43 0x00007ffc8766b6f8 in mlir::PassManager::runPasses (this=0x1074bb0, op=0x8503a0, am=...) at external/llvm-project/mlir/lib/Pass/Pass.cpp:905
#44 0x00007ffc8766b591 in mlir::PassManager::run (this=0x1074bb0, op=0x8503a0) at external/llvm-project/mlir/lib/Pass/Pass.cpp:885
#45 0x00007ffc84da7be2 in mlirPassManagerRunOnOp (passManager=..., op=...) at external/llvm-project/mlir/lib/CAPI/IR/Pass.cpp:44
#46 0x00007ffc50f077df in mlirPassManagerRunOnOp () at /home/giordano/.julia/dev/Reactant/src/mlir/libMLIR_h.jl:8174
#47 run! () at /home/giordano/.julia/dev/Reactant/src/mlir/IR/Pass.jl:74
#48 julia_#run_pass_pipeline!#2_20554 () at /home/giordano/.julia/dev/Reactant/src/Compiler.jl:352

The assertion is https://github.com/EnzymeAD/Enzyme-JAX/blob/d694d09a33814a83abaef9d1ef924614ab59e4c5/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp#L1343

  if (auto floatType = dyn_cast<FloatType>(elemType))
    bitWidth = floatType.getWidth();
  assert(bitWidth != -1 && "expect integer or float");

@wsmoses
Copy link
Member

wsmoses commented Feb 6, 2025

ironically x/ref openxla/stablehlo#2709 cc @GleasonK

@wsmoses
Copy link
Member

wsmoses commented Feb 6, 2025

and also #699 is equivalent to this cc @ptiede

@giordano giordano linked a pull request Feb 7, 2025 that will close this issue
@giordano giordano removed a link to a pull request Feb 8, 2025
@giordano giordano linked a pull request Feb 8, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants