Skip to content

Commit

Permalink
snapshot partial progress
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Jan 31, 2025
1 parent 61d3085 commit 95bcbd1
Showing 1 changed file with 60 additions and 42 deletions.
102 changes: 60 additions & 42 deletions lib/Transforms/LayoutPropagation/LayoutPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ using tensor_ext::LayoutAttr;
#define GEN_PASS_DEF_LAYOUTPROPAGATION
#include "lib/Transforms/LayoutPropagation/LayoutPropagation.h.inc"

struct CompatibilityResult {
bool compatible;
std::optional<InFlightDiagnostic> diag;
};

struct LayoutPropagation : impl::LayoutPropagationBase<LayoutPropagation> {
using LayoutPropagationBase::LayoutPropagationBase;

Expand All @@ -60,14 +65,11 @@ struct LayoutPropagation : impl::LayoutPropagationBase<LayoutPropagation> {
// Return true if the operand layouts are compatible for the operation, and
// false if not. Include an InFlightDiagnostic if an operand is encountered
// that requires a layout, but none has been set.
std::pair<bool, std::optional<InFlightDiagnostic>>
hasCompatibleArgumentLayouts(Operation *op);
CompatibilityResult hasCompatibleArgumentLayouts(Operation *op);

// Op-specific compatibility functions
std::pair<bool, std::optional<InFlightDiagnostic>>
hasCompatibleArgumentLayouts(ReduceOp op);
std::pair<bool, std::optional<InFlightDiagnostic>>
hasCompatibleArgumentLayouts(VecmatOp op);
CompatibilityResult hasCompatibleArgumentLayouts(ReduceOp op);
CompatibilityResult hasCompatibleArgumentLayouts(VecmatOp op);

// Insert conversion ops to rectify incompatible operand layouts
void rectifyIncompatibleOperandLayouts(Operation *op);
Expand Down Expand Up @@ -169,49 +171,65 @@ LogicalResult LayoutPropagation::visitOperation(Operation *op) {
.Default([&](Operation *op) { return success(); });
}

std::pair<bool, std::optional<InFlightDiagnostic>>
LayoutPropagation::hasCompatibleArgumentLayouts(Operation *op) {
// FIXME: type switch on special case ops
if (isa<func::FuncOp, GenericOp, YieldOp>(op)) {
return {true, std::nullopt};
}
CompatibilityResult LayoutPropagation::hasCompatibleArgumentLayouts(
Operation *op) {
return TypeSwitch<Operation *, CompatibilityResult>(op)
// Trivially true ops
.Case<func::FuncOp, GenericOp, YieldOp>(
[&](auto op) { return CompatibilityResult{true, std::nullopt}; })
// Ops with special rules
.Case<ReduceOp, VecmatOp>([&](auto op) { return visitOperation(op); })
// By default, assume operands must all have the same layout.
.Default([&](Operation *op) {
std::optional<AffineMap> firstFoundLayout;

for (auto &operand : op->getOpOperands()) {
if (isa<RankedTensorType>(operand.get().getType())) {
if (!assignedLayouts.contains(operand.get())) {
// If the operand has no layout, we can't propagate layout
// information to the result.
return CompatibilityResult{
false, op->emitError("operand has no assigned layout")};
}
AffineMap layout = assignedLayouts.at(operand.get());

if (!firstFoundLayout.has_value()) firstFoundLayout = layout;
if (layout != firstFoundLayout.value()) {
return CompatibilityResult{false, std::nullopt};
}
}
}

if (isa<VecmatOp>(op)) {
// Currently only support secret vectors and plaintext matrices.
auto vecmatOp = cast<linalg::ContractionOpInterface>(op);
Value vec = vecmatOp.lhs();
Value mat = vecmatOp.rhs();
if (isSecret(mat, solver) || !isSecret(vec, solver)) {
return {false,
op->emitError("Only secret vectors and plaintext matrices are "
"supported for linalg.vecmat")};
}
return CompatibilityResult{true, std::nullopt};
});
}

if (!assignedLayouts.contains(vec)) {
return {false, op->emitError("vector operand has no assigned layout")};
}
return {true, std::nullopt};
CompatibilityResult LayoutPropagation::hasCompatibleArgumentLayouts(
ReduceOp op) {
// The arguments of a ReduceOp are the tensor(s) to reduce and the
// initializer values for the reduction.
for (const auto &[input, init] : llvm::zip(op.getInputs(), op.getInits())) {
// FIXME: what is compatible here?
}

// By default, assume operands must all have the same layout.
std::optional<AffineMap> firstFoundLayout;

for (auto &operand : op->getOpOperands()) {
if (isa<RankedTensorType>(operand.get().getType())) {
if (!assignedLayouts.contains(operand.get())) {
// If the operand has no layout, we can't propagate layout
// information to the result.
return {false, op->emitError("operand has no assigned layout")};
}
AffineMap layout = assignedLayouts.at(operand.get());
return {true, std::nullopt};
}

if (!firstFoundLayout.has_value()) firstFoundLayout = layout;
if (layout != firstFoundLayout.value()) {
return {false, std::nullopt};
}
}
CompatibilityResult LayoutPropagation::hasCompatibleArgumentLayouts(
VecmatOp op) {
// Currently only support secret vectors and plaintext matrices.
auto vecmatOp = cast<linalg::ContractionOpInterface>(op);
Value vec = vecmatOp.lhs();
Value mat = vecmatOp.rhs();
if (isSecret(mat, solver) || !isSecret(vec, solver)) {
return {false,
op->emitError("Only secret vectors and plaintext matrices are "
"supported for linalg.vecmat")};
}

if (!assignedLayouts.contains(vec)) {
return {false, op->emitError("vector operand has no assigned layout")};
}
return {true, std::nullopt};
}

Expand Down

0 comments on commit 95bcbd1

Please sign in to comment.