-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Segfault on convert_simplify
optimization with complex numbers
#695
Comments
Seems like julia> Reactant.Compiler.run_pass_pipeline_on_source("""
module {
func.func @main(%arg0: tensor<2xcomplex<f64>>) -> (tensor<complex<f64>>, tensor<2xcomplex<f64>>) {
%0 = stablehlo.transpose %arg0, dims = [0] : (tensor<2xcomplex<f64>>) -> tensor<2xcomplex<f64>>
%c = stablehlo.constant dense<1> : tensor<2xi64>
%1 = stablehlo.convert %c : (tensor<2xi64>) -> tensor<2xcomplex<f64>>
%2 = stablehlo.dot_general %0, %1, contracting_dims = [0] x [0] : (tensor<2xcomplex<f64>>, tensor<2xcomplex<f64>>) -> tensor<complex<f64>>
%3 = stablehlo.transpose %0, dims = [0] : (tensor<2xcomplex<f64>>) -> tensor<2xcomplex<f64>>
return %2, %3 : tensor<complex<f64>>, tensor<2xcomplex<f64>>
}
}""", "enzyme-hlo-generate-td{patterns=convert_simplify},transform-interpreter") but with julia> Reactant.Compiler.run_pass_pipeline_on_source("""
module {
func.func @main(%arg0: tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>) {
%0 = stablehlo.transpose %arg0, dims = [0] : (tensor<2xf64>) -> tensor<2xf64>
%c = stablehlo.constant dense<1> : tensor<2xi64>
%1 = stablehlo.convert %c : (tensor<2xi64>) -> tensor<2xf64>
%2 = stablehlo.dot_general %0, %1, contracting_dims = [0] x [0] : (tensor<2xf64>, tensor<2xf64>) -> tensor<f64>
%3 = stablehlo.transpose %0, dims = [0] : (tensor<2xf64>) -> tensor<2xf64>
return %2, %3 : tensor<f64>, tensor<2xf64>
}
}""", "enzyme-hlo-generate-td{patterns=convert_simplify},transform-interpreter")
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.enzyme_hlo.convert_simplify
} : !transform.any_op
transform.yield
}
func.func @main(%arg0: tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>) {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<2xf64>
%0 = stablehlo.transpose %arg0, dims = [0] : (tensor<2xf64>) -> tensor<2xf64>
%1 = stablehlo.dot_general %0, %cst, contracting_dims = [0] x [0] : (tensor<2xf64>, tensor<2xf64>) -> tensor<f64>
%2 = stablehlo.transpose %0, dims = [0] : (tensor<2xf64>) -> tensor<2xf64>
return %1, %2 : tensor<f64>, tensor<2xf64>
}
} |
convert_simplify
optimization with complex numbers
The smaller reproducer provided by Sergio hits an assertion in a debug build of Reactant:
The assertion is https://github.com/EnzymeAD/Enzyme-JAX/blob/d694d09a33814a83abaef9d1ef924614ab59e4c5/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp#L1343 if (auto floatType = dyn_cast<FloatType>(elemType))
bitWidth = floatType.getWidth();
assert(bitWidth != -1 && "expect integer or float"); |
ironically x/ref openxla/stablehlo#2709 cc @GleasonK |
CC @mofeing
We found the following unexpected behavior when performing
dot_general
operation over two vectors of different types.The text was updated successfully, but these errors were encountered: