Skip to content

Commit

Permalink
simplify some secretness lattice lookups
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Jan 21, 2025
1 parent 3917e51 commit 8303e53
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 97 deletions.
2 changes: 1 addition & 1 deletion lib/Analysis/DimensionAnalysis/DimensionAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ void annotateDimension(Operation *top, DataFlowSolver *solver) {
if (op->getNumResults() == 0) {
return;
}
if (!ensureSecretness(op->getResult(0), solver)) {
if (!isSecret(op->getResult(0), solver)) {
return;
}
op->setAttr("dimension",
Expand Down
4 changes: 2 additions & 2 deletions lib/Analysis/LevelAnalysis/LevelAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ static int getMaxLevel(Operation *top, DataFlowSolver *solver) {
if (op->getNumResults() == 0) {
return;
}
if (!ensureSecretness(op->getResult(0), solver)) {
if (!isSecret(op->getResult(0), solver)) {
return;
}
// ensure result is secret
Expand Down Expand Up @@ -121,7 +121,7 @@ void annotateLevel(Operation *top, DataFlowSolver *solver) {
if (op->getNumResults() == 0) {
return;
}
if (!ensureSecretness(op->getResult(0), solver)) {
if (!isSecret(op->getResult(0), solver)) {
return;
}
auto level = getLevel(op->getResult(0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
}

// skip secret generic op; we decide inside generic op block
if (!isa<secret::GenericOp>(op) &&
ensureSecretness(op->getResults(), solver)) {
if (!isa<secret::GenericOp>(op) && isSecret(op->getResults(), solver)) {
auto decisionVar = model.AddBinaryVariable("InsertRelin_" + name);
decisionVariables.insert(std::make_pair(op, decisionVar));
}
Expand All @@ -116,7 +115,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
std::string varName = "Degree_" + name;
for (Value result : op->getResults()) {
// skip secret generic ops
if (isa<secret::GenericOp>(op) || !ensureSecretness(result, solver)) {
if (isa<secret::GenericOp>(op) || !isSecret(result, solver)) {
continue;
}

Expand All @@ -143,7 +142,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
for (Region &region : op->getRegions()) {
for (Block &block : region.getBlocks()) {
for (BlockArgument arg : block.getArguments()) {
if (!ensureSecretness(arg, solver)) {
if (!isSecret(arg, solver)) {
continue;
}

Expand Down Expand Up @@ -225,7 +224,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
.Case<tensor_ext::RotateOp, secret::YieldOp>([&](auto op) {
for (Value operand : op->getOperands()) {
// skip non secret argument
if (!ensureSecretness(operand, solver)) {
if (!isSecret(operand, solver)) {
continue;
}
if (!keyBasisVars.contains(operand)) {
Expand All @@ -248,12 +247,11 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
llvm::TypeSwitch<Operation &>(*op)
.Case<arith::MulIOp, arith::MulFOp>([&](auto op) {
// if plain mul, skip
if (!ensureSecretness(op.getResult(), solver)) {
if (!isSecret(op.getResult(), solver)) {
return;
}
// ct-ct mul
if (ensureSecretness(op.getLhs(), solver) &&
ensureSecretness(op.getRhs(), solver)) {
if (isSecret(op.getLhs(), solver) && isSecret(op.getRhs(), solver)) {
auto lhsDegreeVar = keyBasisVars.at(op.getLhs());
auto rhsDegreeVar = keyBasisVars.at(op.getRhs());
auto resultBeforeRelinVar = beforeRelinVars.at(op.getResult());
Expand Down Expand Up @@ -302,7 +300,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
if (isa<secret::GenericOp>(op)) {
return;
}
if (!ensureSecretness(op.getResults(), solver)) {
if (!isSecret(op.getResults(), solver)) {
return;
}
SmallVector<OpOperand *, 4> secretOperands;
Expand Down Expand Up @@ -343,7 +341,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
if (isa<secret::GenericOp>(op)) {
return;
}
if (!ensureSecretness(op->getResults(), solver)) {
if (!isSecret(op->getResults(), solver)) {
return;
}

Expand Down
23 changes: 11 additions & 12 deletions lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace heir {

void SecretnessAnalysis::setToEntryState(SecretnessLattice *lattice) {
auto operand = lattice->getAnchor();
bool isSecret = isa<secret::SecretType>(operand.getType());
bool secretness = isa<secret::SecretType>(operand.getType());

Operation *operation = nullptr;
// Get defining operation for operand
Expand All @@ -34,7 +34,7 @@ void SecretnessAnalysis::setToEntryState(SecretnessLattice *lattice) {
if (auto genericOp = dyn_cast<secret::GenericOp>(*operation)) {
if (OpOperand *genericOperand =
genericOp.getOpOperandForBlockArgument(operand)) {
isSecret = isa<secret::SecretType>(genericOperand->get().getType());
secretness = isa<secret::SecretType>(genericOperand->get().getType());
}
}

Expand All @@ -48,24 +48,24 @@ void SecretnessAnalysis::setToEntryState(SecretnessLattice *lattice) {
blockArgs.begin();

// Check if it has secret type
isSecret = isa<secret::SecretType>(funcOp.getArgumentTypes()[index]);
secretness = isa<secret::SecretType>(funcOp.getArgumentTypes()[index]);

// check if it is annotated as {secret.secret}
auto attrs = funcOp.getArgAttrs();
if (attrs) {
auto arr = attrs->getValue();
if (auto dictattr = dyn_cast<DictionaryAttr>(arr[index])) {
for (auto attr : dictattr) {
isSecret =
isSecret ||
secretness =
secretness ||
attr.getName() == secret::SecretDialect::kArgSecretAttrName.str();
break;
}
}
}
}

propagateIfChanged(lattice, lattice->join(Secretness(isSecret)));
propagateIfChanged(lattice, lattice->join(Secretness(secretness)));
}

LogicalResult SecretnessAnalysis::visitOperation(
Expand Down Expand Up @@ -133,7 +133,7 @@ void annotateSecretness(Operation *top, DataFlowSolver *solver) {
});
}

bool ensureSecretness(Value value, DataFlowSolver *solver) {
bool isSecret(Value value, DataFlowSolver *solver) {
auto *lattice = solver->lookupState<SecretnessLattice>(value);
if (!lattice) {
return false;
Expand All @@ -144,20 +144,19 @@ bool ensureSecretness(Value value, DataFlowSolver *solver) {
return lattice->getValue().getSecretness();
}

bool ensureSecretness(ValueRange values, DataFlowSolver *solver) {
bool isSecret(ValueRange values, DataFlowSolver *solver) {
if (values.empty()) {
return false;
}
return std::all_of(values.begin(), values.end(), [&](Value value) {
return ensureSecretness(value, solver);
});
return std::all_of(values.begin(), values.end(),
[&](Value value) { return isSecret(value, solver); });
}

void getSecretOperands(Operation *op,
SmallVectorImpl<OpOperand *> &secretOperands,
DataFlowSolver *solver) {
for (auto &operand : op->getOpOperands()) {
if (ensureSecretness(operand.get(), solver)) {
if (isSecret(operand.get(), solver)) {
secretOperands.push_back(&operand);
}
}
Expand Down
10 changes: 5 additions & 5 deletions lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class SecretnessAnalysisDependent {
* @return true if the value is secret, false if the secretness of the value
* is unknown or false.
*/
bool ensureSecretness(Operation *op, Value value) {
bool isSecretInternal(Operation *op, Value value) {
// create dependency on SecretnessAnalysis
auto *lattice =
getChildAnalysis()->template getOrCreateFor<SecretnessLattice>(
Expand All @@ -157,7 +157,7 @@ class SecretnessAnalysisDependent {
void getSecretResults(Operation *op,
SmallVectorImpl<OpResult> &secretResults) {
for (const auto &result : op->getOpResults()) {
if (ensureSecretness(op, result)) {
if (isSecretInternal(op, result)) {
secretResults.push_back(result);
}
}
Expand All @@ -176,7 +176,7 @@ class SecretnessAnalysisDependent {
void getSecretOperands(Operation *op,
SmallVectorImpl<OpOperand *> &secretOperands) {
for (auto &operand : op->getOpOperands()) {
if (ensureSecretness(op, operand.get())) {
if (isSecretInternal(op, operand.get())) {
secretOperands.push_back(&operand);
}
}
Expand All @@ -188,9 +188,9 @@ void annotateSecretness(Operation *top, DataFlowSolver *solver);

// this method is used when DataFlowSolver has finished running the secretness
// analysis
bool ensureSecretness(Value value, DataFlowSolver *solver);
bool isSecret(Value value, DataFlowSolver *solver);

bool ensureSecretness(ValueRange values, DataFlowSolver *solver);
bool isSecret(ValueRange values, DataFlowSolver *solver);

void getSecretOperands(Operation *op,
SmallVectorImpl<OpOperand *> &secretOperands,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,8 @@ struct ConvertLinalgMatmul : public OpRewritePattern<mlir::linalg::MatmulOp> {
// Determine if the left or right operand is secret to determine which
// matrix to diagonalize, or if both are secret or both are public, then
// return failure.
auto isSecret = [&](Value value) {
auto *operandLookup = solver->lookupState<SecretnessLattice>(value);
Secretness operandSecretness =
operandLookup ? operandLookup->getValue() : Secretness();
return (operandSecretness.isInitialized() &&
operandSecretness.getSecretness());
};

bool isLeftOperandSecret = isSecret(op.getInputs()[0]);
bool isRightOperandSecret = isSecret(op.getInputs()[1]);
bool isLeftOperandSecret = isSecret(op.getInputs()[0], solver);
bool isRightOperandSecret = isSecret(op.getInputs()[1], solver);

LLVM_DEBUG({
llvm::dbgs() << "Left operand is secret: " << isLeftOperandSecret << "\n"
Expand Down
17 changes: 7 additions & 10 deletions lib/Dialect/Secret/Transforms/DistributeGeneric.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {
// Ensure that the loop bound operands are also validated. If they are
// secret types, then return a failure - we cannot distribute through a
// loop with secret bounds.
auto isSecret = [&](OpFoldResult v) {
auto hasSecretType = [&](OpFoldResult v) {
if (auto value = dyn_cast<Value>(v)) {
if (auto *genericOperand =
genericOp.getOpOperandForBlockArgument(value)) {
Expand All @@ -183,9 +183,9 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {
return false;
};
if ((loop.getLoopLowerBounds().has_value() &&
llvm::any_of(loop.getLoopLowerBounds().value(), isSecret)) ||
llvm::any_of(loop.getLoopLowerBounds().value(), hasSecretType)) ||
(loop.getLoopUpperBounds().has_value() &&
llvm::any_of(loop.getLoopUpperBounds().value(), isSecret))) {
llvm::any_of(loop.getLoopUpperBounds().value(), hasSecretType))) {
LLVM_DEBUG(genericOp.emitRemark()
<< "cannot distribute through a LoopLikeInterface with "
"secret bounds");
Expand All @@ -200,15 +200,12 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {
DenseMap<Value, Value> newInitsToOperands;
for (auto [operand, blockArg] : llvm::zip(
clonedLoop.getInitsMutable(), clonedLoop.getRegionIterArgs())) {
auto yieldedIterValue = clonedLoop.getTiedLoopYieldedValue(blockArg);
auto *yieldedIterValue = clonedLoop.getTiedLoopYieldedValue(blockArg);
if (isa<SecretType>(operand.get().getType())) {
blockArg.setType(operand.get().getType());
} else if (solver
->lookupState<SecretnessLattice>(
loop.getYieldedValues()[yieldedIterValue
->getOperandNumber()])
->getValue()
.getSecretness() &&
} else if (isSecret(loop.getYieldedValues()[yieldedIterValue
->getOperandNumber()],
solver) &&
!isa<SecretType>(operand.get().getType())) {
// The initial value of an iter_arg yielded by the original loop must
// be promoted to a secret and added to the new generic's operands if
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,7 @@ struct ConvertTosaSigmoid : public OpRewritePattern<mlir::tosa::SigmoidOp> {

LogicalResult matchAndRewrite(mlir::tosa::SigmoidOp op,
PatternRewriter &rewriter) const override {
auto isSecret = [&](Value value) {
auto *operandLookup = solver->lookupState<SecretnessLattice>(value);
Secretness operandSecretness =
operandLookup ? operandLookup->getValue() : Secretness();
return (operandSecretness.isInitialized() &&
operandSecretness.getSecretness());
};

// Do not support lowering for non-secret operands
bool operandIsSecret = isSecret(op.getOperand());
bool operandIsSecret = isSecret(op.getOperand(), solver);
if (!operandIsSecret) {
return failure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,8 @@ struct SecretForToStaticForConversion : OpRewritePattern<scf::ForOp> {

LogicalResult matchAndRewrite(scf::ForOp forOp,
PatternRewriter &rewriter) const override {
auto *lowerBoundSecretnessLattice =
solver->lookupState<SecretnessLattice>(forOp.getLowerBound());

auto *upperBoundSecretnessLattice =
solver->lookupState<SecretnessLattice>(forOp.getUpperBound());

if (!lowerBoundSecretnessLattice && !upperBoundSecretnessLattice)
return failure();

// Get secretness state of the lower and upper bounds
bool isLowerBoundSecret =
lowerBoundSecretnessLattice &&
lowerBoundSecretnessLattice->getValue().getSecretness();
bool isUpperBoundSecret =
upperBoundSecretnessLattice &&
upperBoundSecretnessLattice->getValue().getSecretness();
bool isLowerBoundSecret = isSecret(forOp.getLowerBound(), solver);
bool isUpperBoundSecret = isSecret(forOp.getUpperBound(), solver);

// If both bounds are non-secret constants, return
if (!isLowerBoundSecret && !isUpperBoundSecret) return failure();
Expand Down
10 changes: 3 additions & 7 deletions lib/Transforms/SecretInsertMgmt/SecretInsertMgmtCKKS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ bool isTensorInSlots(Operation *top, DataFlowSolver *solver, int slotNumber) {
LogicalResult result = walkAndValidateValues(
top,
[&](Value value) {
auto secretness =
solver->lookupState<SecretnessLattice>(value)->getValue();
if (secretness.isInitialized() && secretness.getSecretness()) {
auto secret = isSecret(value, solver);
if (secret) {
auto tensorTy = dyn_cast<RankedTensorType>(value.getType());
if (tensorTy) {
// TODO(#913): Multidimensional tensors with a single non-unit
Expand All @@ -88,10 +87,7 @@ bool isTensorInSlots(Operation *top, DataFlowSolver *solver, int slotNumber) {
void annotateTensorExtractAsNotSlotExtract(Operation *top,
DataFlowSolver *solver) {
top->walk([&](tensor::ExtractOp extractOp) {
auto secretness =
solver->lookupState<SecretnessLattice>(extractOp.getOperand(0))
->getValue();
if (secretness.isInitialized() && secretness.getSecretness()) {
if (isSecret(extractOp.getOperand(0), solver)) {
extractOp->setAttr("slot_extract",
BoolAttr::get(extractOp.getContext(), false));
}
Expand Down
22 changes: 8 additions & 14 deletions lib/Transforms/SecretInsertMgmt/SecretInsertMgmtPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,15 @@ template <typename MulOp>
LogicalResult MultRelinearize<MulOp>::matchAndRewrite(
MulOp mulOp, PatternRewriter &rewriter) const {
Value result = mulOp.getResult();
auto secret = solver->lookupState<SecretnessLattice>(result)->getValue();
// if not secret, skip
if (!secret.isInitialized() || !secret.getSecretness()) {
bool secret = isSecret(result, solver);
if (!secret) {
return success();
}

// if mul const, skip
for (auto operand : mulOp.getOperands()) {
auto secretness =
solver->lookupState<SecretnessLattice>(operand)->getValue();
if (!secretness.isInitialized() || !secretness.getSecretness()) {
auto secret = isSecret(operand, solver);
if (!secret) {
return success();
}
}
Expand All @@ -58,8 +56,8 @@ LogicalResult ModReduceBefore<Op>::matchAndRewrite(
// guard against secret::YieldOp
if (op->getResults().size() > 0) {
for (auto result : op->getResults()) {
auto secret = solver->lookupState<SecretnessLattice>(result)->getValue();
if (!secret.isInitialized() || !secret.getSecretness()) {
bool secret = isSecret(result, solver);
if (!secret) {
return success();
}
}
Expand All @@ -82,12 +80,8 @@ LogicalResult ModReduceBefore<Op>::matchAndRewrite(
// use map in case we have same operands
DenseMap<Value, LevelState::LevelType> operandsInsertLevel;
for (auto operand : op.getOperands()) {
auto secretness =
solver->lookupState<SecretnessLattice>(operand)->getValue();
if (!secretness.isInitialized()) {
return failure();
}
if (!secretness.getSecretness()) {
bool secret = isSecret(operand, solver);
if (!secret) {
continue;
}
auto levelLattice = solver->lookupState<LevelLattice>(operand)->getValue();
Expand Down

0 comments on commit 8303e53

Please sign in to comment.