Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use isSecret in place of manual lattice lookups where possible #1278

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/Analysis/DimensionAnalysis/DimensionAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ void annotateDimension(Operation *top, DataFlowSolver *solver) {
if (op->getNumResults() == 0) {
return;
}
if (!ensureSecretness(op->getResult(0), solver)) {
if (!isSecret(op->getResult(0), solver)) {
return;
}
op->setAttr("dimension",
Expand Down
4 changes: 2 additions & 2 deletions lib/Analysis/LevelAnalysis/LevelAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ static int getMaxLevel(Operation *top, DataFlowSolver *solver) {
if (op->getNumResults() == 0) {
return;
}
if (!ensureSecretness(op->getResult(0), solver)) {
if (!isSecret(op->getResult(0), solver)) {
return;
}
// ensure result is secret
Expand Down Expand Up @@ -121,7 +121,7 @@ void annotateLevel(Operation *top, DataFlowSolver *solver) {
if (op->getNumResults() == 0) {
return;
}
if (!ensureSecretness(op->getResult(0), solver)) {
if (!isSecret(op->getResult(0), solver)) {
return;
}
auto level = getLevel(op->getResult(0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
}

// skip secret generic op; we decide inside generic op block
if (!isa<secret::GenericOp>(op) &&
ensureSecretness(op->getResults(), solver)) {
if (!isa<secret::GenericOp>(op) && isSecret(op->getResults(), solver)) {
auto decisionVar = model.AddBinaryVariable("InsertRelin_" + name);
decisionVariables.insert(std::make_pair(op, decisionVar));
}
Expand All @@ -116,7 +115,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
std::string varName = "Degree_" + name;
for (Value result : op->getResults()) {
// skip secret generic ops
if (isa<secret::GenericOp>(op) || !ensureSecretness(result, solver)) {
if (isa<secret::GenericOp>(op) || !isSecret(result, solver)) {
continue;
}

Expand All @@ -143,7 +142,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
for (Region &region : op->getRegions()) {
for (Block &block : region.getBlocks()) {
for (BlockArgument arg : block.getArguments()) {
if (!ensureSecretness(arg, solver)) {
if (!isSecret(arg, solver)) {
continue;
}

Expand Down Expand Up @@ -225,7 +224,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
.Case<tensor_ext::RotateOp, secret::YieldOp>([&](auto op) {
for (Value operand : op->getOperands()) {
// skip non secret argument
if (!ensureSecretness(operand, solver)) {
if (!isSecret(operand, solver)) {
continue;
}
if (!keyBasisVars.contains(operand)) {
Expand All @@ -248,12 +247,11 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
llvm::TypeSwitch<Operation &>(*op)
.Case<arith::MulIOp, arith::MulFOp>([&](auto op) {
// if plain mul, skip
if (!ensureSecretness(op.getResult(), solver)) {
if (!isSecret(op.getResult(), solver)) {
return;
}
// ct-ct mul
if (ensureSecretness(op.getLhs(), solver) &&
ensureSecretness(op.getRhs(), solver)) {
if (isSecret(op.getLhs(), solver) && isSecret(op.getRhs(), solver)) {
auto lhsDegreeVar = keyBasisVars.at(op.getLhs());
auto rhsDegreeVar = keyBasisVars.at(op.getRhs());
auto resultBeforeRelinVar = beforeRelinVars.at(op.getResult());
Expand Down Expand Up @@ -302,7 +300,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
if (isa<secret::GenericOp>(op)) {
return;
}
if (!ensureSecretness(op.getResults(), solver)) {
if (!isSecret(op.getResults(), solver)) {
return;
}
SmallVector<OpOperand *, 4> secretOperands;
Expand Down Expand Up @@ -343,7 +341,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
if (isa<secret::GenericOp>(op)) {
return;
}
if (!ensureSecretness(op->getResults(), solver)) {
if (!isSecret(op->getResults(), solver)) {
return;
}

Expand Down
23 changes: 11 additions & 12 deletions lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace heir {

void SecretnessAnalysis::setToEntryState(SecretnessLattice *lattice) {
auto operand = lattice->getAnchor();
bool isSecret = isa<secret::SecretType>(operand.getType());
bool secretness = isa<secret::SecretType>(operand.getType());

Operation *operation = nullptr;
// Get defining operation for operand
Expand All @@ -34,7 +34,7 @@ void SecretnessAnalysis::setToEntryState(SecretnessLattice *lattice) {
if (auto genericOp = dyn_cast<secret::GenericOp>(*operation)) {
if (OpOperand *genericOperand =
genericOp.getOpOperandForBlockArgument(operand)) {
isSecret = isa<secret::SecretType>(genericOperand->get().getType());
secretness = isa<secret::SecretType>(genericOperand->get().getType());
}
}

Expand All @@ -48,24 +48,24 @@ void SecretnessAnalysis::setToEntryState(SecretnessLattice *lattice) {
blockArgs.begin();

// Check if it has secret type
isSecret = isa<secret::SecretType>(funcOp.getArgumentTypes()[index]);
secretness = isa<secret::SecretType>(funcOp.getArgumentTypes()[index]);

// check if it is annotated as {secret.secret}
auto attrs = funcOp.getArgAttrs();
if (attrs) {
auto arr = attrs->getValue();
if (auto dictattr = dyn_cast<DictionaryAttr>(arr[index])) {
for (auto attr : dictattr) {
isSecret =
isSecret ||
secretness =
secretness ||
attr.getName() == secret::SecretDialect::kArgSecretAttrName.str();
break;
}
}
}
}

propagateIfChanged(lattice, lattice->join(Secretness(isSecret)));
propagateIfChanged(lattice, lattice->join(Secretness(secretness)));
}

LogicalResult SecretnessAnalysis::visitOperation(
Expand Down Expand Up @@ -133,7 +133,7 @@ void annotateSecretness(Operation *top, DataFlowSolver *solver) {
});
}

bool ensureSecretness(Value value, DataFlowSolver *solver) {
bool isSecret(Value value, DataFlowSolver *solver) {
auto *lattice = solver->lookupState<SecretnessLattice>(value);
if (!lattice) {
return false;
Expand All @@ -144,20 +144,19 @@ bool ensureSecretness(Value value, DataFlowSolver *solver) {
return lattice->getValue().getSecretness();
}

bool ensureSecretness(ValueRange values, DataFlowSolver *solver) {
bool isSecret(ValueRange values, DataFlowSolver *solver) {
if (values.empty()) {
return false;
}
return std::all_of(values.begin(), values.end(), [&](Value value) {
return ensureSecretness(value, solver);
});
return std::all_of(values.begin(), values.end(),
[&](Value value) { return isSecret(value, solver); });
}

void getSecretOperands(Operation *op,
SmallVectorImpl<OpOperand *> &secretOperands,
DataFlowSolver *solver) {
for (auto &operand : op->getOpOperands()) {
if (ensureSecretness(operand.get(), solver)) {
if (isSecret(operand.get(), solver)) {
secretOperands.push_back(&operand);
}
}
Expand Down
10 changes: 5 additions & 5 deletions lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class SecretnessAnalysisDependent {
* @return true if the value is secret, false if the secretness of the value
* is unknown or false.
*/
bool ensureSecretness(Operation *op, Value value) {
bool isSecretInternal(Operation *op, Value value) {
// create dependency on SecretnessAnalysis
auto *lattice =
getChildAnalysis()->template getOrCreateFor<SecretnessLattice>(
Expand All @@ -157,7 +157,7 @@ class SecretnessAnalysisDependent {
void getSecretResults(Operation *op,
SmallVectorImpl<OpResult> &secretResults) {
for (const auto &result : op->getOpResults()) {
if (ensureSecretness(op, result)) {
if (isSecretInternal(op, result)) {
secretResults.push_back(result);
}
}
Expand All @@ -176,7 +176,7 @@ class SecretnessAnalysisDependent {
void getSecretOperands(Operation *op,
SmallVectorImpl<OpOperand *> &secretOperands) {
for (auto &operand : op->getOpOperands()) {
if (ensureSecretness(op, operand.get())) {
if (isSecretInternal(op, operand.get())) {
secretOperands.push_back(&operand);
}
}
Expand All @@ -188,9 +188,9 @@ void annotateSecretness(Operation *top, DataFlowSolver *solver);

// this method is used when DataFlowSolver has finished running the secretness
// analysis
bool ensureSecretness(Value value, DataFlowSolver *solver);
bool isSecret(Value value, DataFlowSolver *solver);

bool ensureSecretness(ValueRange values, DataFlowSolver *solver);
bool isSecret(ValueRange values, DataFlowSolver *solver);

void getSecretOperands(Operation *op,
SmallVectorImpl<OpOperand *> &secretOperands,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,8 @@ struct ConvertLinalgMatmul : public OpRewritePattern<mlir::linalg::MatmulOp> {
// Determine if the left or right operand is secret to determine which
// matrix to diagonalize, or if both are secret or both are public, then
// return failure.
auto isSecret = [&](Value value) {
auto *operandLookup = solver->lookupState<SecretnessLattice>(value);
Secretness operandSecretness =
operandLookup ? operandLookup->getValue() : Secretness();
return (operandSecretness.isInitialized() &&
operandSecretness.getSecretness());
};

bool isLeftOperandSecret = isSecret(op.getInputs()[0]);
bool isRightOperandSecret = isSecret(op.getInputs()[1]);
bool isLeftOperandSecret = isSecret(op.getInputs()[0], solver);
bool isRightOperandSecret = isSecret(op.getInputs()[1], solver);

LLVM_DEBUG({
llvm::dbgs() << "Left operand is secret: " << isLeftOperandSecret << "\n"
Expand Down
17 changes: 7 additions & 10 deletions lib/Dialect/Secret/Transforms/DistributeGeneric.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {
// Ensure that the loop bound operands are also validated. If they are
// secret types, then return a failure - we cannot distribute through a
// loop with secret bounds.
auto isSecret = [&](OpFoldResult v) {
auto hasSecretType = [&](OpFoldResult v) {
if (auto value = dyn_cast<Value>(v)) {
if (auto *genericOperand =
genericOp.getOpOperandForBlockArgument(value)) {
Expand All @@ -183,9 +183,9 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {
return false;
};
if ((loop.getLoopLowerBounds().has_value() &&
llvm::any_of(loop.getLoopLowerBounds().value(), isSecret)) ||
llvm::any_of(loop.getLoopLowerBounds().value(), hasSecretType)) ||
(loop.getLoopUpperBounds().has_value() &&
llvm::any_of(loop.getLoopUpperBounds().value(), isSecret))) {
llvm::any_of(loop.getLoopUpperBounds().value(), hasSecretType))) {
LLVM_DEBUG(genericOp.emitRemark()
<< "cannot distribute through a LoopLikeInterface with "
"secret bounds");
Expand All @@ -200,15 +200,12 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {
DenseMap<Value, Value> newInitsToOperands;
for (auto [operand, blockArg] : llvm::zip(
clonedLoop.getInitsMutable(), clonedLoop.getRegionIterArgs())) {
auto yieldedIterValue = clonedLoop.getTiedLoopYieldedValue(blockArg);
auto *yieldedIterValue = clonedLoop.getTiedLoopYieldedValue(blockArg);
if (isa<SecretType>(operand.get().getType())) {
blockArg.setType(operand.get().getType());
} else if (solver
->lookupState<SecretnessLattice>(
loop.getYieldedValues()[yieldedIterValue
->getOperandNumber()])
->getValue()
.getSecretness() &&
} else if (isSecret(loop.getYieldedValues()[yieldedIterValue
->getOperandNumber()],
solver) &&
!isa<SecretType>(operand.get().getType())) {
// The initial value of an iter_arg yielded by the original loop must
// be promoted to a secret and added to the new generic's operands if
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,7 @@ struct ConvertTosaSigmoid : public OpRewritePattern<mlir::tosa::SigmoidOp> {

LogicalResult matchAndRewrite(mlir::tosa::SigmoidOp op,
PatternRewriter &rewriter) const override {
auto isSecret = [&](Value value) {
auto *operandLookup = solver->lookupState<SecretnessLattice>(value);
Secretness operandSecretness =
operandLookup ? operandLookup->getValue() : Secretness();
return (operandSecretness.isInitialized() &&
operandSecretness.getSecretness());
};

// Do not support lowering for non-secret operands
bool operandIsSecret = isSecret(op.getOperand());
bool operandIsSecret = isSecret(op.getOperand(), solver);
if (!operandIsSecret) {
return failure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,8 @@ struct SecretForToStaticForConversion : OpRewritePattern<scf::ForOp> {

LogicalResult matchAndRewrite(scf::ForOp forOp,
PatternRewriter &rewriter) const override {
auto *lowerBoundSecretnessLattice =
solver->lookupState<SecretnessLattice>(forOp.getLowerBound());

auto *upperBoundSecretnessLattice =
solver->lookupState<SecretnessLattice>(forOp.getUpperBound());

if (!lowerBoundSecretnessLattice && !upperBoundSecretnessLattice)
return failure();

// Get secretness state of the lower and upper bounds
bool isLowerBoundSecret =
lowerBoundSecretnessLattice &&
lowerBoundSecretnessLattice->getValue().getSecretness();
bool isUpperBoundSecret =
upperBoundSecretnessLattice &&
upperBoundSecretnessLattice->getValue().getSecretness();
bool isLowerBoundSecret = isSecret(forOp.getLowerBound(), solver);
bool isUpperBoundSecret = isSecret(forOp.getUpperBound(), solver);

// If both bounds are non-secret constants, return
if (!isLowerBoundSecret && !isUpperBoundSecret) return failure();
Expand Down
10 changes: 3 additions & 7 deletions lib/Transforms/SecretInsertMgmt/SecretInsertMgmtCKKS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ bool isTensorInSlots(Operation *top, DataFlowSolver *solver, int slotNumber) {
LogicalResult result = walkAndValidateValues(
top,
[&](Value value) {
auto secretness =
solver->lookupState<SecretnessLattice>(value)->getValue();
if (secretness.isInitialized() && secretness.getSecretness()) {
auto secret = isSecret(value, solver);
if (secret) {
auto tensorTy = dyn_cast<RankedTensorType>(value.getType());
if (tensorTy) {
// TODO(#913): Multidimensional tensors with a single non-unit
Expand All @@ -88,10 +87,7 @@ bool isTensorInSlots(Operation *top, DataFlowSolver *solver, int slotNumber) {
void annotateTensorExtractAsNotSlotExtract(Operation *top,
DataFlowSolver *solver) {
top->walk([&](tensor::ExtractOp extractOp) {
auto secretness =
solver->lookupState<SecretnessLattice>(extractOp.getOperand(0))
->getValue();
if (secretness.isInitialized() && secretness.getSecretness()) {
if (isSecret(extractOp.getOperand(0), solver)) {
extractOp->setAttr("slot_extract",
BoolAttr::get(extractOp.getContext(), false));
}
Expand Down
22 changes: 8 additions & 14 deletions lib/Transforms/SecretInsertMgmt/SecretInsertMgmtPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,15 @@ template <typename MulOp>
LogicalResult MultRelinearize<MulOp>::matchAndRewrite(
MulOp mulOp, PatternRewriter &rewriter) const {
Value result = mulOp.getResult();
auto secret = solver->lookupState<SecretnessLattice>(result)->getValue();
// if not secret, skip
if (!secret.isInitialized() || !secret.getSecretness()) {
bool secret = isSecret(result, solver);
if (!secret) {
return success();
}

// if mul const, skip
for (auto operand : mulOp.getOperands()) {
auto secretness =
solver->lookupState<SecretnessLattice>(operand)->getValue();
if (!secretness.isInitialized() || !secretness.getSecretness()) {
auto secret = isSecret(operand, solver);
if (!secret) {
return success();
}
}
Expand All @@ -58,8 +56,8 @@ LogicalResult ModReduceBefore<Op>::matchAndRewrite(
// guard against secret::YieldOp
if (op->getResults().size() > 0) {
for (auto result : op->getResults()) {
auto secret = solver->lookupState<SecretnessLattice>(result)->getValue();
if (!secret.isInitialized() || !secret.getSecretness()) {
bool secret = isSecret(result, solver);
if (!secret) {
return success();
}
}
Expand All @@ -82,12 +80,8 @@ LogicalResult ModReduceBefore<Op>::matchAndRewrite(
// use map in case we have same operands
DenseMap<Value, LevelState::LevelType> operandsInsertLevel;
for (auto operand : op.getOperands()) {
auto secretness =
solver->lookupState<SecretnessLattice>(operand)->getValue();
if (!secretness.isInitialized()) {
return failure();
}
if (!secretness.getSecretness()) {
bool secret = isSecret(operand, solver);
if (!secret) {
continue;
}
auto levelLattice = solver->lookupState<LevelLattice>(operand)->getValue();
Expand Down
Loading