Skip to content

Commit

Permalink
cggi: fix MSB/LSB ordering of lut2/lut3 inputs in lut_lincomb
Browse files Browse the repository at this point in the history
 canonicalization

cggi.lut2 and cggi.lut3's operands are ordered from MSB to LSB, but the lut_lincomb canonicalization uses a coeff vector of <1, 2 (,4)> instead of ordering the first bit in the MSB.

PiperOrigin-RevId: 702059440
  • Loading branch information
asraa authored and copybara-github committed Dec 2, 2024
1 parent 57c78fe commit 372219b
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions lib/Dialect/CGGI/IR/CGGIOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ std::optional<ValueRange> Lut2Op::getLookupTableInputs() {
}

LogicalResult Lut2Op::canonicalize(Lut2Op op, PatternRewriter &rewriter) {
SmallVector<int32_t> coeffs2 = {1, 2};
SmallVector<int32_t> coeffs2 = {2, 1};
auto createLutLinCombOp = rewriter.create<LutLinCombOp>(
op.getLoc(), op.getOutput().getType(), op.getOperands(), coeffs2,
op.getLookupTable());
Expand All @@ -33,7 +33,7 @@ std::optional<ValueRange> Lut3Op::getLookupTableInputs() {
}

LogicalResult Lut3Op::canonicalize(Lut3Op op, PatternRewriter &rewriter) {
SmallVector<int> coeffs3 = {1, 2, 4};
SmallVector<int> coeffs3 = {4, 2, 1};

auto createLutLinCombOp = rewriter.create<LutLinCombOp>(
op.getLoc(), op.getOutput().getType(), op.getOperands(), coeffs3,
Expand Down
4 changes: 2 additions & 2 deletions tests/Dialect/CGGI/Transforms/lut_canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func.func @require_post_pass_toposort_lut3(%arg0: tensor<8x!ct_ty>) -> !ct_ty {
%1 = tensor.extract %arg0[%c1] : tensor<8x!ct_ty>
%2 = tensor.extract %arg0[%c2] : tensor<8x!ct_ty>

// CHECK: cggi.lut_lincomb %extracted, %extracted_0, %extracted_1 {coefficients = array<i32: 1, 2, 4>, lookup_table = 8 : ui8}
// CHECK: cggi.lut_lincomb %extracted, %extracted_0, %extracted_1 {coefficients = array<i32: 4, 2, 1>, lookup_table = 8 : ui8}
%r1 = cggi.lut3 %0, %1, %2 {lookup_table = 8 : ui8} : !ct_ty

return %r1 : !ct_ty
Expand All @@ -27,7 +27,7 @@ func.func @require_post_pass_toposort_lut2(%arg0: tensor<8x!ct_ty>) -> !ct_ty {
%1 = tensor.extract %arg0[%c1] : tensor<8x!ct_ty>
%2 = tensor.extract %arg0[%c2] : tensor<8x!ct_ty>

// CHECK: cggi.lut_lincomb %extracted, %extracted_0 {coefficients = array<i32: 1, 2>, lookup_table = 8 : ui8}
// CHECK: cggi.lut_lincomb %extracted, %extracted_0 {coefficients = array<i32: 2, 1>, lookup_table = 8 : ui8}
%r1 = cggi.lut2 %0, %1 {lookup_table = 8 : ui8} : !ct_ty

return %r1 : !ct_ty
Expand Down
4 changes: 2 additions & 2 deletions tests/Dialect/CGGI/Transforms/straight_line_vectorizer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
// CHECK-LABEL: add_one
// CHECK-COUNT-9: cggi.lut3
// CHECK: cggi.lut3 %[[arg1:.*]], %[[arg2:.*]], %[[arg3:.*]] {lookup_table = 105 : ui8} : tensor<6x!lwe.lwe_ciphertext
// CANONICAL: cggi.lut_lincomb %[[arg1:.*]], %[[arg2:.*]], %[[arg3:.*]] {coefficients = array<i32: 1, 2, 4>, lookup_table = 105 : ui8} : tensor<6x!lwe.lwe_ciphertext
// CANONICAL: cggi.lut_lincomb %[[arg1:.*]], %[[arg2:.*]], %[[arg3:.*]] {coefficients = array<i32: 4, 2, 1>, lookup_table = 105 : ui8} : tensor<6x!lwe.lwe_ciphertext
func.func @add_one(%arg0: tensor<8x!ct_ty>) -> tensor<8x!ct_ty> {
%true = arith.constant true
%false = arith.constant false
Expand Down Expand Up @@ -94,7 +94,7 @@ func.func @require_post_pass_toposort(%arg0: tensor<8x!ct_ty>) -> tensor<8x!ct_t
// cggi.not occurring after its single result.


// CHECK-CANONICAL: cggi.lut_lincomb %[[arg1:.*]], %[[arg2:.*]], %[[arg3:.*]] {coefficients = array<i32: 1, 2, 4>, lookup_table = 8 : ui8} : tensor<7x!lwe.lwe_ciphertext
// CHECK-CANONICAL: cggi.lut_lincomb %[[arg1:.*]], %[[arg2:.*]], %[[arg3:.*]] {coefficients = array<i32: 4, 2, 1>, lookup_table = 8 : ui8} : tensor<7x!lwe.lwe_ciphertext
// CHECK-CANONICAL: cggi.not
// CHECK-CANONICAL: cggi.lut_lincomb
%from_elements = tensor.from_elements %r1, %r2, %r3, %r4, %r5, %r6, %r7, %x : tensor<8x!ct_ty>
Expand Down

0 comments on commit 372219b

Please sign in to comment.