Skip to content

[TorchToLinalg]Lower torch.gcd to linalg and scf #3732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13228,6 +13228,31 @@ def Torch_AtenStftOp : Torch_Op<"aten.stft", [
}];
}

def Torch_AtenGcdOp : Torch_Op<"aten.gcd", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::gcd : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenGcdOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenGcdOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasVerifier = 1;
}

def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
81 changes: 81 additions & 0 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Matchers.h"
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
Expand Down Expand Up @@ -213,6 +215,83 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
};
} // namespace

namespace {
class ConvertAtenGcdOp : public OpConversionPattern<torch::Torch::AtenGcdOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(torch::Torch::AtenGcdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto self = adaptor.getSelf(); // tensor A
auto other = adaptor.getOther(); // tensor B of the same size
auto loc = op.getLoc();

TensorType resultType =

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto

cast<TensorType>(getTypeConverter()->convertType(op.getType()));

auto gcdPayloadBody = [&](OpBuilder &b, Location loc,
ValueRange genericInstructionArgs) {
auto A = genericInstructionArgs[0];
A = b.create<mlir::math::AbsIOp>(loc, A);
auto B = genericInstructionArgs[1];
B = b.create<mlir::math::AbsIOp>(loc, B);
auto zero = b.create<mlir::arith::ConstantIntOp>(loc, 0, A.getType());

Value AtrailingZerosCount =
b.create<mlir::math::CountTrailingZerosOp>(loc, A);
Value BtrailingZerosCount =
b.create<mlir::math::CountTrailingZerosOp>(loc, B);
auto smalerZerosCount = b.create<mlir::arith::MinSIOp>(
loc, AtrailingZerosCount, BtrailingZerosCount);
auto shiftedA = b.create<mlir::arith::ShRSIOp>(loc, A, smalerZerosCount);
auto shiftedB = b.create<mlir::arith::ShRSIOp>(loc, B, smalerZerosCount);

auto findGcdConditionBlock = [&](mlir::OpBuilder &b, mlir::Location loc,
mlir::ValueRange innerLoopArgs) {
Value min = b.create<mlir::arith::MinSIOp>(loc, innerLoopArgs[0],
innerLoopArgs[1]);
Value max = b.create<mlir::arith::MaxSIOp>(loc, innerLoopArgs[0],
innerLoopArgs[1]);

auto cmp = b.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ne, min, zero);
b.create<mlir::scf::ConditionOp>(loc, cmp, ValueRange{min, max});
};
auto findGcdBodyBlock = [&](mlir::OpBuilder &b, mlir::Location loc,
mlir::ValueRange innerLoopArgs) {
Value min = innerLoopArgs[0];
Value max = innerLoopArgs[1];
max = b.create<mlir::arith::SubIOp>(loc, max, min);

Value maxTrailingZerosCount =
b.create<mlir::math::CountTrailingZerosOp>(loc, max);
max = b.create<mlir::arith::ShRSIOp>(loc, max, maxTrailingZerosCount);
b.create<mlir::scf::YieldOp>(loc, ValueRange{min, max});
};

auto findGcdWhileOp = b.create<mlir::scf::WhileOp>(
loc, TypeRange{shiftedA.getType(), shiftedB.getType()},
ValueRange{shiftedA, shiftedB}, findGcdConditionBlock,
findGcdBodyBlock);

Value gcdResult = findGcdWhileOp.getResult(1);
gcdResult =
b.create<mlir::arith::ShLIOp>(loc, gcdResult, smalerZerosCount);

b.create<linalg::YieldOp>(loc, gcdResult);
};
Comment on lines +250 to +283

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need mlir:: anywhere here since there is a namespace using for it at the top.


other = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, ValueRange{self, other},
cast<TensorType>(self.getType()).getElementType(), gcdPayloadBody);

Comment on lines +233 to +288

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing code comments for all the major blocks and a high-level description on the top for the lowering.

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, other);
return success();
}
};
} // namespace

namespace {
class ConvertAtenFlipOp : public OpConversionPattern<AtenFlipOp> {
public:
Expand Down Expand Up @@ -1387,4 +1466,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
target.addIllegalOp<AtenConvolutionOp>();
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
target.addIllegalOp<AtenGcdOp>();
patterns.add<ConvertAtenGcdOp>(typeConverter, context);
}
34 changes: 34 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5527,3 +5527,37 @@ LogicalResult AtenRot90Op::verify() {

return success();
}

LogicalResult AtenGcdOp::verify() {

auto selfType = cast<BaseTensorType>(getSelf().getType());
auto otherType = cast<BaseTensorType>(getOther().getType());

if (!selfType.hasDtype() || !selfType.hasSizes() || !otherType.hasDtype() ||
!otherType.hasSizes())
return success();

auto selfShape = selfType.getSizes();
auto otherShape = selfType.getSizes();
int64_t selfRank = selfShape.size();
int64_t otherRank = otherShape.size();
auto selfDtype = selfType.getDtype();

if (!isa<mlir::IntegerType>(selfDtype))
return emitOpError("expected an integer type for input tensor, but got ")
<< selfDtype;

if (otherRank == 1 && otherShape[0] == 1)
return success();

if (selfRank != otherRank)
return emitOpError("Tensors must be of same rank or second tensor must be "
"a single element tensor");

for (int i = 0; i < selfRank; i++) {
if (selfShape[i] != otherShape[i])
return emitOpError("Dimensions od tensors font match in dim ") << i;
}

return success();
}
81 changes: 66 additions & 15 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6639,6 +6639,72 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %8 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.gcd\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: Shapes must be the same or 'other' must be a single element tensor.\"\n"
" %false = torch.constant.bool false\n"
" %true = torch.constant.bool true\n"
" %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.eq.int_list %arg0, %arg1 : !torch.list<int>, !torch.list<int> -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %3 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
" %5 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %6 = torch.aten.eq.int %5, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %6 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If.yield %4 : !torch.bool\n"
" }\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.gcd\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: aten.gcd works only with integer types\"\n"
" %false = torch.constant.bool false\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
" %4 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If.yield %4 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this part moved?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The file is auto-generated so I think it's a quirk of update_abstract_interp_lib.sh. Its also checked by ci here:

bash build_tools/ci/check_generated_sources.sh

" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list<int>\n"
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n"
" return %1 : !torch.bool\n"
" }\n"
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list<int> {\n"
" %int4 = torch.constant.int 4\n"
" %int3 = torch.constant.int 3\n"
" %int2 = torch.constant.int 2\n"
" %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n"
" %int11 = torch.constant.int 11\n"
" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.detach\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -11253,21 +11319,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %3 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n"
" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list<int>\n"
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n"
" return %1 : !torch.bool\n"
" }\n"
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list<int> {\n"
" %int4 = torch.constant.int 4\n"
" %int3 = torch.constant.int 3\n"
" %int2 = torch.constant.int 2\n"
" %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n"
" %int11 = torch.constant.int 11\n"
" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
Expand Down
6 changes: 6 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,9 @@
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
"SplitWithSizes_Module_basic",
"GCDBatchedModule_I32",
"GCDDynamicModule_I32",
"GCDModule_I32",
"Unfold_Module_basic",
"Unfold_Module_Rank_4",
"Unfold_Module_Rank_Zero_basic",
Expand Down Expand Up @@ -3166,6 +3169,9 @@
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"UnfoldModule_basic",
"GCDBatchedModule_I32",
"GCDDynamicModule_I32",
"GCDModule_I32",
"Unfold_Module_Rank_4",
"Unfold_Module_Rank_Zero_basic",
"Unfold_Module_Rank_Zero_Size_Zero_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,17 @@ def aten〇linalg_slogdet〡shape(A: List[int]) -> Tuple[List[int], List[int]]:
shape = upstream_shape_functions.zero_dim_tensor(A)
return shape, shape

def aten〇gcd〡shape(self: List[int], other: List[int]) -> List[int]:
assert self == other or (len(other) == 1 and other[0]==0), "Shapes must be the same or 'other' must be a single element tensor."
return self

def aten〇gcd〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
other_rank, other_dtype = other_rank_dtype
assert is_integer_dtype(self_dtype) and is_integer_dtype(other_dtype), "aten.gcd works only with integer types"
return self_dtype


def aten〇detach〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?) -> (Tensor)"
)
emit("aten::gcd : (Tensor, Tensor) -> (Tensor)", has_verifier=True)

# Functionalization ops
emit("aten::alias_copy : (Tensor) -> (Tensor)")
Expand Down
49 changes: 49 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6872,3 +6872,52 @@ def forward(self):
@register_test_case(module_factory=lambda: TrilIndicesOfssetGreaterThanRowModule())
def TrilIndicesOfssetGreaterThanRowModule_basic(module, tu: TestUtils):
module.forward()


# ==============================================================================


class GCDModule(torch.nn.Module):
@export
@annotate_args([None, [(4, 4), torch.int32, True], [(4, 4), torch.int32, True]])
def forward(self, A, B):
return torch.gcd(A, B)


@register_test_case(module_factory=lambda: GCDModule())
def GCDModule_I32(module, tu: TestUtils):
A = tu.rand(4, 4, low=-100, high=100).to(dtype=torch.int32)
B = tu.rand(4, 4, low=-100, high=100).to(dtype=torch.int32)
module.forward(A, B)


class GCDBatchedModule(torch.nn.Module):
@export
@annotate_args(
[None, [(4, 4, 4), torch.int32, True], [(4, 4, 4), torch.int32, True]]
)
def forward(self, A, B):
return torch.gcd(A, B)


@register_test_case(module_factory=lambda: GCDBatchedModule())
def GCDBatchedModule_I32(module, tu: TestUtils):
A = tu.rand(4, 4, 4, low=-100, high=100).to(dtype=torch.int32)
B = tu.rand(4, 4, 4, low=-100, high=100).to(dtype=torch.int32)
module.forward(A, B)


class GCDDynamicModule(torch.nn.Module):
@export
@annotate_args(
[None, [(-1, -1, -1), torch.int32, True], [(-1, -1, -1), torch.int32, True]]
)
def forward(self, A, B):
return torch.gcd(A, B)


@register_test_case(module_factory=lambda: GCDDynamicModule())
def GCDDynamicModule_I32(module, tu: TestUtils):
A = tu.rand(3, 4, 4, low=-100, high=100).to(dtype=torch.int32)
B = tu.rand(3, 4, 4, low=-100, high=100).to(dtype=torch.int32)
module.forward(A, B)
2 changes: 0 additions & 2 deletions projects/pt1/tools/e2e_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,4 @@ cd "$src_dir"

# Ensure PYTHONPATH is set for export to child processes, even if empty.
export PYTHONPATH=${PYTHONPATH-}
source $project_dir/.env

python -m e2e_testing.main "$@"
Loading