Skip to content

Commit 48005db

Browse files
committed
[Torch] Add support for aten.round.decimals op
* Added lowering to Linalg-on-Tensors * Added test to projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
1 parent 60379d7 commit 48005db

File tree

8 files changed

+147
-7
lines changed

8 files changed

+147
-7
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

+48
Original file line numberDiff line numberDiff line change
@@ -4562,6 +4562,54 @@ def Torch_AtenRound_Op : Torch_Op<"aten.round_", [
45624562
}];
45634563
}
45644564

4565+
def Torch_AtenRoundDecimalsOp : Torch_Op<"aten.round.decimals", [
4566+
AllowsTypeRefinement,
4567+
HasValueSemantics,
4568+
ReadOnly
4569+
]> {
4570+
let summary = "Generated op for `aten::round.decimals : (Tensor, int) -> (Tensor)`";
4571+
let arguments = (ins
4572+
AnyTorchTensorType:$self,
4573+
Torch_IntType:$decimals
4574+
);
4575+
let results = (outs
4576+
AnyTorchOptionalTensorType:$result
4577+
);
4578+
let hasCustomAssemblyFormat = 1;
4579+
let extraClassDefinition = [{
4580+
ParseResult AtenRoundDecimalsOp::parse(OpAsmParser &parser, OperationState &result) {
4581+
return parseDefaultTorchOp(parser, result, 2, 1);
4582+
}
4583+
void AtenRoundDecimalsOp::print(OpAsmPrinter &printer) {
4584+
printDefaultTorchOp(printer, *this, 2, 1);
4585+
}
4586+
}];
4587+
let hasFolder = 1;
4588+
}
4589+
4590+
def Torch_AtenRound_DecimalsOp : Torch_Op<"aten.round_.decimals", [
4591+
IsTrailingUnderscoreInplaceVariant,
4592+
AllowsTypeRefinement
4593+
]> {
4594+
let summary = "Generated op for `aten::round_.decimals : (Tensor, int) -> (Tensor)`";
4595+
let arguments = (ins
4596+
Torch_NonValueTensorType:$self,
4597+
Torch_IntType:$decimals
4598+
);
4599+
let results = (outs
4600+
AnyTorchOptionalNonValueTensorType:$result
4601+
);
4602+
let hasCustomAssemblyFormat = 1;
4603+
let extraClassDefinition = [{
4604+
ParseResult AtenRound_DecimalsOp::parse(OpAsmParser &parser, OperationState &result) {
4605+
return parseDefaultTorchOp(parser, result, 2, 1);
4606+
}
4607+
void AtenRound_DecimalsOp::print(OpAsmPrinter &printer) {
4608+
printDefaultTorchOp(printer, *this, 2, 1);
4609+
}
4610+
}];
4611+
}
4612+
45654613
def Torch_AtenTruncOp : Torch_Op<"aten.trunc", [
45664614
AllowsTypeRefinement,
45674615
HasValueSemantics,

lib/Conversion/TorchToLinalg/Uncategorized.cpp

+45-7
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,43 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
691691
}
692692
return b.create<math::RoundEvenOp>(loc, payloadArgs[0]);
693693
}
694+
if (auto round = dyn_cast<AtenRoundDecimalsOp>(op)) {
695+
// AtenRoundDecimalsOp is decomposed, if decimals is non-zero, as follow.
696+
// scale = 10 ** decimals
697+
// return round(x * scale) / scale
698+
if (!isa<mlir::FloatType>(
699+
cast<ValueTensorType>(round.getType()).getDtype())) {
700+
round.emitError("unimplemented: non-floating point dtype");
701+
return nullptr;
702+
}
703+
int64_t decimals;
704+
if (!matchPattern(op->getOperand(1), m_TorchConstantInt(&decimals))) {
705+
round.emitError("non-constant decimal point is not supported.");
706+
return nullptr;
707+
}
708+
709+
Value newOp = payloadArgs[0];
710+
Value scale;
711+
if (decimals) {
712+
auto elementType =
713+
cast<RankedTensorType>(
714+
converter->convertType(op->getOperand(0).getType()))
715+
.getElementType();
716+
717+
auto scalaVal = static_cast<float>(pow(10, decimals));
718+
scale = b.create<arith::ConstantOp>(
719+
loc, FloatAttr::get(elementType, scalaVal));
720+
newOp = b.create<arith::MulFOp>(loc, newOp, scale);
721+
}
722+
723+
newOp = b.create<math::RoundEvenOp>(loc, newOp);
724+
725+
if (decimals) {
726+
newOp = b.create<arith::DivFOp>(loc, newOp, scale);
727+
}
728+
729+
return newOp;
730+
}
694731
if (auto prelu = dyn_cast<AtenPreluOp>(op)) {
695732
if (!isa<mlir::FloatType>(
696733
cast<ValueTensorType>(prelu.getType()).getDtype())) {
@@ -1635,10 +1672,11 @@ class ConvertElementwiseOp : public ConversionPattern {
16351672
AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
16361673
AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp,
16371674
AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp,
1638-
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
1639-
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
1640-
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
1641-
AtenQuantizePerTensorOp, AtenIscloseOp>(op))
1675+
AtenRoundDecimalsOp, AtenFillScalarOp, AtenFillTensorOp,
1676+
AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp, AtenAsinOp,
1677+
AtenAsinhOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp,
1678+
AtenDequantizeTensorOp, AtenQuantizePerTensorOp, AtenIscloseOp>(
1679+
op))
16421680
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
16431681

16441682
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@@ -3988,9 +4026,9 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
39884026
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
39894027
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
39904028
AtenTrilOp, AtenRemainderScalarOp, AtenRemainderTensorOp,
3991-
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
3992-
AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
3993-
AtenQuantizePerTensorOp, AtenIscloseOp>();
4029+
AtenBitwiseNotOp, AtenRoundOp, AtenRoundDecimalsOp, AtenFillScalarOp,
4030+
AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp,
4031+
AtenDequantizeTensorOp, AtenQuantizePerTensorOp, AtenIscloseOp>();
39944032
patterns.add<ConvertElementwiseOp>(typeConverter, context);
39954033
target.addIllegalOp<AtenNllLossForwardOp>();
39964034
patterns.add<ConvertAtenDetachOp>(typeConverter, context);

lib/Dialect/Torch/IR/TorchOps.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -1992,6 +1992,19 @@ OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) {
19921992
return {};
19931993
}
19941994

1995+
//===----------------------------------------------------------------------===//
1996+
// AtenRoundDecimalsOp
1997+
//===----------------------------------------------------------------------===//
1998+
1999+
OpFoldResult AtenRoundDecimalsOp::fold(FoldAdaptor adaptor) {
2000+
auto resultType = dyn_cast<ValueTensorType>(getType());
2001+
if (resultType && resultType.hasDtype() &&
2002+
isa<mlir::IntegerType>(resultType.getDtype())) {
2003+
return getSelf();
2004+
}
2005+
return {};
2006+
}
2007+
19952008
//===----------------------------------------------------------------------===//
19962009
// AtenRoundOp
19972010
//===----------------------------------------------------------------------===//

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -6750,6 +6750,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
67506750
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
67516751
" return %0 : !torch.list<int>\n"
67526752
" }\n"
6753+
" func.func @\"__torch_mlir_shape_fn.aten.round.decimals\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
6754+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
6755+
" return %0 : !torch.list<int>\n"
6756+
" }\n"
67536757
" func.func @\"__torch_mlir_shape_fn.aten.glu\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
67546758
" %none = torch.constant.none\n"
67556759
" %str = torch.constant.str \"AssertionError: glu's dim size must be multiply of 2\"\n"
@@ -12896,6 +12900,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1289612900
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1289712901
" return %0#1 : !torch.int\n"
1289812902
" }\n"
12903+
" func.func @\"__torch_mlir_dtype_fn.aten.round.decimals\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
12904+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12905+
" return %0#1 : !torch.int\n"
12906+
" }\n"
1289912907
" func.func @\"__torch_mlir_dtype_fn.aten.glu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
1290012908
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1290112909
" return %0#1 : !torch.int\n"

projects/pt1/e2e_testing/xfail_sets.py

+2
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@
547547
"AtenKthvalueModule_basic",
548548
"AtenPolarDoubleModule_basic",
549549
"AtenPolarFloatModule_basic",
550+
"AtenRoundFloatDecimalsModule_basic",
550551
"DiagonalWithStaticShapeModule_basic",
551552
"EinsumStaticDiagonalDimensionModule_basic",
552553
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
@@ -3408,6 +3409,7 @@
34083409
"AtenSymConstrainRange_basic",
34093410
"Aten_AssertScalar_basic",
34103411
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
3412+
"AtenRoundFloatDecimalsModule_basic",
34113413
"ScatterAddDynamicModule_basic",
34123414
"UniformModule_basic",
34133415
"UniformStaticShapeModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

+8
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,9 @@ def aten〇relu6〡shape(self: List[int]) -> List[int]:
349349
def aten〇round〡shape(self: List[int]) -> List[int]:
350350
return upstream_shape_functions.unary(self)
351351

352+
def aten〇round〇decimals〡shape(self: List[int], decimals: int) -> List[int]:
353+
return upstream_shape_functions.unary(self)
354+
352355
def aten〇glu〡shape(self: List[int], dim: int = -1) -> List[int]:
353356
if dim < 0:
354357
dim += len(self)
@@ -3616,6 +3619,11 @@ def aten〇round〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
36163619
self_rank, self_dtype = self_rank_dtype
36173620
return self_dtype
36183621

3622+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, decimals=0))
3623+
def aten〇round〇decimals〡dtype(self_rank_dtype: Tuple[int, int], decimals: int) -> int:
3624+
self_rank, self_dtype = self_rank_dtype
3625+
return self_dtype
3626+
36193627
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(100,)], dim=0))
36203628
def aten〇glu〡dtype(self_rank_dtype: Tuple[int, int], dim: int = -1) -> int:
36213629
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

+3
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,9 @@ def emit_with_mutating_variants(key, **kwargs):
451451
emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True)
452452
emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True)
453453
emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True)
454+
emit_with_mutating_variants(
455+
"aten::round.decimals : (Tensor, int) -> (Tensor)", has_folder=True
456+
)
454457
emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True)
455458
emit("aten::special_expm1 : (Tensor) -> (Tensor)")
456459
emit_with_mutating_variants(

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

+20
Original file line numberDiff line numberDiff line change
@@ -6503,6 +6503,26 @@ def AtenRoundIntModule_basic(module, tu: TestUtils):
65036503
module.forward(tu.randint(5, 5, low=-10))
65046504

65056505

6506+
class AtenRoundFloatDecimalsModule(torch.nn.Module):
6507+
def __init__(self):
6508+
super().__init__()
6509+
6510+
@export
6511+
@annotate_args(
6512+
[
6513+
None,
6514+
([-1, -1], torch.float32, True),
6515+
]
6516+
)
6517+
def forward(self, x):
6518+
return torch.ops.aten.round(x, decimals=2)
6519+
6520+
6521+
@register_test_case(module_factory=lambda: AtenRoundFloatDecimalsModule())
6522+
def AtenRoundFloatDecimalsModule_basic(module, tu: TestUtils):
6523+
module.forward(tu.rand(5, 5, low=-3.0, high=3.0))
6524+
6525+
65066526
# ==============================================================================
65076527

65086528

0 commit comments

Comments
 (0)