@@ -691,6 +691,43 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
691
691
}
692
692
return b.create <math::RoundEvenOp>(loc, payloadArgs[0 ]);
693
693
}
694
+ if (auto round = dyn_cast<AtenRoundDecimalsOp>(op)) {
695
+ // AtenRoundDecimalsOp is decomposed, if decimals is non-zero, as follow.
696
+ // scale = 10 ** decimals
697
+ // return round(x * scale) / scale
698
+ if (!isa<mlir::FloatType>(
699
+ cast<ValueTensorType>(round .getType ()).getDtype ())) {
700
+ round .emitError (" unimplemented: non-floating point dtype" );
701
+ return nullptr ;
702
+ }
703
+ int64_t decimals;
704
+ if (!matchPattern (op->getOperand (1 ), m_TorchConstantInt (&decimals))) {
705
+ round .emitError (" non-constant decimal point is not supported." );
706
+ return nullptr ;
707
+ }
708
+
709
+ Value newOp = payloadArgs[0 ];
710
+ Value scale;
711
+ if (decimals) {
712
+ auto elementType =
713
+ cast<RankedTensorType>(
714
+ converter->convertType (op->getOperand (0 ).getType ()))
715
+ .getElementType ();
716
+
717
+ auto scalaVal = static_cast <float >(pow (10 , decimals));
718
+ scale = b.create <arith::ConstantOp>(
719
+ loc, FloatAttr::get (elementType, scalaVal));
720
+ newOp = b.create <arith::MulFOp>(loc, newOp, scale);
721
+ }
722
+
723
+ newOp = b.create <math::RoundEvenOp>(loc, newOp);
724
+
725
+ if (decimals) {
726
+ newOp = b.create <arith::DivFOp>(loc, newOp, scale);
727
+ }
728
+
729
+ return newOp;
730
+ }
694
731
if (auto prelu = dyn_cast<AtenPreluOp>(op)) {
695
732
if (!isa<mlir::FloatType>(
696
733
cast<ValueTensorType>(prelu.getType ()).getDtype ())) {
@@ -1635,10 +1672,11 @@ class ConvertElementwiseOp : public ConversionPattern {
1635
1672
AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
1636
1673
AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp,
1637
1674
AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp,
1638
- AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
1639
- AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
1640
- AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
1641
- AtenQuantizePerTensorOp, AtenIscloseOp>(op))
1675
+ AtenRoundDecimalsOp, AtenFillScalarOp, AtenFillTensorOp,
1676
+ AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp, AtenAsinOp,
1677
+ AtenAsinhOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp,
1678
+ AtenDequantizeTensorOp, AtenQuantizePerTensorOp, AtenIscloseOp>(
1679
+ op))
1642
1680
return rewriter.notifyMatchFailure (op, " not a supported elementwise op" );
1643
1681
1644
1682
if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
@@ -3988,9 +4026,9 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
3988
4026
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
3989
4027
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
3990
4028
AtenTrilOp, AtenRemainderScalarOp, AtenRemainderTensorOp,
3991
- AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp ,
3992
- AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp ,
3993
- AtenQuantizePerTensorOp, AtenIscloseOp>();
4029
+ AtenBitwiseNotOp, AtenRoundOp, AtenRoundDecimalsOp, AtenFillScalarOp ,
4030
+ AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp,
4031
+ AtenDequantizeTensorOp, AtenQuantizePerTensorOp, AtenIscloseOp>();
3994
4032
patterns.add <ConvertElementwiseOp>(typeConverter, context);
3995
4033
target.addIllegalOp <AtenNllLossForwardOp>();
3996
4034
patterns.add <ConvertAtenDetachOp>(typeConverter, context);
0 commit comments