Skip to content

Commit

Permalink
use zeroshiftconst
Browse files Browse the repository at this point in the history
  • Loading branch information
abhigunj committed Feb 1, 2025
1 parent 87be1a1 commit 4dff7a0
Showing 1 changed file with 12 additions and 31 deletions.
43 changes: 12 additions & 31 deletions stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,20 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.td"
#include "stablehlo/dialect/StablehloOps.td"

Rewrite zeroShiftConst() -> Op [{
auto type = rewriter.getI8Type();
auto attr = mlir::DenseElementsAttr::get(
llvm::cast<mlir::ShapedType>(type), rewriter.getZeroAttr(type));
return rewriter.create<mlir::tosa::ConstOp>(
rewriter.getUnknownLoc(), type, attr);
}];

// Helper functions.
Rewrite changeElementTypeToI1(type: Type) -> Type [{
auto tensorType = llvm::cast<mlir::RankedTensorType>(type);
return RankedTensorType::get(tensorType.getShape(), rewriter.getI1Type());
}];

Rewrite changeElementTypeToI8(type: Type) -> Type [{
auto tensorType = llvm::cast<mlir::RankedTensorType>(type);
return RankedTensorType::get(tensorType.getShape(), rewriter.getI8Type());
}];

Rewrite zerosLike(op: Op, type: Type) -> Op [{
auto elementType = llvm::cast<mlir::TensorType>(type).getElementType();
llvm::SmallVector<mlir::Attribute, 4> outputValue;

if (elementType.isF16() || elementType.isF32() || elementType.isBF16()) {
outputValue.push_back(rewriter.getFloatAttr(elementType, 0));
} else {
outputValue.push_back(rewriter.getIntegerAttr(elementType, 0));
}

return rewriter.create<mlir::tosa::ConstOp>(
op->getLoc(), type,
mlir::DenseElementsAttr::get(
llvm::cast<mlir::ShapedType>(type), outputValue));
}];

Rewrite onesLike(op: Op, type: Type) -> Op [{
auto elementType = llvm::cast<mlir::TensorType>(type).getElementType();
llvm::SmallVector<mlir::Attribute, 4> outputValue;
Expand Down Expand Up @@ -155,16 +142,10 @@ Pattern =>
replace op<stablehlo.minimum>(input0 : Value<_: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>)
with op<tosa.minimum>(input0, input1);
Pattern {
let root = op<stablehlo.multiply>(input0 : Value<inputType: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>);
rewrite root with {
let typei8 = changeElementTypeToI8(inputType);
let zeros = zerosLike(root, typei8);
let mulResult = op<tosa.mul>(input0, input1, zeros) -> (inputType);
replace root with mulResult;
};
}
Pattern =>
replace op<stablehlo.multiply>(input0 : Value<_: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>)
with op<tosa.mul>(input0, input1, zeroShiftConst());
Pattern =>
replace op<stablehlo.or>(input0 : Value<_: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>)
Expand Down

0 comments on commit 4dff7a0

Please sign in to comment.