diff --git a/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/BUILD b/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/BUILD index d0e495306e..1e4b53b053 100644 --- a/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/BUILD +++ b/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/BUILD @@ -14,7 +14,9 @@ cc_library( deps = [ ":pass_inc_gen", "@heir//lib/Analysis/SecretnessAnalysis", + "@heir//lib/Dialect/Secret/IR:Dialect", "@heir//lib/Dialect/TensorExt/IR:Dialect", + "@heir//lib/Utils:ConversionUtils", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:Analysis", diff --git a/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/LinalgToTensorExt.cpp b/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/LinalgToTensorExt.cpp index 4f151ea432..521c1fa45b 100644 --- a/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/LinalgToTensorExt.cpp +++ b/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/LinalgToTensorExt.cpp @@ -4,7 +4,10 @@ #include #include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" +#include "lib/Dialect/Secret/IR/SecretOps.h" +#include "lib/Dialect/Secret/IR/SecretTypes.h" #include "lib/Dialect/TensorExt/IR/TensorExtOps.h" +#include "lib/Utils/ConversionUtils.h" #include "llvm/include/llvm/ADT/APInt.h" // from @llvm-project #include "llvm/include/llvm/Support/Debug.h" // from @llvm-project #include "llvm/include/llvm/Support/LogicalResult.h" // from @llvm-project @@ -16,6 +19,7 @@ #include "mlir/include/mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project @@ -23,13 +27,14 @@ #include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#define DEBUG_TYPE "linalg-to-arith" +#define DEBUG_TYPE "linalg-to-tensor-ext" namespace mlir { namespace heir { @@ -46,19 +51,21 @@ int calculateIndexHelper(bool isLeftOperandSecret, int dim0, int dim1, int i, if (isLeftOperandSecret) { return ((i + j) % dim0) * dim1 + (j % dim1); } else { // right operand is secret - return (i % dim0) * dim0 + ((i + j) % dim1); + return (i % dim0) * dim1 + ((i + j) % dim1); } } template Value diagonalizeMatrix(ImplicitLocOpBuilder builder, - DenseElementsAttr denseAttr, bool isLeftOperandSecret) { - // Algorithm for diagonalizing the matrix: + DenseElementsAttr denseAttr, bool isLeftOperandSecret, + int maxTilingSize) { + // Algorithm for diagonalizing the matrix into a square matrix of size + // maxTilingSize x maxTilingSize. // There are two loops, an outer loop and an inner loop. // The outer loop is the for loop that goes from 0 to number of rows in the - // diagonalized, transposed matrix (transposedDimensions[0]). + // diagonalized, transposed matrix (maxTilingSize). // The inner loop is the for loop that goes from 0 to number of columns in - // the diagonalized, transposed matrix (transposedDimensions[1]). + // the diagonalized, transposed matrix (maxTilingSize). // At each iteration of the inner loop, we extract the correct diagonal // element from the matrix. Let's take an example: // @@ -85,8 +92,8 @@ Value diagonalizeMatrix(ImplicitLocOpBuilder builder, // multiplication (which is done via the helper function // calculateIndexHelper): // - // for i = 0 to transposedDimensions[0]: - // for j = 0 to transposedDimensions[1]: + // for i = 0 to maxTilingSize: + // for j = 0 to maxTilingSize: // row_index = (i + j) % dim0 // column_index = j % dim1 // index = row_index * dim1 + column_index @@ -98,16 +105,21 @@ Value diagonalizeMatrix(ImplicitLocOpBuilder builder, auto type = denseAttr.getElementType(); auto dims = denseAttr.getType().getShape(); + auto denseAttrValues = denseAttr.getValues(); auto dim0 = dims[0]; auto dim1 = dims[1]; - SmallVector transposedDimensions({dim1, dim0}); + SmallVector transposedDimensions({maxTilingSize, maxTilingSize}); SmallVector diagonalElements; diagonalElements.reserve(denseAttr.getNumElements()); - for (int i = 0; i < transposedDimensions[0]; ++i) { - for (int j = 0; j < transposedDimensions[1]; ++j) { + for (int i = 0; i < maxTilingSize; ++i) { + for (int j = 0; j < maxTilingSize; ++j) { int index = calculateIndexHelper(isLeftOperandSecret, dim0, dim1, i, j); - auto value = denseAttr.getValues()[index]; + LLVM_DEBUG({ + llvm::dbgs() << "i: " << i << ", j: " << j << ", index: " << index + << ", dim0: " << dim0 << ", dim1: " << dim1 << "\n"; + }); + auto value = denseAttrValues[index]; diagonalElements.push_back(value); } } @@ -117,31 +129,50 @@ Value diagonalizeMatrix(ImplicitLocOpBuilder builder, return builder.create(diagonalizedDenseElementsAttr); } +template +Value duplicateBias(ImplicitLocOpBuilder builder, DenseElementsAttr biasAttr, + bool isLeftOperandSecret, int maxTilingSize) { + auto type = biasAttr.getElementType(); + + int numElements = biasAttr.getNumElements(); + + SmallVector duplicatedDimensions({1, maxTilingSize}); + if (!isLeftOperandSecret) { + duplicatedDimensions = {maxTilingSize, 1}; + } + + SmallVector newBiasElements; + newBiasElements.reserve(maxTilingSize); + for (int i = 0; i < maxTilingSize; ++i) { + newBiasElements.push_back(biasAttr.getValues()[i % numElements]); + } + auto duplicatedBiasType = RankedTensorType::get(duplicatedDimensions, type); + auto duplicatedBiasDenseElementsAttr = + DenseElementsAttr::get(duplicatedBiasType, newBiasElements); + return builder.create(duplicatedBiasDenseElementsAttr); +} + template -Value multiplyDiagonalizedMatrixWithVector(ImplicitLocOpBuilder builder, - Value diagonalizedMatrix, - Value secretValues, Value bias, - bool isLeftOperandSecret) { +Value multiplyDiagonalizedMatrixWithVector( + ImplicitLocOpBuilder builder, Value diagonalizedMatrix, + ArrayRef originalMatrixDimensions, Value secretValues, Value bias, + bool isLeftOperandSecret, int maxTilingSize) { // The code below emits the following code for vector-matrix // multiplication (matrix-vector multiplication is similar): // %sum = bias // %rotated_vector = secretValues - // for %i = 0 to transposedDim0 - 1: - // %extractedSlice = extract_slice %newMatrix[%i, 0] [1, transposedDim1] [1, - // 1] - // %multiplied = %rotated_vector * %extractedSlice %sum = %sum + %multiplied + // for %i = 0 to originalMatrixDimensions[1] - 1: + // %extractedSlice = extract_slice %newMatrix[%i, 0] + // [1, originalMatrixDimensions[0]] [1, 1] + // %multiplied = %rotated_vector * %extractedSlice + // %sum = %sum + %multiplied // %rotated_vector = rotate %rotated_vector, 1 - // %lastExtracted = extract_slice %newMatrix[transposedDim0-1, 0] [1, - // transposedDim1] [1, 1] + // %lastExtracted = extract_slice %newMatrix[maxTilingSize-1, 0] [1, + // originalMatrixDimensions[0]] [1, 1] // %final_sum = %sum + %lastExtracted - // At this point, we can rotate and sum if needed. (Squat packing is left as a - // TODO until we resolve the shape mismatch issue.) + // At this point, we can rotate and sum if needed. // return %final_sum - auto shape = cast(diagonalizedMatrix.getType()).getShape(); - auto transposedDim0 = shape[0]; - auto transposedDim1 = shape[1]; - // Build a constant index 1. auto indexOne = builder.create(1); @@ -149,19 +180,17 @@ Value multiplyDiagonalizedMatrixWithVector(ImplicitLocOpBuilder builder, // ExtractSliceOp. SmallVector sizes(2); if (isLeftOperandSecret) { - sizes = {builder.getIndexAttr(1), builder.getIndexAttr(transposedDim1)}; + sizes = {builder.getIndexAttr(1), builder.getIndexAttr(maxTilingSize)}; } else { - sizes = {builder.getIndexAttr(transposedDim0), builder.getIndexAttr(1)}; + sizes = {builder.getIndexAttr(maxTilingSize), builder.getIndexAttr(1)}; } SmallVector strides(2, builder.getIndexAttr(1)); // Setup parameters for the affine for loop. SmallVector iterArgs({bias, secretValues}); - int numLoops; - if (isLeftOperandSecret) { - numLoops = transposedDim0; - } else { // right operand is secret - numLoops = transposedDim1; + int numLoops = originalMatrixDimensions[0]; + if (numLoops > originalMatrixDimensions[1]) { + numLoops = originalMatrixDimensions[1]; } // Build the affine for loop. @@ -204,11 +233,11 @@ Value multiplyDiagonalizedMatrixWithVector(ImplicitLocOpBuilder builder, // ExtractSliceOp. SmallVector lastOffsets(2); if (isLeftOperandSecret) { - lastOffsets = {builder.getIndexAttr(transposedDim1 - 1), + lastOffsets = {builder.getIndexAttr(originalMatrixDimensions[0] - 1), builder.getIndexAttr(0)}; } else { lastOffsets = {builder.getIndexAttr(0), - builder.getIndexAttr(transposedDim1 - 1)}; + builder.getIndexAttr(originalMatrixDimensions[0] - 1)}; } auto lastExtracted = builder.create( diagonalizedMatrix, lastOffsets, sizes, strides); @@ -216,25 +245,100 @@ Value multiplyDiagonalizedMatrixWithVector(ImplicitLocOpBuilder builder, // Calculates the final scalar multiplication and sum. auto lastMultiplied = builder.create(forOp.getResults()[1], lastExtracted); - auto finalSum = builder.create(forOp.getResults()[0], lastMultiplied); - return finalSum; + auto finalSumWithoutRotateAndSum = + builder.create(forOp.getResults()[0], lastMultiplied); + + int numRotationsAndSums; + if (isLeftOperandSecret) { + numRotationsAndSums = llvm::APInt(32, originalMatrixDimensions[0] / + originalMatrixDimensions[1]) + .exactLogBase2(); + } else { + numRotationsAndSums = llvm::APInt(32, originalMatrixDimensions[1] / + originalMatrixDimensions[0]) + .exactLogBase2(); + } + + // Rotate and sum if needed + Value sumInProgress = finalSumWithoutRotateAndSum; + int rotationValue = maxTilingSize; + for (int i = 0; i < numRotationsAndSums; ++i) { + rotationValue /= 2; + auto rotatedTensor = builder.create( + sumInProgress, + builder.create(builder.getIndexAttr(rotationValue))); + sumInProgress = builder.create(sumInProgress, rotatedTensor); + } + + return sumInProgress; } -struct ConvertLinalgMatmul : public OpRewritePattern { +class ReplicatedTensorTypeConverter : public TypeConverter { + private: + int maxTilingSize; + + public: + ReplicatedTensorTypeConverter(int maxTilingSize) + : maxTilingSize(maxTilingSize) { + addConversion([](Type type) { return type; }); + + addConversion([this](RankedTensorType type) -> Type { + // Assuming 2-d operations only + if (type.getShape()[0] == 1) { + return RankedTensorType::get({1, this->maxTilingSize}, + type.getElementType()); + } else if (type.getShape()[1] == 1) { + return RankedTensorType::get({this->maxTilingSize, 1}, + type.getElementType()); + } else { + return RankedTensorType::get({this->maxTilingSize, this->maxTilingSize}, + type.getElementType()); + } + }); + + // Convert secret tensors to secret tensors of the right size. + addConversion([this](secret::SecretType type) -> Type { + return secret::SecretType::get(this->convertType(type.getValueType())); + }); + } +}; + +struct SecretGenericOpLinalgMatmulConversion + : public OpConversionPattern { private: DataFlowSolver *solver; + int maxTilingSize; public: - ConvertLinalgMatmul(DataFlowSolver *solver, mlir::MLIRContext *context) - : OpRewritePattern(context), solver(solver) {} + using OpConversionPattern::OpConversionPattern; + + SecretGenericOpLinalgMatmulConversion(const TypeConverter &converter, + DataFlowSolver *solver, + mlir::MLIRContext *context, + int maxTilingSize) + : OpConversionPattern(converter, context), + solver(solver), + maxTilingSize(maxTilingSize) {} + + LogicalResult matchAndRewrite( + secret::GenericOp genericOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + if (genericOp.getBody()->getOperations().size() > 2) { + // Each secret.generic should contain at most one instruction - + // secret-distribute-generic can be used to distribute through the + // arithmetic ops. + return failure(); + } - using OpRewritePattern::OpRewritePattern; + auto &innerOp = genericOp.getBody()->getOperations().front(); + if (!isa(innerOp)) { + return failure(); + } - LogicalResult matchAndRewrite(mlir::linalg::MatmulOp op, - PatternRewriter &rewriter) const override { // Determine if the left or right operand is secret to determine which // matrix to diagonalize, or if both are secret or both are public, then // return failure. + mlir::linalg::MatmulOp op = cast(innerOp); auto isSecret = [&](Value value) { auto *operandLookup = solver->lookupState(value); Secretness operandSecretness = @@ -246,101 +350,130 @@ struct ConvertLinalgMatmul : public OpRewritePattern { bool isLeftOperandSecret = isSecret(op.getInputs()[0]); bool isRightOperandSecret = isSecret(op.getInputs()[1]); - LLVM_DEBUG({ - llvm::dbgs() << "Left operand is secret: " << isLeftOperandSecret << "\n" - << "Right operand is secret: " << isRightOperandSecret - << "\n"; - }); - - // Error out if both are secret or both are public - if ((isLeftOperandSecret && isRightOperandSecret) || - (!isLeftOperandSecret && !isRightOperandSecret)) { - return failure(); - } auto inputs = op.getInputs(); auto outputs = op.getOutputs(); // Assign if the left operand is secret - Value secretValues = inputs[0]; + int secretValuesIndex = 0; Value publicMatrix = inputs[1]; if (isRightOperandSecret) { - std::swap(secretValues, publicMatrix); + publicMatrix = inputs[0]; + secretValuesIndex = 1; } auto matrixTensorType = cast(publicMatrix.getType()); auto bias = outputs[0]; auto dimensions = matrixTensorType.getShape(); - int64_t dim0 = dimensions[0]; // This is the number of rows in the matrix. - int64_t dim1 = dimensions[1]; // This is the number of columns in the - // matrix. + int64_t dim0 = dimensions[0]; // This is the number of rows in the + // matrix + int64_t dim1 = dimensions[1]; // This is the number of columns + // in the matrix. - // If one of these dimensions is not a power of two, then we can't do the - // Halevi-Shoup or Squat Packing Matrix Multiplication conversion. + // If one of these dimensions is not a power of two, then we can't do + // the Halevi-Shoup or Squat Packing Matrix Multiplication conversion. if if (!isPowerOfTwo(dim0) || !isPowerOfTwo(dim1)) { return failure(); } - // If the matrix is not a square matrix, then we are doing squat packing. - // TODO: Implement squat packing. - if (dim0 != dim1) { - return failure(); - } - // Diagonalize the matrix only if the matrix is a constant. - auto constantValues = + auto matrixConstantValues = dyn_cast(publicMatrix.getDefiningOp()); - if (!constantValues) { + auto biasConstantValues = dyn_cast(bias.getDefiningOp()); + if (!matrixConstantValues || !biasConstantValues) { return failure(); } DenseElementsAttr denseAttr = - dyn_cast(constantValues.getValueAttr()); + dyn_cast(matrixConstantValues.getValueAttr()); + DenseElementsAttr biasAttr = + dyn_cast(biasConstantValues.getValueAttr()); // If the constant values doesn't have a dense attribute, then we can't // diagonalize the matrix. - if (!denseAttr) { + if (!denseAttr || !biasAttr) { return failure(); } auto type = denseAttr.getElementType(); + auto originalShape = denseAttr.getType().getShape(); + if (!type.isIntOrFloat()) { return failure(); } - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value result; - - // First, modify the matrix to be a diagonal matrix. We'll simply create a - // copy of the weight matrix diagonalized, and if the old weight matrix is - // not used, then dead code elimination pass will remove it. - - // After that, we create code for multiplying the matrix with rotations of - // the vector. - if (type.isInteger()) { - Value diagonalizedMatrix = - diagonalizeMatrix(b, denseAttr, isLeftOperandSecret); - - result = - multiplyDiagonalizedMatrixWithVector( - b, diagonalizedMatrix, secretValues, bias, isLeftOperandSecret); - } else { // floating point - Value diagonalizedMatrix = - diagonalizeMatrix(b, denseAttr, isLeftOperandSecret); - - result = - multiplyDiagonalizedMatrixWithVector( - b, diagonalizedMatrix, secretValues, bias, isLeftOperandSecret); + // Define local function pointers or lambdas that refer to the functions + auto diagMatrixInt = diagonalizeMatrix; + auto duplicateBiasInt = duplicateBias; + auto multDiagMatrixWithVectorInt = + multiplyDiagonalizedMatrixWithVector; + + auto diagMatrixFloat = diagonalizeMatrix; + auto duplicateBiasFloat = duplicateBias; + auto multDiagMatrixWithVectorFloat = + multiplyDiagonalizedMatrixWithVector; + + SmallVector genericOpInputs; + for (OpOperand &operand : innerOp.getOpOperands()) { + if (auto *secretArg = + genericOp.getOpOperandForBlockArgument(operand.get())) { + genericOpInputs.push_back( + adaptor.getODSOperands(0)[secretArg->getOperandNumber()]); + } else { + genericOpInputs.push_back(operand.get()); + } } - rewriter.replaceOp(op, result); + + SmallVector genericOpOutputTypes; + auto result = getTypeConverter()->convertTypes(genericOp.getResultTypes(), + genericOpOutputTypes); + if (failed(result)) return failure(); + + auto newGeneric = rewriter.create( + genericOp.getLoc(), genericOpInputs, genericOpOutputTypes, + [&](OpBuilder &builder, Location loc, ValueRange blockArguments) { + // The blockArguments should include the secret vector and public + // matrix. + + Value secretValues = blockArguments[secretValuesIndex]; + ImplicitLocOpBuilder b(loc, rewriter); + Value result; + + // Compute diagonalized matrix and duplicated bias inside the body. + if (type.isInteger()) { + Value diagonalizedMatrix = + diagMatrixInt(b, denseAttr, isLeftOperandSecret, maxTilingSize); + Value duplicatedBias = duplicateBiasInt( + b, biasAttr, isLeftOperandSecret, maxTilingSize); + result = multDiagMatrixWithVectorInt( + b, diagonalizedMatrix, originalShape, secretValues, + duplicatedBias, isLeftOperandSecret, maxTilingSize); + } else { // Floating point + Value diagonalizedMatrix = diagMatrixFloat( + b, denseAttr, isLeftOperandSecret, maxTilingSize); + Value duplicatedBias = duplicateBiasFloat( + b, biasAttr, isLeftOperandSecret, maxTilingSize); + result = multDiagMatrixWithVectorFloat( + b, diagonalizedMatrix, originalShape, secretValues, + duplicatedBias, isLeftOperandSecret, maxTilingSize); + } + // Yield the final result. + b.create(loc, result); + }); + + // Replace the original operation with the new genericOp + rewriter.replaceOp(genericOp, newGeneric); return success(); } }; struct LinalgToTensorExt : public impl::LinalgToTensorExtBase { + using LinalgToTensorExtBase::LinalgToTensorExtBase; + void runOnOperation() override { MLIRContext *context = &getContext(); auto *module = getOperation(); + ConversionTarget target(*context); DataFlowSolver solver; solver.load(); @@ -355,14 +488,43 @@ struct LinalgToTensorExt return; } + // TODO: loop through all of the secret values, figure out the tiling size. + // For now, take tilingSize as a command line argument. + + ReplicatedTensorTypeConverter replicatedTypeConverter(tilingSize); RewritePatternSet patterns(context); - patterns.add(&solver, context); + patterns.add( + replicatedTypeConverter, &solver, context, tilingSize); + target.addDynamicallyLegalOp([&](secret::GenericOp op) { + auto matmulOp = dyn_cast( + op.getBody()->getOperations().front()); + if (!matmulOp) { + return true; + } + auto isSecret = [&](Value value) { + auto *operandLookup = solver.lookupState(value); + Secretness operandSecretness = + operandLookup ? operandLookup->getValue() : Secretness(); + return (operandSecretness.isInitialized() && + operandSecretness.getSecretness()); + }; + + bool isLeftOperandSecret = isSecret(matmulOp.getInputs()[0]); + bool isRightOperandSecret = isSecret(matmulOp.getInputs()[1]); + + // Error out if both are secret or both are public + if ((isLeftOperandSecret && isRightOperandSecret) || + (!isLeftOperandSecret && !isRightOperandSecret)) { + return true; + } + return false; + }); + + addStructuralConversionPatterns(replicatedTypeConverter, patterns, target); // Run pattern matching and conversion - // TODO (#1221): Investigate whether folding (default: on) can be skipped - // here. - if (failed(applyPatternsGreedily(module, std::move(patterns)))) { + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { return signalPassFailure(); } } diff --git a/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/LinalgToTensorExt.td b/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/LinalgToTensorExt.td index 4f8e1e92a0..b9b549585f 100644 --- a/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/LinalgToTensorExt.td +++ b/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/LinalgToTensorExt.td @@ -9,10 +9,24 @@ def LinalgToTensorExt : Pass<"linalg-to-tensor-ext"> { let description = [{ This pass lowers the `linalg.matmul` to a mixture of affine, tensor, and via the Halevi-Shoup and squat matrix multiplication algorithms. + + We assume that the input and output values are replicated. This makes + aligning the matrix multiplications easier (though not necessarily optimal). + For example, when multiplying a 1x4 vector with a 4x2 matrix, the bias and output + will be a 1x2 vector. However, due to requiring tensor sizes to match, and + assuming replication, the matrix will be expanded to a 4x4 matrix and output + to a 1x4 vector (where the output is replicated twice). + + For now, the tilingSize is a command line parameter that determines the + maximum secret vector size used in the Halevi-Shoup and squat matrix + multiplication algorithms. It can be specified via --linalg-to-tensor-ext=tiling-size=16. }]; let dependentDialects = [ "mlir::heir::tensor_ext::TensorExtDialect", ]; + let options = [ + Option<"tilingSize", "tiling-size", "int", "16", "tiling size of the halevi-shoup and squat packing matrix multiplication algorithms"> + ]; } #endif // LIB_DIALECT_LINALG_CONVERSIONS_LINALGTOTENSOREXT_LINALGTOTENSOREXT_TD_ diff --git a/lib/Dialect/TensorExt/IR/TensorExtCanonicalization.td b/lib/Dialect/TensorExt/IR/TensorExtCanonicalization.td index 7675e9f2af..e68f25fc80 100644 --- a/lib/Dialect/TensorExt/IR/TensorExtCanonicalization.td +++ b/lib/Dialect/TensorExt/IR/TensorExtCanonicalization.td @@ -29,6 +29,9 @@ def DropZeroRotation : Pat< [(IsZeroIntAttr $c0)] >; +// Currently commented out because it doesn't work for multi-dimensional tensors. +// Will be uncommented and fixed by Asra's PR. Commenting this out causes various +// other tests to fail. // rotate %t, x -> rotate %t, x mod size def NormalizeRotationIndex : Pat< (TensorExt_RotateOp $tensor, (Arith_ConstantOp:$shiftOp APIntAttr:$shiftAmount)), diff --git a/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/float_small_fc_network.mlir b/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/float_small_fc_network.mlir new file mode 100644 index 0000000000..f0af9b7c48 --- /dev/null +++ b/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/float_small_fc_network.mlir @@ -0,0 +1,47 @@ +// This test verifies that a small fully connected network lowers with returning +// an error. +// TODO: write a test that verifies the correctness of the lowering. + +// RUN: heir-opt %s --linalg-to-tensor-ext=tiling-size=4 --tosa-to-secret-arith --canonicalize | FileCheck %s + +// CHECK: func.func @test_float_small_fc_network(%[[ARG:.*]]: !secret.secret>) +module { +func.func @test_float_small_fc_network(%input : !secret.secret>) -> !secret.secret> { + %matrix1 = arith.constant dense<[[1.0, 2.0, 3.0, 4.0]]> : tensor<1x4xf32> + %bias1 = arith.constant dense<[[5.0, 6.0, 7.0, 8.0]]> : tensor<1x4xf32> + %layer1 = secret.generic ins (%input : !secret.secret>) { + ^bb0(%converted_input1: tensor<1x1xf32>): + %0 = linalg.matmul ins(%converted_input1, %matrix1 : tensor<1x1xf32>, tensor<1x4xf32>) outs(%bias1 : tensor<1x4xf32>) -> tensor<1x4xf32> + secret.yield %0 : tensor<1x4xf32> + } -> !secret.secret> + + %activation_layer1 = secret.generic ins (%layer1 : !secret.secret>) { + ^bb0(%converted_activation_layer_vec1: tensor<1x4xf32>): + %0 = tosa.sigmoid %converted_activation_layer_vec1 : (tensor<1x4xf32>) -> tensor<1x4xf32> + secret.yield %0 : tensor<1x4xf32> + } -> !secret.secret> + + %matrix2 = arith.constant dense<[[10.0, 20.0, 30.0, 40.0], [50.0, 60.0, 70.0, 80.0], [90.0, 100.0, 110.0, 120.0], [130.0, 140.0, 150.0, 160.0]]> : tensor<4x4xf32> + %bias2 = arith.constant dense<[[170.0, 180.0, 190.0, 200.0]]> : tensor<1x4xf32> + %layer2 = secret.generic ins (%layer1 : !secret.secret>) { + ^bb0(%converted_vec2: tensor<1x4xf32>): + %1 = linalg.matmul ins(%converted_vec2, %matrix2 : tensor<1x4xf32>, tensor<4x4xf32>) outs(%bias2 : tensor<1x4xf32>) -> tensor<1x4xf32> + secret.yield %1 : tensor<1x4xf32> + } -> !secret.secret> + + %activation_layer2 = secret.generic ins (%layer2 : !secret.secret>) { + ^bb0(%converted_activation_layer_vec2: tensor<1x4xf32>): + %0 = tosa.sigmoid %converted_activation_layer_vec2 : (tensor<1x4xf32>) -> tensor<1x4xf32> + secret.yield %0 : tensor<1x4xf32> + } -> !secret.secret> + + %matrix3 = arith.constant dense<[[100.0], [200.0], [300.0], [400.0]]> : tensor<4x1xf32> + %bias3 = arith.constant dense<[[500.0]]> : tensor<1x1xf32> + %layer3 = secret.generic ins (%activation_layer2 : !secret.secret>) { + ^bb0(%converted_vec3: tensor<1x4xf32>): + %0 = linalg.matmul ins(%converted_vec3, %matrix3 : tensor<1x4xf32>, tensor<4x1xf32>) outs(%bias3 : tensor<1x1xf32>) -> tensor<1x1xf32> + secret.yield %0 : tensor<1x1xf32> + } -> !secret.secret> + return %layer3 : !secret.secret> +} +} diff --git a/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/float_vector_small_matrix_matmul.mlir b/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/float_vector_small_matrix_matmul.mlir new file mode 100644 index 0000000000..c73ea870e2 --- /dev/null +++ b/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/float_vector_small_matrix_matmul.mlir @@ -0,0 +1,33 @@ +// RUN: heir-opt %s --linalg-to-tensor-ext=tiling-size=4 --canonicalize | FileCheck %s + +// CHECK: func.func @test_float_vector_small_matrix_matmul(%[[ARG:.*]]: !secret.secret>) +// CHECK-DAG: %[[TWO:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[ONE:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[BIAS:.*]] = arith.constant dense<5.{{0*}}e+00> : tensor<1x4xf32> +// CHECK-DAG: %[[DIAGONALIZED_MATRIX:.*]] = arith.constant dense +// CHECK-SAME{LITERAL}: <[[ +// CHECK-SAME: 1.{{0*}}e+00, 2.{{0*}}e+00, 3.{{0*}}e+00, 4.{{0*}}e+00], [2.{{0*}}e+00, 3.{{0*}}e+00, 4.{{0*}}e+00, 1.{{0*}}e+00], [3.{{0*}}e+00, 4.{{0*}}e+00, 1.{{0*}}e+00, 2.{{0*}}e+00], [4.{{0*}}e+00, 1.{{0*}}e+00, 2.{{0*}}e+00, 3.{{0*}}e+00 +// CHECK-SAME{LITERAL}: ]]> +// CHECK-DAG: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][3, 0] [1, 4] [1, 1] +// CHECK: %[[OUT:.*]] = secret.generic ins(%[[ARG]] : !secret.secret>) +// CHECK: ^bb0(%[[ARG_CONVERTED:.*]]: tensor<1x4xf32>): +// CHECK: %[[MUL:.*]] = arith.mulf %[[ARG_CONVERTED]], %[[SLICE]] +// CHECK: %[[SUM:.*]] = arith.addf %[[MUL]], %[[BIAS]] +// CHECK: %[[ROTATE1:.*]] = tensor_ext.rotate %[[SUM]], %[[TWO]] +// CHECK: %[[ROTATE_AND_SUM_1:.*]] = arith.addf %[[SUM]], %[[ROTATE1]] +// CHECK: %[[ROTATE2:.*]] = tensor_ext.rotate %[[ROTATE_AND_SUM_1]], %[[ONE]] +// CHECK: %[[FINAL_SUM:.*]] = arith.addf %[[ROTATE_AND_SUM_1]], %[[ROTATE2]] +// CHECK: secret.yield %[[FINAL_SUM]] +// CHECK: return %[[OUT]] +module { +func.func @test_float_vector_small_matrix_matmul(%vec : !secret.secret>) -> !secret.secret> { + %matrix = arith.constant dense<[[1.0], [2.0], [3.0], [4.0]]> : tensor<4x1xf32> + %bias = arith.constant dense<[[5.0]]> : tensor<1x1xf32> + %out = secret.generic ins (%vec : !secret.secret>) { + ^bb0(%converted_vec: tensor<1x4xf32>): + %0 = linalg.matmul ins(%converted_vec, %matrix : tensor<1x4xf32>, tensor<4x1xf32>) outs(%bias : tensor<1x1xf32>) -> tensor<1x1xf32> + secret.yield %0 : tensor<1x1xf32> + } -> !secret.secret> + return %out : !secret.secret> +} +} diff --git a/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/float_vector_square_matrix_matmul_op.mlir b/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/float_vector_square_matrix_matmul_op.mlir index 98469ff14f..2050ed9728 100644 --- a/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/float_vector_square_matrix_matmul_op.mlir +++ b/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/float_vector_square_matrix_matmul_op.mlir @@ -1,30 +1,30 @@ -// RUN: heir-opt %s --linalg-to-tensor-ext | FileCheck %s +// RUN: heir-opt %s --linalg-to-tensor-ext=tiling-size=4 --canonicalize | FileCheck %s -// CHECK: func.func @test_float_vector_square_matrix_linalg_to_arith(%[[ARG:.*]]: !secret.secret>) -// CHECK-DAG: %[[ONE:.*]] = arith.constant 1 : index -// CHECK: %[[DIAGONALIZED_MATRIX:.*]] = arith.constant dense +// CHECK: func.func @test_float_vector_square_matrix_matmul(%[[ARG:.*]]: !secret.secret>) +// CHECK-DAG: %[[ONE:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[DIAGONALIZED_MATRIX:.*]] = arith.constant dense // CHECK-SAME{LITERAL}: <[[ // CHECK-SAME: 1.{{0*}}e+00, 6.{{0*}}e+00, 1.1{{0*}}e+01, 1.6{{0*}}e+01], [5.{{0*}}e+00, 1.{{0*}}e+01, 1.5{{0*}}e+01, 4.{{0*}}e+00], [9.{{0*}}e+00, 1.4{{0*}}e+01, 3.{{0*}}e+00, 8.{{0*}}e+00], [1.3{{0*}}e+01, 2.{{0*}}e+00, 7.{{0*}}e+00, 1.2{{0*}}e+01 // CHECK-SAME{LITERAL}: ]]> -// CHECK: %[[BIAS:.*]] = arith.constant dense +// CHECK-DAG: %[[BIAS:.*]] = arith.constant dense // CHECK-SAME{LITERAL}: <[[ // CHECK-SAME: 1.7{{0*}}e+01, 1.8{{0*}}e+01, 1.9{{0*}}e+01, 2.{{0*}}e+01 // CHECK-SAME{LITERAL}: ]]> -// CHECK: %[[OUT:.*]] = secret.generic ins(%[[ARG]] : !secret.secret>) -// CHECK: ^bb0(%[[ARG_CONVERTED:.*]]: tensor<1x4xf16>): -// CHECK: %[[FOR_LOOP_OUT:.*]]:2 = affine.for %[[I:.*]] = 0 to 3 iter_args(%[[RUNNING_SUM:.*]] = %[[BIAS]], %[[ROTATED_VEC:.*]] = %[[ARG_CONVERTED]]) -// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][%[[I]], 0] [1, 4] [1, 1] -// CHECK: %[[MUL:.*]] = arith.mulf %[[ROTATED_VEC]], %[[SLICE]] -// CHECK: %[[UPDATED_SUM:.*]] = arith.addf %[[RUNNING_SUM]], %[[MUL]] -// CHECK: %[[UPDATED_ROTATED_VEC:.*]] = tensor_ext.rotate %[[ROTATED_VEC]], %[[ONE]] -// CHECK: affine.yield %[[UPDATED_SUM]], %[[UPDATED_ROTATED_VEC]] -// CHECK: %[[LAST_SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][3, 0] [1, 4] [1, 1] -// CHECK: %[[LAST_MUL:.*]] = arith.mulf %[[FOR_LOOP_OUT]]#1, %[[LAST_SLICE]] -// CHECK: %[[FINAL_SUM:.*]] = arith.addf %[[FOR_LOOP_OUT]]#0, %[[LAST_MUL]] -// CHECK: secret.yield %[[FINAL_SUM]] -// CHECK: return %[[OUT]] +// CHECK-DAG: %[[LAST_SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][3, 0] [1, 4] [1, 1] +// CHECK: %[[OUT:.*]] = secret.generic ins(%[[ARG]] : !secret.secret>) +// CHECK: ^bb0(%[[ARG_CONVERTED:.*]]: tensor<1x4xf16>): +// CHECK: %[[FOR_LOOP_OUT:.*]]:2 = affine.for %[[I:.*]] = 0 to 3 iter_args(%[[RUNNING_SUM:.*]] = %[[BIAS]], %[[ROTATED_VEC:.*]] = %[[ARG_CONVERTED]]) +// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][%[[I]], 0] [1, 4] [1, 1] +// CHECK: %[[MUL:.*]] = arith.mulf %[[ROTATED_VEC]], %[[SLICE]] +// CHECK: %[[UPDATED_SUM:.*]] = arith.addf %[[RUNNING_SUM]], %[[MUL]] +// CHECK: %[[UPDATED_ROTATED_VEC:.*]] = tensor_ext.rotate %[[ROTATED_VEC]], %[[ONE]] +// CHECK: affine.yield %[[UPDATED_SUM]], %[[UPDATED_ROTATED_VEC]] +// CHECK: %[[LAST_MUL:.*]] = arith.mulf %[[FOR_LOOP_OUT]]#1, %[[LAST_SLICE]] +// CHECK: %[[FINAL_SUM:.*]] = arith.addf %[[FOR_LOOP_OUT]]#0, %[[LAST_MUL]] +// CHECK: secret.yield %[[FINAL_SUM]] +// CHECK: return %[[OUT]] module { -func.func @test_float_vector_square_matrix_linalg_to_arith(%vec : !secret.secret>) -> !secret.secret> { +func.func @test_float_vector_square_matrix_matmul(%vec : !secret.secret>) -> !secret.secret> { %matrix = arith.constant dense<[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]> : tensor<4x4xf16> %bias = arith.constant dense<[[17.0, 18.0, 19.0, 20.0]]> : tensor<1x4xf16> %out = secret.generic ins (%vec : !secret.secret>) { diff --git a/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/integer_rect_matrix_vector_matmul_op.mlir b/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/integer_rect_matrix_vector_matmul_op.mlir new file mode 100644 index 0000000000..2f5f9aba9c --- /dev/null +++ b/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/integer_rect_matrix_vector_matmul_op.mlir @@ -0,0 +1,36 @@ +// RUN: heir-opt %s --linalg-to-tensor-ext=tiling-size=4 --canonicalize | FileCheck %s + +// CHECK: func.func @test_integer_rect_matrix_vector_matmul(%[[ARG:.*]]: !secret.secret>) +// CHECK-DAG: %[[ONE:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[TWO:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[BIAS:.*]] = arith.constant dense +// CHECK-SAME{LITERAL}: <[[17], [18], [17], [18]]> : tensor<4x1xi16> +// CHECK-DAG: %[[DIAGONALIZED_MATRIX:.*]] = arith.constant dense +// CHECK-SAME{LITERAL}: <[[1, 2, 3, 4], [6, 7, 8, 5], [3, 4, 1, 2], [8, 5, 6, 7]]> : tensor<4x4xi16> +// CHECK-DAG: %[[LAST_SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, 1] [4, 1] [1, 1] +// CHECK: %[[OUT:.*]] = secret.generic ins(%[[ARG]] : !secret.secret>) +// CHECK: ^bb0(%[[ARG_CONVERTED:.*]]: tensor<4x1xi16>): +// CHECK: %[[FOR_LOOP_OUT:.*]]:2 = affine.for %[[I:.*]] = 0 to 1 iter_args(%[[RUNNING_SUM:.*]] = %[[BIAS]], %[[ROTATED_VEC:.*]] = %[[ARG_CONVERTED]]) +// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, %[[I]]] [4, 1] [1, 1] +// CHECK: %[[MUL:.*]] = arith.muli %[[ROTATED_VEC]], %[[SLICE]] +// CHECK: %[[UPDATED_SUM:.*]] = arith.addi %[[RUNNING_SUM]], %[[MUL]] +// CHECK: %[[UPDATED_ROTATED_VEC:.*]] = tensor_ext.rotate %[[ROTATED_VEC]], %[[ONE]] +// CHECK: affine.yield %[[UPDATED_SUM]], %[[UPDATED_ROTATED_VEC]] +// CHECK: %[[LAST_MUL:.*]] = arith.muli %[[FOR_LOOP_OUT]]#1, %[[LAST_SLICE]] +// CHECK: %[[BEFORE_ROTATE_AND_SUM:.*]] = arith.addi %[[FOR_LOOP_OUT]]#0, %[[LAST_MUL]] +// CHECK: %[[ROTATED_SUM:.*]] = tensor_ext.rotate %[[BEFORE_ROTATE_AND_SUM]], %[[TWO]] +// CHECK: %[[FINAL_SUM:.*]] = arith.addi %[[BEFORE_ROTATE_AND_SUM]], %[[ROTATED_SUM]] +// CHECK: secret.yield %[[FINAL_SUM]] +// CHECK: return %[[OUT]] +module { +func.func @test_integer_rect_matrix_vector_matmul(%vec : !secret.secret>) -> !secret.secret> { + %matrix = arith.constant dense<[[1, 2, 3, 4], [5, 6, 7, 8]]> : tensor<2x4xi16> + %bias = arith.constant dense<[[17], [18]]> : tensor<2x1xi16> + %out = secret.generic ins (%vec : !secret.secret>) { + ^bb0(%converted_vec: tensor<4x1xi16>): + %0 = linalg.matmul ins(%matrix, %converted_vec : tensor<2x4xi16>, tensor<4x1xi16>) outs(%bias : tensor<2x1xi16>) -> tensor<2x1xi16> + secret.yield %0 : tensor<2x1xi16> + } -> !secret.secret> + return %out : !secret.secret> +} +} diff --git a/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/integer_small_vector_matrix_matmul_op.mlir b/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/integer_small_vector_matrix_matmul_op.mlir new file mode 100644 index 0000000000..25ea461c66 --- /dev/null +++ b/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/integer_small_vector_matrix_matmul_op.mlir @@ -0,0 +1,33 @@ +// RUN: heir-opt %s --linalg-to-tensor-ext=tiling-size=4 --canonicalize | FileCheck %s + +// CHECK: func.func @test_integer_square_matrix_vector_matmul(%[[ARG:.*]]: !secret.secret>) +// CHECK-DAG: %[[ONE:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[DIAGONALIZED_MATRIX:.*]] = arith.constant dense +// CHECK-SAME{LITERAL}: <[[1, 2, 3, 4], [6, 7, 8, 5], [11, 12, 9, 10], [16, 13, 14, 15]]> : tensor<4x4xi16> +// CHECK-DAG: %[[BIAS:.*]] = arith.constant dense +// CHECK-SAME{LITERAL}: <[[17], [18], [19], [20]]> : tensor<4x1xi16> +// CHECK-DAG: %[[LAST_SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, 3] [4, 1] [1, 1] +// CHECK: %[[OUT:.*]] = secret.generic ins(%[[ARG]] : !secret.secret>) +// CHECK: ^bb0(%[[ARG_CONVERTED:.*]]: tensor<4x1xi16>): +// CHECK: %[[FOR_LOOP_OUT:.*]]:2 = affine.for %[[I:.*]] = 0 to 3 iter_args(%[[RUNNING_SUM:.*]] = %[[BIAS]], %[[ROTATED_VEC:.*]] = %[[ARG_CONVERTED]]) +// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, %[[I]]] [4, 1] [1, 1] +// CHECK: %[[MUL:.*]] = arith.muli %[[ROTATED_VEC]], %[[SLICE]] +// CHECK: %[[UPDATED_SUM:.*]] = arith.addi %[[RUNNING_SUM]], %[[MUL]] +// CHECK: %[[UPDATED_ROTATED_VEC:.*]] = tensor_ext.rotate %[[ROTATED_VEC]], %[[ONE]] +// CHECK: affine.yield %[[UPDATED_SUM]], %[[UPDATED_ROTATED_VEC]] +// CHECK: %[[LAST_MUL:.*]] = arith.muli %[[FOR_LOOP_OUT]]#1, %[[LAST_SLICE]] +// CHECK: %[[FINAL_SUM:.*]] = arith.addi %[[FOR_LOOP_OUT]]#0, %[[LAST_MUL]] +// CHECK: secret.yield %[[FINAL_SUM]] +// CHECK: return %[[OUT]] +module { +func.func @test_integer_square_matrix_vector_matmul(%vec : !secret.secret>) -> !secret.secret> { + %matrix = arith.constant dense<[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]> : tensor<4x4xi16> + %bias = arith.constant dense<[[17], [18], [19], [20]]> : tensor<4x1xi16> + %out = secret.generic ins (%vec : !secret.secret>) { + ^bb0(%converted_vec: tensor<4x1xi16>): + %0 = linalg.matmul ins(%matrix, %converted_vec : tensor<4x4xi16>, tensor<4x1xi16>) outs(%bias : tensor<4x1xi16>) -> tensor<4x1xi16> + secret.yield %0 : tensor<4x1xi16> + } -> !secret.secret> + return %out : !secret.secret> +} +} diff --git a/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/integer_square_matrix_vector_matmul_op.mlir b/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/integer_square_matrix_vector_matmul_op.mlir index 82c8855806..7dfe8f19c8 100644 --- a/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/integer_square_matrix_vector_matmul_op.mlir +++ b/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/integer_square_matrix_vector_matmul_op.mlir @@ -1,33 +1,33 @@ -// RUN: heir-opt %s --linalg-to-tensor-ext | FileCheck %s +// RUN: heir-opt %s --linalg-to-tensor-ext=tiling-size=4 --canonicalize | FileCheck %s -// CHECK: func.func @test_integer_square_matrix_vector_linalg_to_arith(%[[ARG:.*]]: !secret.secret>) -// CHECK-DAG: %[[ONE:.*]] = arith.constant 1 : index -// CHECK: %[[DIAGONALIZED_MATRIX:.*]] = arith.constant dense -// CHECK-SAME{LITERAL}: <[[1, 2, 3, 4], [6, 7, 8, 5], [11, 12, 9, 10], [16, 13, 14, 15]]> : tensor<4x4xi16> -// CHECK: %[[BIAS:.*]] = arith.constant dense -// CHECK-SAME{LITERAL}: <[[17], [18], [19], [20]]> : tensor<4x1xi16> -// CHECK: %[[OUT:.*]] = secret.generic ins(%[[ARG]] : !secret.secret>) -// CHECK: ^bb0(%[[ARG_CONVERTED:.*]]: tensor<4x1xi16>): -// CHECK: %[[FOR_LOOP_OUT:.*]]:2 = affine.for %[[I:.*]] = 0 to 3 iter_args(%[[RUNNING_SUM:.*]] = %[[BIAS]], %[[ROTATED_VEC:.*]] = %[[ARG_CONVERTED]]) -// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, %[[I]]] [4, 1] [1, 1] -// CHECK: %[[MUL:.*]] = arith.muli %[[ROTATED_VEC]], %[[SLICE]] -// CHECK: %[[UPDATED_SUM:.*]] = arith.addi %[[RUNNING_SUM]], %[[MUL]] -// CHECK: %[[UPDATED_ROTATED_VEC:.*]] = tensor_ext.rotate %[[ROTATED_VEC]], %[[ONE]] -// CHECK: affine.yield %[[UPDATED_SUM]], %[[UPDATED_ROTATED_VEC]] -// CHECK: %[[LAST_SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][0, 3] [4, 1] [1, 1] -// CHECK: %[[LAST_MUL:.*]] = arith.muli %[[FOR_LOOP_OUT]]#1, %[[LAST_SLICE]] -// CHECK: %[[FINAL_SUM:.*]] = arith.addi %[[FOR_LOOP_OUT]]#0, %[[LAST_MUL]] -// CHECK: secret.yield %[[FINAL_SUM]] -// CHECK: return %[[OUT]] +// CHECK: func.func @test_integer_vector_square_matrix_matmul(%[[ARG:.*]]: !secret.secret>) +// CHECK-DAG: %[[ONE:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[DIAGONALIZED_MATRIX:.*]] = arith.constant dense +// CHECK-SAME{LITERAL}: <[[1, 6, 11, 16], [5, 10, 15, 4], [9, 14, 3, 8], [13, 2, 7, 12]]> : tensor<4x4xi16> +// CHECK-DAG: %[[BIAS:.*]] = arith.constant dense +// CHECK-SAME{LITERAL}: <[[17, 18, 19, 20]]> : tensor<1x4xi16> +// CHECK-DAG: %[[LAST_SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][3, 0] [1, 4] [1, 1] +// CHECK: %[[OUT:.*]] = secret.generic ins(%[[ARG]] : !secret.secret>) +// CHECK: ^bb0(%[[ARG_CONVERTED:.*]]: tensor<1x4xi16>): +// CHECK: %[[FOR_LOOP_OUT:.*]]:2 = affine.for %[[I:.*]] = 0 to 3 iter_args(%[[RUNNING_SUM:.*]] = %[[BIAS]], %[[ROTATED_VEC:.*]] = %[[ARG_CONVERTED]]) +// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][%[[I]], 0] [1, 4] [1, 1] +// CHECK: %[[MUL:.*]] = arith.muli %[[ROTATED_VEC]], %[[SLICE]] +// CHECK: %[[UPDATED_SUM:.*]] = arith.addi %[[RUNNING_SUM]], %[[MUL]] +// CHECK: %[[UPDATED_ROTATED_VEC:.*]] = tensor_ext.rotate %[[ROTATED_VEC]], %[[ONE]] +// CHECK: affine.yield %[[UPDATED_SUM]], %[[UPDATED_ROTATED_VEC]] +// CHECK: %[[LAST_MUL:.*]] = arith.muli %[[FOR_LOOP_OUT]]#1, %[[LAST_SLICE]] +// CHECK: %[[FINAL_SUM:.*]] = arith.addi %[[FOR_LOOP_OUT]]#0, %[[LAST_MUL]] +// CHECK: secret.yield %[[FINAL_SUM]] +// CHECK: return %[[OUT]] module { -func.func @test_integer_square_matrix_vector_linalg_to_arith(%vec : !secret.secret>) -> !secret.secret> { +func.func @test_integer_vector_square_matrix_matmul(%vec : !secret.secret>) -> !secret.secret> { %matrix = arith.constant dense<[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]> : tensor<4x4xi16> - %bias = arith.constant dense<[[17], [18], [19], [20]]> : tensor<4x1xi16> - %out = secret.generic ins (%vec : !secret.secret>) { - ^bb0(%converted_vec: tensor<4x1xi16>): - %0 = linalg.matmul ins(%matrix, %converted_vec : tensor<4x4xi16>, tensor<4x1xi16>) outs(%bias : tensor<4x1xi16>) -> tensor<4x1xi16> - secret.yield %0 : tensor<4x1xi16> - } -> !secret.secret> - return %out : !secret.secret> + %bias = arith.constant dense<[[17, 18, 19, 20]]> : tensor<1x4xi16> + %out = secret.generic ins (%vec : !secret.secret>) { + ^bb0(%converted_vec: tensor<1x4xi16>): + %0 = linalg.matmul ins(%converted_vec, %matrix : tensor<1x4xi16>, tensor<4x4xi16>) outs(%bias : tensor<1x4xi16>) -> tensor<1x4xi16> + secret.yield %0 : tensor<1x4xi16> + } -> !secret.secret> + return %out : !secret.secret> } } diff --git a/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/integer_vector_square_matrix_matmul_op.mlir b/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/integer_vector_square_matrix_matmul_op.mlir index 781cd5cbc4..7dfe8f19c8 100644 --- a/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/integer_vector_square_matrix_matmul_op.mlir +++ b/tests/Dialect/LinAlg/Conversions/linalg_to_tensor_ext/integer_vector_square_matrix_matmul_op.mlir @@ -1,26 +1,26 @@ -// RUN: heir-opt %s --linalg-to-tensor-ext | FileCheck %s +// RUN: heir-opt %s --linalg-to-tensor-ext=tiling-size=4 --canonicalize | FileCheck %s -// CHECK: func.func @test_integer_vector_square_matrix_linalg_to_arith(%[[ARG:.*]]: !secret.secret>) -// CHECK-DAG: %[[ONE:.*]] = arith.constant 1 : index -// CHECK: %[[DIAGONALIZED_MATRIX:.*]] = arith.constant dense +// CHECK: func.func @test_integer_vector_square_matrix_matmul(%[[ARG:.*]]: !secret.secret>) +// CHECK-DAG: %[[ONE:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[DIAGONALIZED_MATRIX:.*]] = arith.constant dense // CHECK-SAME{LITERAL}: <[[1, 6, 11, 16], [5, 10, 15, 4], [9, 14, 3, 8], [13, 2, 7, 12]]> : tensor<4x4xi16> -// CHECK: %[[BIAS:.*]] = arith.constant dense +// CHECK-DAG: %[[BIAS:.*]] = arith.constant dense // CHECK-SAME{LITERAL}: <[[17, 18, 19, 20]]> : tensor<1x4xi16> -// CHECK: %[[OUT:.*]] = secret.generic ins(%[[ARG]] : !secret.secret>) -// CHECK: ^bb0(%[[ARG_CONVERTED:.*]]: tensor<1x4xi16>): -// CHECK: %[[FOR_LOOP_OUT:.*]]:2 = affine.for %[[I:.*]] = 0 to 3 iter_args(%[[RUNNING_SUM:.*]] = %[[BIAS]], %[[ROTATED_VEC:.*]] = %[[ARG_CONVERTED]]) -// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][%[[I]], 0] [1, 4] [1, 1] -// CHECK: %[[MUL:.*]] = arith.muli %[[ROTATED_VEC]], %[[SLICE]] -// CHECK: %[[UPDATED_SUM:.*]] = arith.addi %[[RUNNING_SUM]], %[[MUL]] -// CHECK: %[[UPDATED_ROTATED_VEC:.*]] = tensor_ext.rotate %[[ROTATED_VEC]], %[[ONE]] -// CHECK: affine.yield %[[UPDATED_SUM]], %[[UPDATED_ROTATED_VEC]] -// CHECK: %[[LAST_SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][3, 0] [1, 4] [1, 1] -// CHECK: %[[LAST_MUL:.*]] = arith.muli %[[FOR_LOOP_OUT]]#1, %[[LAST_SLICE]] -// CHECK: %[[FINAL_SUM:.*]] = arith.addi %[[FOR_LOOP_OUT]]#0, %[[LAST_MUL]] -// CHECK: secret.yield %[[FINAL_SUM]] -// CHECK: return %[[OUT]] +// CHECK-DAG: %[[LAST_SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][3, 0] [1, 4] [1, 1] +// CHECK: %[[OUT:.*]] = secret.generic ins(%[[ARG]] : !secret.secret>) +// CHECK: ^bb0(%[[ARG_CONVERTED:.*]]: tensor<1x4xi16>): +// CHECK: %[[FOR_LOOP_OUT:.*]]:2 = affine.for %[[I:.*]] = 0 to 3 iter_args(%[[RUNNING_SUM:.*]] = %[[BIAS]], %[[ROTATED_VEC:.*]] = %[[ARG_CONVERTED]]) +// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[DIAGONALIZED_MATRIX]][%[[I]], 0] [1, 4] [1, 1] +// CHECK: %[[MUL:.*]] = arith.muli %[[ROTATED_VEC]], %[[SLICE]] +// CHECK: %[[UPDATED_SUM:.*]] = arith.addi %[[RUNNING_SUM]], %[[MUL]] +// CHECK: %[[UPDATED_ROTATED_VEC:.*]] = tensor_ext.rotate %[[ROTATED_VEC]], %[[ONE]] +// CHECK: affine.yield %[[UPDATED_SUM]], %[[UPDATED_ROTATED_VEC]] +// CHECK: %[[LAST_MUL:.*]] = arith.muli %[[FOR_LOOP_OUT]]#1, %[[LAST_SLICE]] +// CHECK: %[[FINAL_SUM:.*]] = arith.addi %[[FOR_LOOP_OUT]]#0, %[[LAST_MUL]] +// CHECK: secret.yield %[[FINAL_SUM]] +// CHECK: return %[[OUT]] module { -func.func @test_integer_vector_square_matrix_linalg_to_arith(%vec : !secret.secret>) -> !secret.secret> { +func.func @test_integer_vector_square_matrix_matmul(%vec : !secret.secret>) -> !secret.secret> { %matrix = arith.constant dense<[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]> : tensor<4x4xi16> %bias = arith.constant dense<[[17, 18, 19, 20]]> : tensor<1x4xi16> %out = secret.generic ins (%vec : !secret.secret>) {