diff --git a/lib/Analysis/DimensionAnalysis/DimensionAnalysis.cpp b/lib/Analysis/DimensionAnalysis/DimensionAnalysis.cpp index aab98be6b..9288ebfe6 100644 --- a/lib/Analysis/DimensionAnalysis/DimensionAnalysis.cpp +++ b/lib/Analysis/DimensionAnalysis/DimensionAnalysis.cpp @@ -108,7 +108,7 @@ void annotateDimension(Operation *top, DataFlowSolver *solver) { if (op->getNumResults() == 0) { return; } - if (!ensureSecretness(op->getResult(0), solver)) { + if (!isSecret(op->getResult(0), solver)) { return; } op->setAttr("dimension", diff --git a/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp b/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp index 0cc8e3f52..6d54c5907 100644 --- a/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp +++ b/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp @@ -84,7 +84,7 @@ static int getMaxLevel(Operation *top, DataFlowSolver *solver) { if (op->getNumResults() == 0) { return; } - if (!ensureSecretness(op->getResult(0), solver)) { + if (!isSecret(op->getResult(0), solver)) { return; } // ensure result is secret @@ -121,7 +121,7 @@ void annotateLevel(Operation *top, DataFlowSolver *solver) { if (op->getNumResults() == 0) { return; } - if (!ensureSecretness(op->getResult(0), solver)) { + if (!isSecret(op->getResult(0), solver)) { return; } auto level = getLevel(op->getResult(0)); diff --git a/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.cpp b/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.cpp index 2af2f07af..b18121bcd 100644 --- a/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.cpp +++ b/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.cpp @@ -105,8 +105,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { } // skip secret generic op; we decide inside generic op block - if (!isa(op) && - ensureSecretness(op->getResults(), solver)) { + if (!isa(op) && isSecret(op->getResults(), solver)) { auto decisionVar = model.AddBinaryVariable("InsertRelin_" + name); decisionVariables.insert(std::make_pair(op, decisionVar)); } @@ -116,7 +115,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { std::string varName = "Degree_" + name; for (Value result : op->getResults()) { // skip secret generic ops - if (isa(op) || !ensureSecretness(result, solver)) { + if (isa(op) || !isSecret(result, solver)) { continue; } @@ -143,7 +142,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { for (Region ®ion : op->getRegions()) { for (Block &block : region.getBlocks()) { for (BlockArgument arg : block.getArguments()) { - if (!ensureSecretness(arg, solver)) { + if (!isSecret(arg, solver)) { continue; } @@ -225,7 +224,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { .Case([&](auto op) { for (Value operand : op->getOperands()) { // skip non secret argument - if (!ensureSecretness(operand, solver)) { + if (!isSecret(operand, solver)) { continue; } if (!keyBasisVars.contains(operand)) { @@ -248,12 +247,11 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { llvm::TypeSwitch(*op) .Case([&](auto op) { // if plain mul, skip - if (!ensureSecretness(op.getResult(), solver)) { + if (!isSecret(op.getResult(), solver)) { return; } // ct-ct mul - if (ensureSecretness(op.getLhs(), solver) && - ensureSecretness(op.getRhs(), solver)) { + if (isSecret(op.getLhs(), solver) && isSecret(op.getRhs(), solver)) { auto lhsDegreeVar = keyBasisVars.at(op.getLhs()); auto rhsDegreeVar = keyBasisVars.at(op.getRhs()); auto resultBeforeRelinVar = beforeRelinVars.at(op.getResult()); @@ -302,7 +300,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { if (isa(op)) { return; } - if (!ensureSecretness(op.getResults(), solver)) { + if (!isSecret(op.getResults(), solver)) { return; } SmallVector secretOperands; @@ -343,7 +341,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { if (isa(op)) { return; } - if (!ensureSecretness(op->getResults(), solver)) { + if (!isSecret(op->getResults(), solver)) { return; } diff --git a/lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.cpp b/lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.cpp index c61db8727..9993fcab9 100644 --- a/lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.cpp +++ b/lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.cpp @@ -19,7 +19,7 @@ namespace heir { void SecretnessAnalysis::setToEntryState(SecretnessLattice *lattice) { auto operand = lattice->getAnchor(); - bool isSecret = isa(operand.getType()); + bool secretness = isa(operand.getType()); Operation *operation = nullptr; // Get defining operation for operand @@ -34,7 +34,7 @@ void SecretnessAnalysis::setToEntryState(SecretnessLattice *lattice) { if (auto genericOp = dyn_cast(*operation)) { if (OpOperand *genericOperand = genericOp.getOpOperandForBlockArgument(operand)) { - isSecret = isa(genericOperand->get().getType()); + secretness = isa(genericOperand->get().getType()); } } @@ -48,7 +48,7 @@ void SecretnessAnalysis::setToEntryState(SecretnessLattice *lattice) { blockArgs.begin(); // Check if it has secret type - isSecret = isa(funcOp.getArgumentTypes()[index]); + secretness = isa(funcOp.getArgumentTypes()[index]); // check if it is annotated as {secret.secret} auto attrs = funcOp.getArgAttrs(); @@ -56,8 +56,8 @@ void SecretnessAnalysis::setToEntryState(SecretnessLattice *lattice) { auto arr = attrs->getValue(); if (auto dictattr = dyn_cast(arr[index])) { for (auto attr : dictattr) { - isSecret = - isSecret || + secretness = + secretness || attr.getName() == secret::SecretDialect::kArgSecretAttrName.str(); break; } @@ -65,7 +65,7 @@ void SecretnessAnalysis::setToEntryState(SecretnessLattice *lattice) { } } - propagateIfChanged(lattice, lattice->join(Secretness(isSecret))); + propagateIfChanged(lattice, lattice->join(Secretness(secretness))); } LogicalResult SecretnessAnalysis::visitOperation( @@ -133,7 +133,7 @@ void annotateSecretness(Operation *top, DataFlowSolver *solver) { }); } -bool ensureSecretness(Value value, DataFlowSolver *solver) { +bool isSecret(Value value, DataFlowSolver *solver) { auto *lattice = solver->lookupState(value); if (!lattice) { return false; @@ -144,20 +144,19 @@ bool ensureSecretness(Value value, DataFlowSolver *solver) { return lattice->getValue().getSecretness(); } -bool ensureSecretness(ValueRange values, DataFlowSolver *solver) { +bool isSecret(ValueRange values, DataFlowSolver *solver) { if (values.empty()) { return false; } - return std::all_of(values.begin(), values.end(), [&](Value value) { - return ensureSecretness(value, solver); - }); + return std::all_of(values.begin(), values.end(), + [&](Value value) { return isSecret(value, solver); }); } void getSecretOperands(Operation *op, SmallVectorImpl &secretOperands, DataFlowSolver *solver) { for (auto &operand : op->getOpOperands()) { - if (ensureSecretness(operand.get(), solver)) { + if (isSecret(operand.get(), solver)) { secretOperands.push_back(&operand); } } diff --git a/lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h b/lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h index 13f6ebdc3..c64de1776 100644 --- a/lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h +++ b/lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h @@ -133,7 +133,7 @@ class SecretnessAnalysisDependent { * @return true if the value is secret, false if the secretness of the value * is unknown or false. */ - bool ensureSecretness(Operation *op, Value value) { + bool isSecretInternal(Operation *op, Value value) { // create dependency on SecretnessAnalysis auto *lattice = getChildAnalysis()->template getOrCreateFor( @@ -157,7 +157,7 @@ class SecretnessAnalysisDependent { void getSecretResults(Operation *op, SmallVectorImpl &secretResults) { for (const auto &result : op->getOpResults()) { - if (ensureSecretness(op, result)) { + if (isSecretInternal(op, result)) { secretResults.push_back(result); } } @@ -176,7 +176,7 @@ class SecretnessAnalysisDependent { void getSecretOperands(Operation *op, SmallVectorImpl &secretOperands) { for (auto &operand : op->getOpOperands()) { - if (ensureSecretness(op, operand.get())) { + if (isSecretInternal(op, operand.get())) { secretOperands.push_back(&operand); } } @@ -188,9 +188,9 @@ void annotateSecretness(Operation *top, DataFlowSolver *solver); // this method is used when DataFlowSolver has finished running the secretness // analysis -bool ensureSecretness(Value value, DataFlowSolver *solver); +bool isSecret(Value value, DataFlowSolver *solver); -bool ensureSecretness(ValueRange values, DataFlowSolver *solver); +bool isSecret(ValueRange values, DataFlowSolver *solver); void getSecretOperands(Operation *op, SmallVectorImpl &secretOperands, diff --git a/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/LinalgToTensorExt.cpp b/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/LinalgToTensorExt.cpp index 4f151ea43..fb23acfc6 100644 --- a/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/LinalgToTensorExt.cpp +++ b/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/LinalgToTensorExt.cpp @@ -235,16 +235,8 @@ struct ConvertLinalgMatmul : public OpRewritePattern { // Determine if the left or right operand is secret to determine which // matrix to diagonalize, or if both are secret or both are public, then // return failure. - auto isSecret = [&](Value value) { - auto *operandLookup = solver->lookupState(value); - Secretness operandSecretness = - operandLookup ? operandLookup->getValue() : Secretness(); - return (operandSecretness.isInitialized() && - operandSecretness.getSecretness()); - }; - - bool isLeftOperandSecret = isSecret(op.getInputs()[0]); - bool isRightOperandSecret = isSecret(op.getInputs()[1]); + bool isLeftOperandSecret = isSecret(op.getInputs()[0], solver); + bool isRightOperandSecret = isSecret(op.getInputs()[1], solver); LLVM_DEBUG({ llvm::dbgs() << "Left operand is secret: " << isLeftOperandSecret << "\n" diff --git a/lib/Dialect/Secret/Transforms/DistributeGeneric.cpp b/lib/Dialect/Secret/Transforms/DistributeGeneric.cpp index d0f3e9834..9864215b2 100644 --- a/lib/Dialect/Secret/Transforms/DistributeGeneric.cpp +++ b/lib/Dialect/Secret/Transforms/DistributeGeneric.cpp @@ -171,7 +171,7 @@ struct SplitGeneric : public OpRewritePattern { // Ensure that the loop bound operands are also validated. If they are // secret types, then return a failure - we cannot distribute through a // loop with secret bounds. - auto isSecret = [&](OpFoldResult v) { + auto hasSecretType = [&](OpFoldResult v) { if (auto value = dyn_cast(v)) { if (auto *genericOperand = genericOp.getOpOperandForBlockArgument(value)) { @@ -183,9 +183,9 @@ struct SplitGeneric : public OpRewritePattern { return false; }; if ((loop.getLoopLowerBounds().has_value() && - llvm::any_of(loop.getLoopLowerBounds().value(), isSecret)) || + llvm::any_of(loop.getLoopLowerBounds().value(), hasSecretType)) || (loop.getLoopUpperBounds().has_value() && - llvm::any_of(loop.getLoopUpperBounds().value(), isSecret))) { + llvm::any_of(loop.getLoopUpperBounds().value(), hasSecretType))) { LLVM_DEBUG(genericOp.emitRemark() << "cannot distribute through a LoopLikeInterface with " "secret bounds"); @@ -200,15 +200,12 @@ struct SplitGeneric : public OpRewritePattern { DenseMap newInitsToOperands; for (auto [operand, blockArg] : llvm::zip( clonedLoop.getInitsMutable(), clonedLoop.getRegionIterArgs())) { - auto yieldedIterValue = clonedLoop.getTiedLoopYieldedValue(blockArg); + auto *yieldedIterValue = clonedLoop.getTiedLoopYieldedValue(blockArg); if (isa(operand.get().getType())) { blockArg.setType(operand.get().getType()); - } else if (solver - ->lookupState( - loop.getYieldedValues()[yieldedIterValue - ->getOperandNumber()]) - ->getValue() - .getSecretness() && + } else if (isSecret(loop.getYieldedValues()[yieldedIterValue + ->getOperandNumber()], + solver) && !isa(operand.get().getType())) { // The initial value of an iter_arg yielded by the original loop must // be promoted to a secret and added to the new generic's operands if diff --git a/lib/Dialect/TOSA/Conversions/TosaToSecretArith/TosaToSecretArith.cpp b/lib/Dialect/TOSA/Conversions/TosaToSecretArith/TosaToSecretArith.cpp index aa66403f3..c14a6f432 100644 --- a/lib/Dialect/TOSA/Conversions/TosaToSecretArith/TosaToSecretArith.cpp +++ b/lib/Dialect/TOSA/Conversions/TosaToSecretArith/TosaToSecretArith.cpp @@ -62,16 +62,7 @@ struct ConvertTosaSigmoid : public OpRewritePattern { LogicalResult matchAndRewrite(mlir::tosa::SigmoidOp op, PatternRewriter &rewriter) const override { - auto isSecret = [&](Value value) { - auto *operandLookup = solver->lookupState(value); - Secretness operandSecretness = - operandLookup ? operandLookup->getValue() : Secretness(); - return (operandSecretness.isInitialized() && - operandSecretness.getSecretness()); - }; - - // Do not support lowering for non-secret operands - bool operandIsSecret = isSecret(op.getOperand()); + bool operandIsSecret = isSecret(op.getOperand(), solver); if (!operandIsSecret) { return failure(); } diff --git a/lib/Transforms/ConvertSecretForToStaticFor/ConvertSecretForToStaticFor.cpp b/lib/Transforms/ConvertSecretForToStaticFor/ConvertSecretForToStaticFor.cpp index 4f9c81e06..89f28a0f4 100644 --- a/lib/Transforms/ConvertSecretForToStaticFor/ConvertSecretForToStaticFor.cpp +++ b/lib/Transforms/ConvertSecretForToStaticFor/ConvertSecretForToStaticFor.cpp @@ -36,22 +36,8 @@ struct SecretForToStaticForConversion : OpRewritePattern { LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const override { - auto *lowerBoundSecretnessLattice = - solver->lookupState(forOp.getLowerBound()); - - auto *upperBoundSecretnessLattice = - solver->lookupState(forOp.getUpperBound()); - - if (!lowerBoundSecretnessLattice && !upperBoundSecretnessLattice) - return failure(); - - // Get secretness state of the lower and upper bounds - bool isLowerBoundSecret = - lowerBoundSecretnessLattice && - lowerBoundSecretnessLattice->getValue().getSecretness(); - bool isUpperBoundSecret = - upperBoundSecretnessLattice && - upperBoundSecretnessLattice->getValue().getSecretness(); + bool isLowerBoundSecret = isSecret(forOp.getLowerBound(), solver); + bool isUpperBoundSecret = isSecret(forOp.getUpperBound(), solver); // If both bounds are non-secret constants, return if (!isLowerBoundSecret && !isUpperBoundSecret) return failure(); diff --git a/lib/Transforms/SecretInsertMgmt/SecretInsertMgmtCKKS.cpp b/lib/Transforms/SecretInsertMgmt/SecretInsertMgmtCKKS.cpp index f92242c10..58adf050f 100644 --- a/lib/Transforms/SecretInsertMgmt/SecretInsertMgmtCKKS.cpp +++ b/lib/Transforms/SecretInsertMgmt/SecretInsertMgmtCKKS.cpp @@ -62,9 +62,8 @@ bool isTensorInSlots(Operation *top, DataFlowSolver *solver, int slotNumber) { LogicalResult result = walkAndValidateValues( top, [&](Value value) { - auto secretness = - solver->lookupState(value)->getValue(); - if (secretness.isInitialized() && secretness.getSecretness()) { + auto secret = isSecret(value, solver); + if (secret) { auto tensorTy = dyn_cast(value.getType()); if (tensorTy) { // TODO(#913): Multidimensional tensors with a single non-unit @@ -88,10 +87,7 @@ bool isTensorInSlots(Operation *top, DataFlowSolver *solver, int slotNumber) { void annotateTensorExtractAsNotSlotExtract(Operation *top, DataFlowSolver *solver) { top->walk([&](tensor::ExtractOp extractOp) { - auto secretness = - solver->lookupState(extractOp.getOperand(0)) - ->getValue(); - if (secretness.isInitialized() && secretness.getSecretness()) { + if (isSecret(extractOp.getOperand(0), solver)) { extractOp->setAttr("slot_extract", BoolAttr::get(extractOp.getContext(), false)); } diff --git a/lib/Transforms/SecretInsertMgmt/SecretInsertMgmtPatterns.cpp b/lib/Transforms/SecretInsertMgmt/SecretInsertMgmtPatterns.cpp index 2dd5f33a3..c490d5e72 100644 --- a/lib/Transforms/SecretInsertMgmt/SecretInsertMgmtPatterns.cpp +++ b/lib/Transforms/SecretInsertMgmt/SecretInsertMgmtPatterns.cpp @@ -28,17 +28,15 @@ template LogicalResult MultRelinearize::matchAndRewrite( MulOp mulOp, PatternRewriter &rewriter) const { Value result = mulOp.getResult(); - auto secret = solver->lookupState(result)->getValue(); - // if not secret, skip - if (!secret.isInitialized() || !secret.getSecretness()) { + bool secret = isSecret(result, solver); + if (!secret) { return success(); } // if mul const, skip for (auto operand : mulOp.getOperands()) { - auto secretness = - solver->lookupState(operand)->getValue(); - if (!secretness.isInitialized() || !secretness.getSecretness()) { + auto secret = isSecret(operand, solver); + if (!secret) { return success(); } } @@ -58,8 +56,8 @@ LogicalResult ModReduceBefore::matchAndRewrite( // guard against secret::YieldOp if (op->getResults().size() > 0) { for (auto result : op->getResults()) { - auto secret = solver->lookupState(result)->getValue(); - if (!secret.isInitialized() || !secret.getSecretness()) { + bool secret = isSecret(result, solver); + if (!secret) { return success(); } } @@ -82,12 +80,8 @@ LogicalResult ModReduceBefore::matchAndRewrite( // use map in case we have same operands DenseMap operandsInsertLevel; for (auto operand : op.getOperands()) { - auto secretness = - solver->lookupState(operand)->getValue(); - if (!secretness.isInitialized()) { - return failure(); - } - if (!secretness.getSecretness()) { + bool secret = isSecret(operand, solver); + if (!secret) { continue; } auto levelLattice = solver->lookupState(operand)->getValue();