-
Notifications
You must be signed in to change notification settings - Fork 510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TorchToLinalg]Lower torch.gcd to linalg and scf #3732
base: main
Are you sure you want to change the base?
[TorchToLinalg]Lower torch.gcd to linalg and scf #3732
Conversation
7673a8f
to
0815cd1
Compare
force push: added math::cttz instead of counting trailing zeros manually |
projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Outdated
Show resolved
Hide resolved
Add verify() method to check if tensors are of integer type. Also check if tensors are of same shape, or if the second tensor is a single element tensor. Add e2e tests. Put them into onnx and stablehlo xfailed sets.
0815cd1
to
53b1ec3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @bratislavSyrmia, I would like you to explore a non-loop path for this lowering, since these kind of lowerings usually causes issues in the downstream pipeline especially the code-generation part.
@vivekkhandelwal1 out of curiosity, do you have an algorithm/solution in mind? The best I could think of was to use |
Hi @IanWood1, I did not spend time on thinking about it that's why I asked @bratislavSyrmia to explore the possibility of any such solution. But if there exists a solution based on |
I have thought about it but I have no idea how I would find the greatest common divisor between two numbers without using loops |
This PR is acceptable and reviewable as is and I don't think it should be blocked because of a downstream's inability to deal with it. The lowering output is perfectly valid IR and it's contributing an otherwise missing lowering. Using loops itself, with a |
auto other = adaptor.getOther(); // tensor B of the same size | ||
auto loc = op.getLoc(); | ||
|
||
TensorType resultType = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto
auto gcdPayloadBody = [&](OpBuilder &b, Location loc, | ||
ValueRange genericInstructionArgs) { | ||
auto A = genericInstructionArgs[0]; | ||
A = b.create<mlir::math::AbsIOp>(loc, A); | ||
auto B = genericInstructionArgs[1]; | ||
B = b.create<mlir::math::AbsIOp>(loc, B); | ||
auto zero = b.create<mlir::arith::ConstantIntOp>(loc, 0, A.getType()); | ||
|
||
Value AtrailingZerosCount = | ||
b.create<mlir::math::CountTrailingZerosOp>(loc, A); | ||
Value BtrailingZerosCount = | ||
b.create<mlir::math::CountTrailingZerosOp>(loc, B); | ||
auto smalerZerosCount = b.create<mlir::arith::MinSIOp>( | ||
loc, AtrailingZerosCount, BtrailingZerosCount); | ||
auto shiftedA = b.create<mlir::arith::ShRSIOp>(loc, A, smalerZerosCount); | ||
auto shiftedB = b.create<mlir::arith::ShRSIOp>(loc, B, smalerZerosCount); | ||
|
||
auto findGcdConditionBlock = [&](mlir::OpBuilder &b, mlir::Location loc, | ||
mlir::ValueRange innerLoopArgs) { | ||
Value min = b.create<mlir::arith::MinSIOp>(loc, innerLoopArgs[0], | ||
innerLoopArgs[1]); | ||
Value max = b.create<mlir::arith::MaxSIOp>(loc, innerLoopArgs[0], | ||
innerLoopArgs[1]); | ||
|
||
auto cmp = b.create<mlir::arith::CmpIOp>( | ||
loc, mlir::arith::CmpIPredicate::ne, min, zero); | ||
b.create<mlir::scf::ConditionOp>(loc, cmp, ValueRange{min, max}); | ||
}; | ||
auto findGcdBodyBlock = [&](mlir::OpBuilder &b, mlir::Location loc, | ||
mlir::ValueRange innerLoopArgs) { | ||
Value min = innerLoopArgs[0]; | ||
Value max = innerLoopArgs[1]; | ||
max = b.create<mlir::arith::SubIOp>(loc, max, min); | ||
|
||
Value maxTrailingZerosCount = | ||
b.create<mlir::math::CountTrailingZerosOp>(loc, max); | ||
max = b.create<mlir::arith::ShRSIOp>(loc, max, maxTrailingZerosCount); | ||
b.create<mlir::scf::YieldOp>(loc, ValueRange{min, max}); | ||
}; | ||
|
||
auto findGcdWhileOp = b.create<mlir::scf::WhileOp>( | ||
loc, TypeRange{shiftedA.getType(), shiftedB.getType()}, | ||
ValueRange{shiftedA, shiftedB}, findGcdConditionBlock, | ||
findGcdBodyBlock); | ||
|
||
Value gcdResult = findGcdWhileOp.getResult(1); | ||
gcdResult = | ||
b.create<mlir::arith::ShLIOp>(loc, gcdResult, smalerZerosCount); | ||
|
||
b.create<linalg::YieldOp>(loc, gcdResult); | ||
}; | ||
|
||
other = torch_to_linalg::createElementwiseLinalgGeneric( | ||
rewriter, loc, ValueRange{self, other}, | ||
cast<TensorType>(self.getType()).getElementType(), gcdPayloadBody); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing code comments for all the major blocks and a high-level description on the top for the lowering.
" }\n" | ||
" return %0#1 : !torch.int\n" | ||
" }\n" | ||
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this part moved?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The file is auto-generated so I think it's a quirk of update_abstract_interp_lib.sh
. Its also checked by ci here:
torch-mlir/.github/workflows/ci.yml
Line 77 in 99115dc
bash build_tools/ci/check_generated_sources.sh |
auto findGcdConditionBlock = [&](mlir::OpBuilder &b, mlir::Location loc, | ||
mlir::ValueRange innerLoopArgs) { | ||
Value min = b.create<mlir::arith::MinSIOp>(loc, innerLoopArgs[0], | ||
innerLoopArgs[1]); | ||
Value max = b.create<mlir::arith::MaxSIOp>(loc, innerLoopArgs[0], | ||
innerLoopArgs[1]); | ||
|
||
auto cmp = b.create<mlir::arith::CmpIOp>( | ||
loc, mlir::arith::CmpIPredicate::ne, min, zero); | ||
b.create<mlir::scf::ConditionOp>(loc, cmp, ValueRange{min, max}); | ||
}; | ||
auto findGcdBodyBlock = [&](mlir::OpBuilder &b, mlir::Location loc, | ||
mlir::ValueRange innerLoopArgs) { | ||
Value min = innerLoopArgs[0]; | ||
Value max = innerLoopArgs[1]; | ||
max = b.create<mlir::arith::SubIOp>(loc, max, min); | ||
|
||
Value maxTrailingZerosCount = | ||
b.create<mlir::math::CountTrailingZerosOp>(loc, max); | ||
max = b.create<mlir::arith::ShRSIOp>(loc, max, maxTrailingZerosCount); | ||
b.create<mlir::scf::YieldOp>(loc, ValueRange{min, max}); | ||
}; | ||
|
||
auto findGcdWhileOp = b.create<mlir::scf::WhileOp>( | ||
loc, TypeRange{shiftedA.getType(), shiftedB.getType()}, | ||
ValueRange{shiftedA, shiftedB}, findGcdConditionBlock, | ||
findGcdBodyBlock); | ||
|
||
Value gcdResult = findGcdWhileOp.getResult(1); | ||
gcdResult = | ||
b.create<mlir::arith::ShLIOp>(loc, gcdResult, smalerZerosCount); | ||
|
||
b.create<linalg::YieldOp>(loc, gcdResult); | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need mlir::
anywhere here since there is a namespace using
for it at the top.
Add verify() method to check if tensors are of
integer type. Also check if tensors are of same shape, or if the second tensor is a single element tensor.
Add e2e tests. Put them into onnx and stablehlo
xfailed sets.