diff --git a/WORKSPACE.bazel b/WORKSPACE.bazel index 0da2d5ad85..ee92aa2ea7 100644 --- a/WORKSPACE.bazel +++ b/WORKSPACE.bazel @@ -17,9 +17,9 @@ workspace(name = "stablehlo") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -LLVM_COMMIT = "aa65f93b71dee8cacb22be1957673c8be6a3ec24" +LLVM_COMMIT = "956c0707d9098499a2682297b71f46b0a562eed9" -LLVM_SHA256 = "0a6046edb6a9834d5b912ec0e705dec91d39ee1b7b2fbb5930955d83d2090ff5" +LLVM_SHA256 = "f90b866908daa3c65b74454943e52b59f40ab448f42a13b23e9823045f017066" http_archive( name = "llvm-raw", diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index 775886a14b..b3842bfc2e 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1 +1 @@ -5c24847e7dba01dde230e18b39a3074022279c89 +956c0707d9098499a2682297b71f46b0a562eed9 diff --git a/stablehlo/conversions/tosa/tests/unary.mlir b/stablehlo/conversions/tosa/tests/unary.mlir index a735c337e5..3ab3501d96 100644 --- a/stablehlo/conversions/tosa/tests/unary.mlir +++ b/stablehlo/conversions/tosa/tests/unary.mlir @@ -79,7 +79,9 @@ func.func @negate(%arg : tensor<10xf32>) -> tensor<10xf32> { // CHECK-LABEL: @slice func.func @slice(%arg : tensor<4x3xf32>) -> tensor<2x2xf32> { - // CHECK: tosa.slice %arg0 {size = array, start = array} + // CHECK: %[[SIZE:.*]] = tosa.const_shape {value = dense<[2, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK: %[[START:.*]] = tosa.const_shape {value = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK: tosa.slice %arg0, %[[SIZE]], %[[START]] %0 = "stablehlo.slice"(%arg) { start_indices = array, limit_indices = array, diff --git a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp index b4430e7c65..ec16ac3b92 100644 --- a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp +++ b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinAttributes.h" @@ -435,8 +436,8 @@ struct ConvertStablehloSliceOp : public OpRewritePattern { rewriter.replaceOpWithNewOp( op, op.getType(), op.getOperand(), - rewriter.getDenseI64ArrayAttr(startIndicesI64), - rewriter.getDenseI64ArrayAttr(size)); + getTosaConstShape(rewriter, op.getLoc(), startIndicesI64), + getTosaConstShape(rewriter, op.getLoc(), size)); return success(); } }; diff --git a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll index f6b2e3ef29..c2eb32cc2a 100644 --- a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll +++ b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll @@ -15,15 +15,33 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.td" #include "stablehlo/dialect/StablehloOps.td" -Rewrite zeroConst() -> Op [{ - auto type = rewriter.getI8Type(); - auto attr = mlir::DenseElementsAttr::get( - llvm::cast(type), rewriter.getZeroAttr(type)); +// Helper functions. +Rewrite changeElementTypeToI1(type: Type) -> Type [{ + auto tensorType = llvm::cast(type); + return RankedTensorType::get(tensorType.getShape(), rewriter.getI1Type()); +}]; + +Rewrite changeElementTypeToI8(type: Type) -> Type [{ + auto tensorType = llvm::cast(type); + return RankedTensorType::get(tensorType.getShape(), rewriter.getI8Type()); +}]; + +Rewrite zerosLike(op: Op, type: Type) -> Op [{ + auto elementType = llvm::cast(type).getElementType(); + llvm::SmallVector outputValue; + + if (elementType.isF16() || elementType.isF32() || elementType.isBF16()) { + outputValue.push_back(rewriter.getFloatAttr(elementType, 0)); + } else { + outputValue.push_back(rewriter.getIntegerAttr(elementType, 0)); + } + return rewriter.create( - rewriter.getUnknownLoc(), type, attr); + op->getLoc(), type, + mlir::DenseElementsAttr::get( + llvm::cast(type), outputValue)); }]; -// Helper functions. Rewrite onesLike(op: Op, type: Type) -> Op [{ auto elementType = llvm::cast(type).getElementType(); llvm::SmallVector outputValue; @@ -55,11 +73,6 @@ Rewrite positiveFloatInfinityLike(op: Op, type: Type) -> Op [{ llvm::cast(type), outputValue)); }]; -Rewrite changeElementTypeToI1(type: Type) -> Type [{ - auto tensorType = llvm::cast(type); - return RankedTensorType::get(tensorType.getShape(), rewriter.getI1Type()); -}]; - // Nullary ops. Pattern => replace op {value = input: Attr<_: Tosa_Tensor>} @@ -142,10 +155,16 @@ Pattern => replace op(input0 : Value<_: Tosa_Tensor>, input1 : Value<_: Tosa_Tensor>) with op(input0, input1); -Pattern => - replace op(input0 : Value<_: Tosa_Tensor>, - input1 : Value<_: Tosa_Tensor>) - with op(input0, input1, zeroConst()); +Pattern { + let root = op(input0 : Value, + input1 : Value<_: Tosa_Tensor>); + rewrite root with { + let typei8 = changeElementTypeToI8(inputType); + let zeros = zerosLike(root, typei8); + let mulResult = op(input0, input1, zeros) -> (inputType); + replace root with mulResult; + }; +} Pattern => replace op(input0 : Value<_: Tosa_Tensor>, input1 : Value<_: Tosa_Tensor>)