diff --git a/lib/Transforms/LayoutPropagation/LayoutPropagation.cpp b/lib/Transforms/LayoutPropagation/LayoutPropagation.cpp index 63a2627a1..958c2fa52 100644 --- a/lib/Transforms/LayoutPropagation/LayoutPropagation.cpp +++ b/lib/Transforms/LayoutPropagation/LayoutPropagation.cpp @@ -36,6 +36,11 @@ using tensor_ext::LayoutAttr; #define GEN_PASS_DEF_LAYOUTPROPAGATION #include "lib/Transforms/LayoutPropagation/LayoutPropagation.h.inc" +struct CompatibilityResult { + bool compatible; + std::optional diag; +}; + struct LayoutPropagation : impl::LayoutPropagationBase { using LayoutPropagationBase::LayoutPropagationBase; @@ -60,14 +65,11 @@ struct LayoutPropagation : impl::LayoutPropagationBase { // Return true if the operand layouts are compatible for the operation, and // false if not. Include an InFlightDiagnostic if an operand is encountered // that requires a layout, but none has been set. - std::pair> - hasCompatibleArgumentLayouts(Operation *op); + CompatibilityResult hasCompatibleArgumentLayouts(Operation *op); // Op-specific compatibility functions - std::pair> - hasCompatibleArgumentLayouts(ReduceOp op); - std::pair> - hasCompatibleArgumentLayouts(VecmatOp op); + CompatibilityResult hasCompatibleArgumentLayouts(ReduceOp op); + CompatibilityResult hasCompatibleArgumentLayouts(VecmatOp op); // Insert conversion ops to rectify incompatible operand layouts void rectifyIncompatibleOperandLayouts(Operation *op); @@ -169,49 +171,65 @@ LogicalResult LayoutPropagation::visitOperation(Operation *op) { .Default([&](Operation *op) { return success(); }); } -std::pair> -LayoutPropagation::hasCompatibleArgumentLayouts(Operation *op) { - // FIXME: type switch on special case ops - if (isa(op)) { - return {true, std::nullopt}; - } +CompatibilityResult LayoutPropagation::hasCompatibleArgumentLayouts( + Operation *op) { + return TypeSwitch(op) + // Trivially true ops + .Case( + [&](auto op) { return CompatibilityResult{true, std::nullopt}; }) + // Ops with special rules + .Case([&](auto op) { return visitOperation(op); }) + // By default, assume operands must all have the same layout. + .Default([&](Operation *op) { + std::optional firstFoundLayout; + + for (auto &operand : op->getOpOperands()) { + if (isa(operand.get().getType())) { + if (!assignedLayouts.contains(operand.get())) { + // If the operand has no layout, we can't propagate layout + // information to the result. + return CompatibilityResult{ + false, op->emitError("operand has no assigned layout")}; + } + AffineMap layout = assignedLayouts.at(operand.get()); + + if (!firstFoundLayout.has_value()) firstFoundLayout = layout; + if (layout != firstFoundLayout.value()) { + return CompatibilityResult{false, std::nullopt}; + } + } + } - if (isa(op)) { - // Currently only support secret vectors and plaintext matrices. - auto vecmatOp = cast(op); - Value vec = vecmatOp.lhs(); - Value mat = vecmatOp.rhs(); - if (isSecret(mat, solver) || !isSecret(vec, solver)) { - return {false, - op->emitError("Only secret vectors and plaintext matrices are " - "supported for linalg.vecmat")}; - } + return CompatibilityResult{true, std::nullopt}; + }); +} - if (!assignedLayouts.contains(vec)) { - return {false, op->emitError("vector operand has no assigned layout")}; - } - return {true, std::nullopt}; +CompatibilityResult LayoutPropagation::hasCompatibleArgumentLayouts( + ReduceOp op) { + // The arguments of a ReduceOp are the tensor(s) to reduce and the + // initializer values for the reduction. + for (const auto &[input, init] : llvm::zip(op.getInputs(), op.getInits())) { + // FIXME: what is compatible here? } - // By default, assume operands must all have the same layout. - std::optional firstFoundLayout; - - for (auto &operand : op->getOpOperands()) { - if (isa(operand.get().getType())) { - if (!assignedLayouts.contains(operand.get())) { - // If the operand has no layout, we can't propagate layout - // information to the result. - return {false, op->emitError("operand has no assigned layout")}; - } - AffineMap layout = assignedLayouts.at(operand.get()); + return {true, std::nullopt}; +} - if (!firstFoundLayout.has_value()) firstFoundLayout = layout; - if (layout != firstFoundLayout.value()) { - return {false, std::nullopt}; - } - } +CompatibilityResult LayoutPropagation::hasCompatibleArgumentLayouts( + VecmatOp op) { + // Currently only support secret vectors and plaintext matrices. + auto vecmatOp = cast(op); + Value vec = vecmatOp.lhs(); + Value mat = vecmatOp.rhs(); + if (isSecret(mat, solver) || !isSecret(vec, solver)) { + return {false, + op->emitError("Only secret vectors and plaintext matrices are " + "supported for linalg.vecmat")}; } + if (!assignedLayouts.contains(vec)) { + return {false, op->emitError("vector operand has no assigned layout")}; + } return {true, std::nullopt}; }