-
Notifications
You must be signed in to change notification settings - Fork 553
[Torch] Add support for aten.round.decimals op #4166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
mmanzoorTT
commented
Apr 28, 2025
•
edited
Loading
edited
- Added decomposition to aten.round
- Added test to projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some comments:
1.) The PR title was incorrect. It said [TORCH] Add support for aten.reduce.decimals op
while the op added is aten.round.decimals
op. Also, it says that the support is added but only a folder is added.
2.) The shape and dtype has not been added. Please do that.
3.) Add the lowering for the op which lowers to Linalg.
4.) Add an e2e test for the op.
bc8dad7
to
48005db
Compare
@vivekkhandelwal1 Thanks a lot for the feedback. I have updated the PR as suggested. Please have a look |
if (auto round = dyn_cast<AtenRoundDecimalsOp>(op)) { | ||
// AtenRoundDecimalsOp is decomposed, if decimals is non-zero, as follow. | ||
// scale = 10 ** decimals | ||
// return round(x * scale) / scale | ||
if (!isa<mlir::FloatType>( | ||
cast<ValueTensorType>(round.getType()).getDtype())) { | ||
round.emitError("unimplemented: non-floating point dtype"); | ||
return nullptr; | ||
} | ||
int64_t decimals; | ||
if (!matchPattern(op->getOperand(1), m_TorchConstantInt(&decimals))) { | ||
round.emitError("non-constant decimal point is not supported."); | ||
return nullptr; | ||
} | ||
|
||
Value newOp = payloadArgs[0]; | ||
Value scale; | ||
if (decimals) { | ||
auto elementType = | ||
cast<RankedTensorType>( | ||
converter->convertType(op->getOperand(0).getType())) | ||
.getElementType(); | ||
|
||
auto scalaVal = static_cast<float>(pow(10, decimals)); | ||
scale = b.create<arith::ConstantOp>( | ||
loc, FloatAttr::get(elementType, scalaVal)); | ||
newOp = b.create<arith::MulFOp>(loc, newOp, scale); | ||
} | ||
|
||
newOp = b.create<math::RoundEvenOp>(loc, newOp); | ||
|
||
if (decimals) { | ||
newOp = b.create<arith::DivFOp>(loc, newOp, scale); | ||
} | ||
|
||
return newOp; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to add it to DecomposeComplexOps.cpp
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion @vivekkhandelwal1 . updated.
1684186
to
f460ca2
Compare
* Added decomposition to aten.round * Added test to projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
f460ca2
to
fbc9352
Compare