Skip to content

[Torch] Add support for aten.round.decimals op #4166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mmanzoorTT
Copy link
Contributor

@mmanzoorTT mmanzoorTT commented Apr 28, 2025

  • Added decomposition to aten.round
  • Added test to projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

@mmanzoorTT
Copy link
Contributor Author

@vivekkhandelwal1

@vivekkhandelwal1 vivekkhandelwal1 changed the title [Torch] Add support for aten.reduce.decimals op [Torch] Add support for aten.round.decimals op Apr 30, 2025
Copy link
Collaborator

@vivekkhandelwal1 vivekkhandelwal1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some comments:

1.) The PR title was incorrect. It said [TORCH] Add support for aten.reduce.decimals op while the op added is aten.round.decimals op. Also, it says that the support is added but only a folder is added.
2.) The shape and dtype has not been added. Please do that.
3.) Add the lowering for the op which lowers to Linalg.
4.) Add an e2e test for the op.

@mmanzoorTT mmanzoorTT force-pushed the tenstorrent/torch-round-decimals branch 4 times, most recently from bc8dad7 to 48005db Compare May 14, 2025 02:32
@mmanzoorTT
Copy link
Contributor Author

@vivekkhandelwal1 Thanks a lot for the feedback. I have updated the PR as suggested. Please have a look

Comment on lines 694 to 730
if (auto round = dyn_cast<AtenRoundDecimalsOp>(op)) {
// AtenRoundDecimalsOp is decomposed, if decimals is non-zero, as follow.
// scale = 10 ** decimals
// return round(x * scale) / scale
if (!isa<mlir::FloatType>(
cast<ValueTensorType>(round.getType()).getDtype())) {
round.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
int64_t decimals;
if (!matchPattern(op->getOperand(1), m_TorchConstantInt(&decimals))) {
round.emitError("non-constant decimal point is not supported.");
return nullptr;
}

Value newOp = payloadArgs[0];
Value scale;
if (decimals) {
auto elementType =
cast<RankedTensorType>(
converter->convertType(op->getOperand(0).getType()))
.getElementType();

auto scalaVal = static_cast<float>(pow(10, decimals));
scale = b.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, scalaVal));
newOp = b.create<arith::MulFOp>(loc, newOp, scale);
}

newOp = b.create<math::RoundEvenOp>(loc, newOp);

if (decimals) {
newOp = b.create<arith::DivFOp>(loc, newOp, scale);
}

return newOp;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to add it to DecomposeComplexOps.cpp.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion @vivekkhandelwal1 . updated.

@mmanzoorTT mmanzoorTT force-pushed the tenstorrent/torch-round-decimals branch 2 times, most recently from 1684186 to f460ca2 Compare May 16, 2025 15:16
* Added decomposition to aten.round
* Added test to projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
@mmanzoorTT mmanzoorTT force-pushed the tenstorrent/torch-round-decimals branch from f460ca2 to fbc9352 Compare May 16, 2025 15:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants