-
Notifications
You must be signed in to change notification settings - Fork 611
[TorchToLinalg]Lower torch.gcd to linalg and scf #3732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,8 @@ | |
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" | ||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/Dialect/Math/IR/Math.h" | ||
#include "mlir/Dialect/SCF/IR/SCF.h" | ||
#include "mlir/IR/Matchers.h" | ||
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h" | ||
#include "torch-mlir/Conversion/Utils/Utils.h" | ||
|
@@ -213,6 +215,83 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> { | |
}; | ||
} // namespace | ||
|
||
namespace { | ||
class ConvertAtenGcdOp : public OpConversionPattern<torch::Torch::AtenGcdOp> { | ||
public: | ||
using OpConversionPattern::OpConversionPattern; | ||
|
||
LogicalResult | ||
matchAndRewrite(torch::Torch::AtenGcdOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
auto self = adaptor.getSelf(); // tensor A | ||
auto other = adaptor.getOther(); // tensor B of the same size | ||
auto loc = op.getLoc(); | ||
|
||
TensorType resultType = | ||
cast<TensorType>(getTypeConverter()->convertType(op.getType())); | ||
|
||
auto gcdPayloadBody = [&](OpBuilder &b, Location loc, | ||
ValueRange genericInstructionArgs) { | ||
auto A = genericInstructionArgs[0]; | ||
A = b.create<mlir::math::AbsIOp>(loc, A); | ||
auto B = genericInstructionArgs[1]; | ||
B = b.create<mlir::math::AbsIOp>(loc, B); | ||
auto zero = b.create<mlir::arith::ConstantIntOp>(loc, 0, A.getType()); | ||
|
||
Value AtrailingZerosCount = | ||
b.create<mlir::math::CountTrailingZerosOp>(loc, A); | ||
Value BtrailingZerosCount = | ||
b.create<mlir::math::CountTrailingZerosOp>(loc, B); | ||
auto smalerZerosCount = b.create<mlir::arith::MinSIOp>( | ||
loc, AtrailingZerosCount, BtrailingZerosCount); | ||
auto shiftedA = b.create<mlir::arith::ShRSIOp>(loc, A, smalerZerosCount); | ||
auto shiftedB = b.create<mlir::arith::ShRSIOp>(loc, B, smalerZerosCount); | ||
|
||
auto findGcdConditionBlock = [&](mlir::OpBuilder &b, mlir::Location loc, | ||
mlir::ValueRange innerLoopArgs) { | ||
Value min = b.create<mlir::arith::MinSIOp>(loc, innerLoopArgs[0], | ||
innerLoopArgs[1]); | ||
Value max = b.create<mlir::arith::MaxSIOp>(loc, innerLoopArgs[0], | ||
innerLoopArgs[1]); | ||
|
||
auto cmp = b.create<mlir::arith::CmpIOp>( | ||
loc, mlir::arith::CmpIPredicate::ne, min, zero); | ||
b.create<mlir::scf::ConditionOp>(loc, cmp, ValueRange{min, max}); | ||
}; | ||
auto findGcdBodyBlock = [&](mlir::OpBuilder &b, mlir::Location loc, | ||
mlir::ValueRange innerLoopArgs) { | ||
Value min = innerLoopArgs[0]; | ||
Value max = innerLoopArgs[1]; | ||
max = b.create<mlir::arith::SubIOp>(loc, max, min); | ||
|
||
Value maxTrailingZerosCount = | ||
b.create<mlir::math::CountTrailingZerosOp>(loc, max); | ||
max = b.create<mlir::arith::ShRSIOp>(loc, max, maxTrailingZerosCount); | ||
b.create<mlir::scf::YieldOp>(loc, ValueRange{min, max}); | ||
}; | ||
|
||
auto findGcdWhileOp = b.create<mlir::scf::WhileOp>( | ||
loc, TypeRange{shiftedA.getType(), shiftedB.getType()}, | ||
ValueRange{shiftedA, shiftedB}, findGcdConditionBlock, | ||
findGcdBodyBlock); | ||
|
||
Value gcdResult = findGcdWhileOp.getResult(1); | ||
gcdResult = | ||
b.create<mlir::arith::ShLIOp>(loc, gcdResult, smalerZerosCount); | ||
|
||
b.create<linalg::YieldOp>(loc, gcdResult); | ||
}; | ||
Comment on lines
+250
to
+283
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You don't need |
||
|
||
other = torch_to_linalg::createElementwiseLinalgGeneric( | ||
rewriter, loc, ValueRange{self, other}, | ||
cast<TensorType>(self.getType()).getElementType(), gcdPayloadBody); | ||
|
||
Comment on lines
+233
to
+288
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing code comments for all the major blocks and a high-level description on the top for the lowering. |
||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, other); | ||
return success(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
namespace { | ||
class ConvertAtenFlipOp : public OpConversionPattern<AtenFlipOp> { | ||
public: | ||
|
@@ -1387,4 +1466,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( | |
patterns.add<ConvertAtenBmmOp>(typeConverter, context); | ||
target.addIllegalOp<AtenConvolutionOp>(); | ||
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context); | ||
target.addIllegalOp<AtenGcdOp>(); | ||
patterns.add<ConvertAtenGcdOp>(typeConverter, context); | ||
} |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -6639,6 +6639,72 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { | |||
" }\n" | ||||
" return %8 : !torch.tuple<list<int>, list<int>>\n" | ||||
" }\n" | ||||
" func.func @\"__torch_mlir_shape_fn.aten.gcd\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n" | ||||
" %none = torch.constant.none\n" | ||||
" %str = torch.constant.str \"AssertionError: Shapes must be the same or 'other' must be a single element tensor.\"\n" | ||||
" %false = torch.constant.bool false\n" | ||||
" %true = torch.constant.bool true\n" | ||||
" %int1 = torch.constant.int 1\n" | ||||
" %int0 = torch.constant.int 0\n" | ||||
" %0 = torch.aten.eq.int_list %arg0, %arg1 : !torch.list<int>, !torch.list<int> -> !torch.bool\n" | ||||
" %1 = torch.prim.If %0 -> (!torch.bool) {\n" | ||||
" torch.prim.If.yield %true : !torch.bool\n" | ||||
" } else {\n" | ||||
" %2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n" | ||||
" %3 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n" | ||||
" %4 = torch.prim.If %3 -> (!torch.bool) {\n" | ||||
" %5 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n" | ||||
" %6 = torch.aten.eq.int %5, %int0 : !torch.int, !torch.int -> !torch.bool\n" | ||||
" torch.prim.If.yield %6 : !torch.bool\n" | ||||
" } else {\n" | ||||
" torch.prim.If.yield %false : !torch.bool\n" | ||||
" }\n" | ||||
" torch.prim.If.yield %4 : !torch.bool\n" | ||||
" }\n" | ||||
" torch.prim.If %1 -> () {\n" | ||||
" torch.prim.If.yield\n" | ||||
" } else {\n" | ||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" | ||||
" torch.prim.If.yield\n" | ||||
" }\n" | ||||
" return %arg0 : !torch.list<int>\n" | ||||
" }\n" | ||||
" func.func @\"__torch_mlir_dtype_fn.aten.gcd\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n" | ||||
" %none = torch.constant.none\n" | ||||
" %str = torch.constant.str \"AssertionError: aten.gcd works only with integer types\"\n" | ||||
" %false = torch.constant.bool false\n" | ||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" | ||||
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" | ||||
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" | ||||
" %3 = torch.prim.If %2 -> (!torch.bool) {\n" | ||||
" %4 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" | ||||
" torch.prim.If.yield %4 : !torch.bool\n" | ||||
" } else {\n" | ||||
" torch.prim.If.yield %false : !torch.bool\n" | ||||
" }\n" | ||||
" torch.prim.If %3 -> () {\n" | ||||
" torch.prim.If.yield\n" | ||||
" } else {\n" | ||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" | ||||
" torch.prim.If.yield\n" | ||||
" }\n" | ||||
" return %0#1 : !torch.int\n" | ||||
" }\n" | ||||
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this part moved? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The file is auto-generated so I think it's a quirk of torch-mlir/.github/workflows/ci.yml Line 77 in 99115dc
|
||||
" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list<int>\n" | ||||
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n" | ||||
" return %1 : !torch.bool\n" | ||||
" }\n" | ||||
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list<int> {\n" | ||||
" %int4 = torch.constant.int 4\n" | ||||
" %int3 = torch.constant.int 3\n" | ||||
" %int2 = torch.constant.int 2\n" | ||||
" %int1 = torch.constant.int 1\n" | ||||
" %int0 = torch.constant.int 0\n" | ||||
" %int11 = torch.constant.int 11\n" | ||||
" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n" | ||||
" return %0 : !torch.list<int>\n" | ||||
" }\n" | ||||
" func.func @\"__torch_mlir_shape_fn.aten.detach\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n" | ||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" | ||||
" return %0 : !torch.list<int>\n" | ||||
|
@@ -11253,21 +11319,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { | |||
" %3 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" | ||||
" return %3 : !torch.int\n" | ||||
" }\n" | ||||
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" | ||||
" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list<int>\n" | ||||
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n" | ||||
" return %1 : !torch.bool\n" | ||||
" }\n" | ||||
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list<int> {\n" | ||||
" %int4 = torch.constant.int 4\n" | ||||
" %int3 = torch.constant.int 3\n" | ||||
" %int2 = torch.constant.int 2\n" | ||||
" %int1 = torch.constant.int 1\n" | ||||
" %int0 = torch.constant.int 0\n" | ||||
" %int11 = torch.constant.int 11\n" | ||||
" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n" | ||||
" return %0 : !torch.list<int>\n" | ||||
" }\n" | ||||
" func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n" | ||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" | ||||
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto