Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
abhigunj authored Feb 1, 2025
1 parent bde336e commit 7775e3e
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 21 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ workspace(name = "stablehlo")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

LLVM_COMMIT = "aa65f93b71dee8cacb22be1957673c8be6a3ec24"
LLVM_COMMIT = "956c0707d9098499a2682297b71f46b0a562eed9"

LLVM_SHA256 = "0a6046edb6a9834d5b912ec0e705dec91d39ee1b7b2fbb5930955d83d2090ff5"
LLVM_SHA256 = "f90b866908daa3c65b74454943e52b59f40ab448f42a13b23e9823045f017066"

http_archive(
name = "llvm-raw",
Expand Down
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
5c24847e7dba01dde230e18b39a3074022279c89
956c0707d9098499a2682297b71f46b0a562eed9
4 changes: 3 additions & 1 deletion stablehlo/conversions/tosa/tests/unary.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ func.func @negate(%arg : tensor<10xf32>) -> tensor<10xf32> {

// CHECK-LABEL: @slice
func.func @slice(%arg : tensor<4x3xf32>) -> tensor<2x2xf32> {
// CHECK: tosa.slice %arg0 {size = array<i64: 2, 2>, start = array<i64: 2, 1>}
// CHECK: %[[SIZE:.*]] = tosa.const_shape {value = dense<[2, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: %[[START:.*]] = tosa.const_shape {value = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: tosa.slice %arg0, %[[SIZE]], %[[START]]
%0 = "stablehlo.slice"(%arg) {
start_indices = array<i64: 2, 1>,
limit_indices = array<i64: 4, 3>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinAttributes.h"
Expand Down Expand Up @@ -435,8 +436,8 @@ struct ConvertStablehloSliceOp : public OpRewritePattern<stablehlo::SliceOp> {

rewriter.replaceOpWithNewOp<tosa::SliceOp>(
op, op.getType(), op.getOperand(),
rewriter.getDenseI64ArrayAttr(startIndicesI64),
rewriter.getDenseI64ArrayAttr(size));
getTosaConstShape(rewriter, op.getLoc(), startIndicesI64),
getTosaConstShape(rewriter, op.getLoc(), size));
return success();
}
};
Expand Down
49 changes: 34 additions & 15 deletions stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,33 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.td"
#include "stablehlo/dialect/StablehloOps.td"

Rewrite zeroConst() -> Op [{
auto type = rewriter.getI8Type();
auto attr = mlir::DenseElementsAttr::get(
llvm::cast<mlir::ShapedType>(type), rewriter.getZeroAttr(type));
// Helper functions.
Rewrite changeElementTypeToI1(type: Type) -> Type [{
auto tensorType = llvm::cast<mlir::RankedTensorType>(type);
return RankedTensorType::get(tensorType.getShape(), rewriter.getI1Type());
}];

Rewrite changeElementTypeToI8(type: Type) -> Type [{
auto tensorType = llvm::cast<mlir::RankedTensorType>(type);
return RankedTensorType::get(tensorType.getShape(), rewriter.getI8Type());
}];

Rewrite zerosLike(op: Op, type: Type) -> Op [{
auto elementType = llvm::cast<mlir::TensorType>(type).getElementType();
llvm::SmallVector<mlir::Attribute, 4> outputValue;

if (elementType.isF16() || elementType.isF32() || elementType.isBF16()) {
outputValue.push_back(rewriter.getFloatAttr(elementType, 0));
} else {
outputValue.push_back(rewriter.getIntegerAttr(elementType, 0));
}

return rewriter.create<mlir::tosa::ConstOp>(
rewriter.getUnknownLoc(), type, attr);
op->getLoc(), type,
mlir::DenseElementsAttr::get(
llvm::cast<mlir::ShapedType>(type), outputValue));
}];

// Helper functions.
Rewrite onesLike(op: Op, type: Type) -> Op [{
auto elementType = llvm::cast<mlir::TensorType>(type).getElementType();
llvm::SmallVector<mlir::Attribute, 4> outputValue;
Expand Down Expand Up @@ -55,11 +73,6 @@ Rewrite positiveFloatInfinityLike(op: Op, type: Type) -> Op [{
llvm::cast<mlir::ShapedType>(type), outputValue));
}];

Rewrite changeElementTypeToI1(type: Type) -> Type [{
auto tensorType = llvm::cast<mlir::RankedTensorType>(type);
return RankedTensorType::get(tensorType.getShape(), rewriter.getI1Type());
}];

// Nullary ops.
Pattern =>
replace op<stablehlo.constant> {value = input: Attr<_: Tosa_Tensor>}
Expand Down Expand Up @@ -142,10 +155,16 @@ Pattern =>
replace op<stablehlo.minimum>(input0 : Value<_: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>)
with op<tosa.minimum>(input0, input1);
Pattern =>
replace op<stablehlo.multiply>(input0 : Value<_: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>)
with op<tosa.mul>(input0, input1, zeroConst());
Pattern {
let root = op<stablehlo.multiply>(input0 : Value<inputType: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>);
rewrite root with {
let typei8 = changeElementTypeToI8(inputType);
let zeros = zerosLike(root, typei8);
let mulResult = op<tosa.mul>(input0, input1, zeros) -> (inputType);
replace root with mulResult;
};
}
Pattern =>
replace op<stablehlo.or>(input0 : Value<_: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>)
Expand Down

0 comments on commit 7775e3e

Please sign in to comment.