Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchToLinalg]Lower torch.gcd to linalg and scf #3732

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

bratislavSyrmia
Copy link

Add verify() method to check if tensors are of
integer type. Also check if tensors are of same shape, or if the second tensor is a single element tensor.

Add e2e tests. Put them into onnx and stablehlo
xfailed sets.

@bratislavSyrmia
Copy link
Author

force push: added math::cttz instead of counting trailing zeros manually

Add verify() method to check if tensors are of
integer type. Also check if tensors are of same shape,
or if the second tensor is a single element tensor.

Add e2e tests. Put them into onnx and stablehlo
xfailed sets.
@bratislavSyrmia bratislavSyrmia force-pushed the lower_torch_aten_gcd_to_linalg_and_scf branch from 0815cd1 to 53b1ec3 Compare October 1, 2024 09:18
Copy link
Collaborator

@vivekkhandelwal1 vivekkhandelwal1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bratislavSyrmia, I would like you to explore a non-loop path for this lowering, since these kind of lowerings usually causes issues in the downstream pipeline especially the code-generation part.

@IanWood1
Copy link
Contributor

Hi @bratislavSyrmia, I would like you to explore a non-loop path for this lowering, since these kind of lowerings usually causes issues in the downstream pipeline especially the code-generation part.

@vivekkhandelwal1 out of curiosity, do you have an algorithm/solution in mind? The best I could think of was to use linalg.generic's library_call attr and define a GCD function but maybe that runs into the same problem.

@vivekkhandelwal1
Copy link
Collaborator

Hi @bratislavSyrmia, I would like you to explore a non-loop path for this lowering, since these kind of lowerings usually causes issues in the downstream pipeline especially the code-generation part.

@vivekkhandelwal1 out of curiosity, do you have an algorithm/solution in mind? The best I could think of was to use linalg.generic's library_call attr and define a GCD function but maybe that runs into the same problem.

Hi @IanWood1, I did not spend time on thinking about it that's why I asked @bratislavSyrmia to explore the possibility of any such solution.

But if there exists a solution based on linalg.generic then it would still be a better approach then the current one.

@bratislavSyrmia
Copy link
Author

Hi @bratislavSyrmia, I would like you to explore a non-loop path for this lowering, since these kind of lowerings usually causes issues in the downstream pipeline especially the code-generation part.

@vivekkhandelwal1 out of curiosity, do you have an algorithm/solution in mind? The best I could think of was to use linalg.generic's library_call attr and define a GCD function but maybe that runs into the same problem.

Hi @IanWood1, I did not spend time on thinking about it that's why I asked @bratislavSyrmia to explore the possibility of any such solution.

But if there exists a solution based on linalg.generic then it would still be a better approach then the current one.

I have thought about it but I have no idea how I would find the greatest common divisor between two numbers without using loops

@bondhugula
Copy link

Hi @bratislavSyrmia, I would like you to explore a non-loop path for this lowering, since these kind of lowerings usually causes issues in the downstream pipeline especially the code-generation part.

@vivekkhandelwal1 out of curiosity, do you have an algorithm/solution in mind? The best I could think of was to use linalg.generic's library_call attr and define a GCD function but maybe that runs into the same problem.

Hi @IanWood1, I did not spend time on thinking about it that's why I asked @bratislavSyrmia to explore the possibility of any such solution.
But if there exists a solution based on linalg.generic then it would still be a better approach then the current one.

I have thought about it but I have no idea how I would find the greatest common divisor between two numbers without using loops

This PR is acceptable and reviewable as is and I don't think it should be blocked because of a downstream's inability to deal with it. The lowering output is perfectly valid IR and it's contributing an otherwise missing lowering. Using loops itself, with a while loop (Euclid's GCD), it'd take log n steps while one could use an alternative approach with a countable for loop that would take O(n) steps. Downstream users could potentially add a lowering that uses a "GCD" intrinsic call to link with a library if really needed.

auto other = adaptor.getOther(); // tensor B of the same size
auto loc = op.getLoc();

TensorType resultType =

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto

Comment on lines +233 to +288
auto gcdPayloadBody = [&](OpBuilder &b, Location loc,
ValueRange genericInstructionArgs) {
auto A = genericInstructionArgs[0];
A = b.create<mlir::math::AbsIOp>(loc, A);
auto B = genericInstructionArgs[1];
B = b.create<mlir::math::AbsIOp>(loc, B);
auto zero = b.create<mlir::arith::ConstantIntOp>(loc, 0, A.getType());

Value AtrailingZerosCount =
b.create<mlir::math::CountTrailingZerosOp>(loc, A);
Value BtrailingZerosCount =
b.create<mlir::math::CountTrailingZerosOp>(loc, B);
auto smalerZerosCount = b.create<mlir::arith::MinSIOp>(
loc, AtrailingZerosCount, BtrailingZerosCount);
auto shiftedA = b.create<mlir::arith::ShRSIOp>(loc, A, smalerZerosCount);
auto shiftedB = b.create<mlir::arith::ShRSIOp>(loc, B, smalerZerosCount);

auto findGcdConditionBlock = [&](mlir::OpBuilder &b, mlir::Location loc,
mlir::ValueRange innerLoopArgs) {
Value min = b.create<mlir::arith::MinSIOp>(loc, innerLoopArgs[0],
innerLoopArgs[1]);
Value max = b.create<mlir::arith::MaxSIOp>(loc, innerLoopArgs[0],
innerLoopArgs[1]);

auto cmp = b.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ne, min, zero);
b.create<mlir::scf::ConditionOp>(loc, cmp, ValueRange{min, max});
};
auto findGcdBodyBlock = [&](mlir::OpBuilder &b, mlir::Location loc,
mlir::ValueRange innerLoopArgs) {
Value min = innerLoopArgs[0];
Value max = innerLoopArgs[1];
max = b.create<mlir::arith::SubIOp>(loc, max, min);

Value maxTrailingZerosCount =
b.create<mlir::math::CountTrailingZerosOp>(loc, max);
max = b.create<mlir::arith::ShRSIOp>(loc, max, maxTrailingZerosCount);
b.create<mlir::scf::YieldOp>(loc, ValueRange{min, max});
};

auto findGcdWhileOp = b.create<mlir::scf::WhileOp>(
loc, TypeRange{shiftedA.getType(), shiftedB.getType()},
ValueRange{shiftedA, shiftedB}, findGcdConditionBlock,
findGcdBodyBlock);

Value gcdResult = findGcdWhileOp.getResult(1);
gcdResult =
b.create<mlir::arith::ShLIOp>(loc, gcdResult, smalerZerosCount);

b.create<linalg::YieldOp>(loc, gcdResult);
};

other = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, ValueRange{self, other},
cast<TensorType>(self.getType()).getElementType(), gcdPayloadBody);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing code comments for all the major blocks and a high-level description on the top for the lowering.

" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this part moved?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The file is auto-generated so I think it's a quirk of update_abstract_interp_lib.sh. Its also checked by ci here:

bash build_tools/ci/check_generated_sources.sh

Comment on lines +250 to +283
auto findGcdConditionBlock = [&](mlir::OpBuilder &b, mlir::Location loc,
mlir::ValueRange innerLoopArgs) {
Value min = b.create<mlir::arith::MinSIOp>(loc, innerLoopArgs[0],
innerLoopArgs[1]);
Value max = b.create<mlir::arith::MaxSIOp>(loc, innerLoopArgs[0],
innerLoopArgs[1]);

auto cmp = b.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ne, min, zero);
b.create<mlir::scf::ConditionOp>(loc, cmp, ValueRange{min, max});
};
auto findGcdBodyBlock = [&](mlir::OpBuilder &b, mlir::Location loc,
mlir::ValueRange innerLoopArgs) {
Value min = innerLoopArgs[0];
Value max = innerLoopArgs[1];
max = b.create<mlir::arith::SubIOp>(loc, max, min);

Value maxTrailingZerosCount =
b.create<mlir::math::CountTrailingZerosOp>(loc, max);
max = b.create<mlir::arith::ShRSIOp>(loc, max, maxTrailingZerosCount);
b.create<mlir::scf::YieldOp>(loc, ValueRange{min, max});
};

auto findGcdWhileOp = b.create<mlir::scf::WhileOp>(
loc, TypeRange{shiftedA.getType(), shiftedB.getType()},
ValueRange{shiftedA, shiftedB}, findGcdConditionBlock,
findGcdBodyBlock);

Value gcdResult = findGcdWhileOp.getResult(1);
gcdResult =
b.create<mlir::arith::ShLIOp>(loc, gcdResult, smalerZerosCount);

b.create<linalg::YieldOp>(loc, gcdResult);
};

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need mlir:: anywhere here since there is a namespace using for it at the top.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants