Skip to content

[mlir][SCF][GPU] Add DeviceMaskingAttrInterface #146943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 7, 2025

Conversation

nicolasvasilache
Copy link
Contributor

@nicolasvasilache nicolasvasilache commented Jul 3, 2025

This revision adds DeviceMaskingAttrInterface and extends DeviceMappingArrayAttr to accept a union of DeviceMappingAttrInterface and DeviceMaskingAttrInterface.

Support is added to GPUTransformOps to take advantage of this information and lower to block/warpgroup/warp/thread specialization when mapped to linear ids.

The revision also connects to scf::ForallOp and uses the new attribute to implement warp specialization.
The implementation is in the form of a GPUMappingMaskAttr, which can be additionally passed to the scf.forall.mapping attribute to specify a mask on compute resources that should be active.

In the first implementation the masking is a bitfield that specifies for each processing unit whether it is active or not.
In the future, we may want to implement this as a symbol to refer to dynamically defined values.
Extending op semantics with an operand is deemed too intrusive at this time.

@llvmbot
Copy link
Member

llvmbot commented Jul 3, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir-gpu

Author: Nicolas Vasilache (nicolasvasilache)

Changes

…lOp and use it to implement warp specialization.

This revision adds DeviceMaskingAttrInterface and extends DeviceMappingArrayAttr to accept a union of DeviceMappingAttrInterface and DeviceMaskingAttrInterface.

The first implementation is if the form of a GPUMappingMaskAttr, which can be additionally passed to the scf.forall.mapping attribute to specify a mask on compute resources that should be active.

Support is added to GPUTransformOps to take advantage of this information and lower to block/warpgroup/warp/thread specialization when mapped to linear ids.


Patch is 35.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146943.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td (+18)
  • (modified) mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h (+10-5)
  • (modified) mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td (+44-1)
  • (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+12)
  • (modified) mlir/lib/Dialect/GPU/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+45)
  • (modified) mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp (+39-19)
  • (modified) mlir/lib/Dialect/GPU/TransformOps/Utils.cpp (+73-27)
  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+37-6)
  • (modified) mlir/test/Dialect/GPU/transform-gpu-failing.mlir (+61)
  • (modified) mlir/test/Dialect/GPU/transform-gpu.mlir (+81)
  • (modified) mlir/test/Dialect/SCF/invalid.mlir (+18)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td b/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td
index 63f228ca3157f..e8540027e7b77 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td
@@ -252,6 +252,24 @@ def GPULaneMappingAttr
   }];
 }
 
+def GPUMappingMaskAttr : GPU_Attr<"GPUMappingMask", "mask", [
+  DeclareAttrInterfaceMethods<DeviceMaskingAttrInterface> ] >  {
+  let parameters = (ins "uint64_t":$mask);
+  let assemblyFormat = "`<` params `>`";
+  let description = [{
+    Attribute describing how to filter the processing units that a
+    region is mapped to.
+
+    In the first implementation the masking is a bitfield that specifies for
+    each processing unit whether it is active or not.
+
+    In the future, we may want to implement this as a symbol to refer to
+    dynamically defined values.
+
+    Extending op semantics with an operand is deemed too intrusive at this time.
+  }];
+}
+
 def GPUMemorySpaceMappingAttr : GPU_Attr<"GPUMemorySpaceMapping", "memory_space", [
   DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] >  {
   let parameters = (ins
diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
index de512ded59fec..0a11b8f8d3fa0 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
@@ -78,7 +78,8 @@ struct GpuIdBuilder {
 /// If `useLinearMapping` is true, the `idBuilder` method returns nD values
 /// used for indexing rewrites as well as 1D sizes for predicate generation.
 struct GpuBlockIdBuilder : public GpuIdBuilder {
-  GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping = false);
+  GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping = false,
+                    DeviceMaskingAttrInterface mask = nullptr);
 };
 
 /// Builder for warpgroup ids used to map scf.forall to reindexed warpgroups.
@@ -88,7 +89,8 @@ struct GpuBlockIdBuilder : public GpuIdBuilder {
 /// used for indexing rewrites as well as 1D sizes for predicate generation.
 struct GpuWarpgroupIdBuilder : public GpuIdBuilder {
   GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize,
-                        bool useLinearMapping = false);
+                        bool useLinearMapping = false,
+                        DeviceMaskingAttrInterface mask = nullptr);
   int64_t warpSize = 32;
   /// In the future this may be configured by the transformation.
   static constexpr int64_t kNumWarpsPerGroup = 4;
@@ -101,7 +103,8 @@ struct GpuWarpgroupIdBuilder : public GpuIdBuilder {
 /// used for indexing rewrites as well as 1D sizes for predicate generation.
 struct GpuWarpIdBuilder : public GpuIdBuilder {
   GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize,
-                   bool useLinearMapping = false);
+                   bool useLinearMapping = false,
+                   DeviceMaskingAttrInterface mask = nullptr);
   int64_t warpSize = 32;
 };
 
@@ -111,7 +114,8 @@ struct GpuWarpIdBuilder : public GpuIdBuilder {
 /// If `useLinearMapping` is true, the `idBuilder` method returns nD values
 /// used for indexing rewrites as well as 1D sizes for predicate generation.
 struct GpuThreadIdBuilder : public GpuIdBuilder {
-  GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping = false);
+  GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping = false,
+                     DeviceMaskingAttrInterface mask = nullptr);
 };
 
 /// Builder for lane id.
@@ -119,7 +123,8 @@ struct GpuThreadIdBuilder : public GpuIdBuilder {
 /// as 1D sizes for predicate generation.
 /// This `useLinearMapping` case is the only supported case.
 struct GpuLaneIdBuilder : public GpuIdBuilder {
-  GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize, bool unused);
+  GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize, bool unused,
+                   DeviceMaskingAttrInterface mask = nullptr);
   int64_t warpSize = 32;
 };
 
diff --git a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
index 96db2a40cf58e..353aaf05bee0c 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
@@ -60,8 +60,51 @@ def DeviceMappingAttrInterface : AttrInterface<"DeviceMappingAttrInterface"> {
   ];
 }
 
+def DeviceMaskingAttrInterface : AttrInterface<"DeviceMaskingAttrInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    Attribute interface describing how to filter the processing units that a
+    region is mapped to.
+
+    A popcount can be applied to determine the logical linear index that a
+    physical processing unit is responsible for.
+  }];
+
+ let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the logical active id for a given physical id.
+        Expects a physicalLinearMappingId of I64Type.
+      }],
+      /*retTy=*/"Value",
+      /*methodName=*/"getLogicalLinearMappingId",
+      /*args=*/(ins "OpBuilder&":$builder, "Value":$physicalLinearMappingId)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the dynamic condition determining whether a given physical id is
+        active under the mask.
+        Expects a physicalLinearMappingId of I64Type.
+      }],
+      /*retTy=*/"Value",
+      /*methodName=*/"getIsActiveIdPredicate",
+      /*args=*/(ins "OpBuilder&":$builder, "Value":$physicalLinearMappingId)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the maximal number of pysical ids supported.
+        This is to account for temporary implementation limitations (e.g. i64)
+        and fail gracefully with actionnable error messages.
+      }],
+      /*retTy=*/"int64_t",
+      /*methodName=*/"getMaxNumPhysicalIds",
+      /*args=*/(ins)
+    >,
+  ];
+}
+
 def DeviceMappingArrayAttr :
-  TypedArrayAttrBase<DeviceMappingAttrInterface,
+  TypedArrayAttrBase<AnyAttrOf<[DeviceMappingAttrInterface, DeviceMaskingAttrInterface]>,
   "Device Mapping array attribute"> { }
 
 #endif // MLIR_DEVICEMAPPINGINTERFACE
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 8b14cef7437d4..2d15544e871b3 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -611,6 +611,18 @@ def ForallOp : SCF_Op<"forall", [
     /// Returns operations within scf.forall.in_parallel whose destination
     /// operand is the block argument `bbArg`.
     SmallVector<Operation*> getCombiningOps(BlockArgument bbArg);
+
+    /// Returns the subset of DeviceMappingArrayAttrs of type
+    /// DeviceMappingAttrInterface.
+    SmallVector<DeviceMappingAttrInterface> getDeviceMappingAttrs();
+
+    /// Returns the at most one DeviceMaskingAttrInterface in the mapping.
+    /// If more than one DeviceMaskingAttrInterface is specified, returns
+    /// failure. If no mapping is present, returns nullptr.
+    FailureOr<DeviceMaskingAttrInterface> getDeviceMaskingAttr();
+
+    /// Returns true if the mapping specified for this forall op is linear.
+    bool usesLinearMapping();
   }];
 }
 
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index c8c53374d676b..4862d1f722785 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRGPUDialect
   MLIRFunctionInterfaces
   MLIRInferIntRangeInterface
   MLIRIR
+  MLIRMathDialect
   MLIRMemRefDialect
   MLIRSideEffectInterfaces
   MLIRSupport
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 56631f1aac084..9d74c23c24cc8 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -120,6 +121,50 @@ int64_t GPULaneMappingAttr::getRelativeIndex() const {
              : getMappingId();
 }
 
+int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; }
+
+///                 8       4       0
+/// Example mask  : 0 0 0 1 1 0 1 0 0
+///
+/// Active physical (resp. logical) is  2 (0), 4 (1) and 5 (2).
+/// Logical id for e.g. 5 (2) constructs filter (1 << 5 - 1).
+///
+/// Example mask  : 0 0 0 1 1 0 1 0 0
+/// Example filter: 0 0 0 0 1 1 1 1 1
+/// Intersection  : 0 0 0 0 1 0 1 0 0
+/// PopCnt        : 2
+Value GPUMappingMaskAttr::getLogicalLinearMappingId(
+    OpBuilder &b, Value physicalLinearMappingId) const {
+  Location loc = physicalLinearMappingId.getLoc();
+  Value mask = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(getMask()));
+  Value one = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
+  Value filter = b.create<arith::ShLIOp>(loc, one, physicalLinearMappingId);
+  filter = b.create<arith::SubIOp>(loc, filter, one);
+  Value filteredId = b.create<arith::AndIOp>(loc, mask, filter);
+  return b.create<math::CtPopOp>(loc, filteredId);
+}
+
+///                 8       4       0
+/// Example mask  : 0 0 0 1 1 0 1 0 0
+///
+/// Active physical (resp. logical) is  2 (0), 4 (1) and 5 (2).
+/// Logical id for e.g. 5 (2) constructs filter (1 << 5).
+///
+/// Example mask  : 0 0 0 1 1 0 1 0 0
+/// Example filter: 0 0 0 1 0 0 0 0 0
+/// Intersection  : 0 0 0 1 0 0 0 0 0
+/// Cmp           : 1
+Value GPUMappingMaskAttr::getIsActiveIdPredicate(
+    OpBuilder &b, Value physicalLinearMappingId) const {
+  Location loc = physicalLinearMappingId.getLoc();
+  Value mask = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(getMask()));
+  Value one = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
+  Value filter = b.create<arith::ShLIOp>(loc, one, physicalLinearMappingId);
+  Value filtered = b.create<arith::AndIOp>(loc, mask, filter);
+  Value zero = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(0));
+  return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, filtered, zero);
+}
+
 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
   return static_cast<int64_t>(getAddressSpace());
 }
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 63f87d9b5877e..a8eaa20928b7f 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -351,16 +351,25 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
     seen.insert(map);
   }
 
-  auto isLinear = [](Attribute a) {
-    return cast<DeviceMappingAttrInterface>(a).isLinearMapping();
+  auto isLinear = [](DeviceMappingAttrInterface attr) {
+    return attr.isLinearMapping();
   };
-  if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) &&
-      !llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) {
+  if (llvm::any_of(forallOp.getDeviceMappingAttrs(), isLinear) &&
+      !llvm::all_of(forallOp.getDeviceMappingAttrs(), isLinear)) {
     return definiteFailureHelper(
         transformOp, forallOp,
         "cannot mix linear and non-linear mapping modes");
   }
 
+  FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
+      forallOp.getDeviceMaskingAttr();
+  if (succeeded(maybeMaskingAttr) && *maybeMaskingAttr &&
+      !forallOp.usesLinearMapping()) {
+    return definiteFailureHelper(
+        transformOp, forallOp,
+        "device masking is only available in linear mapping mode");
+  }
+
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -381,9 +390,7 @@ verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
   if (forallOp.getNumResults() > 0)
     return definiteFailureHelper(transformOp, forallOp,
                                  "only bufferized scf.forall can be mapped");
-  bool useLinearMapping = cast<DeviceMappingAttrInterface>(
-                              forallOp.getMapping()->getValue().front())
-                              .isLinearMapping();
+  bool useLinearMapping = forallOp.usesLinearMapping();
   // TODO: This would be more natural with support for Optional<EnumParameter>
   // in GPUDeviceMappingAttr.
   int64_t maxNumMappingsSupported =
@@ -682,12 +689,17 @@ DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
 
   // The BlockIdBuilder adapts to whatever is thrown at it.
   bool useLinearMapping = false;
-  if (topLevelForallOp.getMapping()) {
-    auto mappingAttr = cast<DeviceMappingAttrInterface>(
-        topLevelForallOp.getMapping()->getValue().front());
-    useLinearMapping = mappingAttr.isLinearMapping();
-  }
-  GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping);
+  if (topLevelForallOp.getMapping())
+    useLinearMapping = topLevelForallOp.usesLinearMapping();
+
+  FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
+      topLevelForallOp.getDeviceMaskingAttr();
+  assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr");
+  assert((!*maybeMaskingAttr || useLinearMapping) &&
+         "masking requires linear mapping");
+
+  GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping,
+                                      *maybeMaskingAttr);
 
   diag = mlir::transform::gpu::mapForallToBlocksImpl(
       rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
@@ -744,8 +756,7 @@ static DiagnosedSilenceableFailure
 getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
                    scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
                    int64_t warpSize, GpuIdBuilder &gpuIdBuilder) {
-  auto mappingAttr = cast<DeviceMappingAttrInterface>(
-      forallOp.getMapping()->getValue().front());
+  auto mappingAttr = forallOp.getDeviceMappingAttrs().front();
   bool useLinearMapping = mappingAttr.isLinearMapping();
 
   // Sanity checks that may result in runtime verification errors.
@@ -768,21 +779,30 @@ getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
   if (!diag.succeeded())
     return diag;
 
+  FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
+      forallOp.getDeviceMaskingAttr();
+  assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr");
+  assert((!*maybeMaskingAttr || useLinearMapping) &&
+         "masking requires linear mapping");
+
   // Start mapping.
   MLIRContext *ctx = forallOp.getContext();
   gpuIdBuilder =
       TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr)
           .Case([&](GPUWarpgroupMappingAttr) {
-            return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping);
+            return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping,
+                                         *maybeMaskingAttr);
           })
           .Case([&](GPUWarpMappingAttr) {
-            return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping);
+            return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping,
+                                    *maybeMaskingAttr);
           })
           .Case([&](GPUThreadMappingAttr) {
-            return GpuThreadIdBuilder(ctx, useLinearMapping);
+            return GpuThreadIdBuilder(ctx, useLinearMapping, *maybeMaskingAttr);
           })
           .Case([&](GPULaneMappingAttr) {
-            return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping);
+            return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping,
+                                    *maybeMaskingAttr);
           })
           .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
             llvm_unreachable("unknown mapping attribute");
diff --git a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
index 795d643c05912..d1969dbc82997 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
@@ -44,7 +44,7 @@ using namespace mlir::transform::gpu;
 #define DEBUG_TYPE "gpu-transforms"
 
 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
 
 /// Build predicates to filter execution by only the activeIds. Along each
@@ -120,10 +120,23 @@ static Value buildLinearId(RewriterBase &rewriter, Location loc,
 /// it in the basis of `forallMappingSizes`. The linear id builder returns an
 /// n-D vector of ids for indexing and 1-D size + id for predicate generation.
 template <typename ThreadOrBlockIdOp>
-static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
-  auto res = [multiplicity](RewriterBase &rewriter, Location loc,
-                            ArrayRef<int64_t> forallMappingSizes,
-                            ArrayRef<int64_t> originalBasis) {
+static GpuIdBuilderFnType
+commonLinearIdBuilderFn(int64_t multiplicity = 1,
+                        DeviceMaskingAttrInterface mask = nullptr) {
+  auto res = [multiplicity, mask](RewriterBase &rewriter, Location loc,
+                                  ArrayRef<int64_t> forallMappingSizes,
+                                  ArrayRef<int64_t> originalBasis) {
+    // 0. Early-exit mask case.
+    if (mask) {
+      if (computeProduct(originalBasis) >
+          mask.getMaxNumPhysicalIds() * multiplicity) {
+        return IdBuilderResult{
+            /*errorMsg=*/std::string(
+                "mask representation too short to capture all physical ids: ") +
+            std::to_string(mask.getMaxNumPhysicalIds())};
+      }
+    }
+
     // 1. Compute linearId.
     SmallVector<OpFoldResult> originalBasisOfr =
         getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
@@ -132,9 +145,25 @@ static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
 
     // 2. Compute scaledLinearId.
     AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
-    OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply(
+    OpFoldResult scaledLinearIdOfr = affine::makeComposedFoldedAffineApply(
         rewriter, loc, d0.floorDiv(multiplicity), {physicalLinearId});
 
+    // 2.b. Adjust with mask if needed.
+    Value scaledLinearIdI64;
+    Value scaledLinearId =
+        getValueOrCreateConstantIndexOp(rewriter, loc, scaledLinearIdOfr);
+    if (mask) {
+      scaledLinearId =
+          getValueOrCreateConstantIndexOp(rewriter, loc, scaledLinearIdOfr);
+      scaledLinearIdI64 = rewriter.create<arith::IndexCastUIOp>(
+          loc, rewriter.getI64Type(), scaledLinearId);
+      Value logicalLinearIdI64 =
+          mask.getLogicalLinearMappingId(rewriter, scaledLinearIdI64);
+      scaledLinearId = rewriter.create<arith::IndexCastUIOp>(
+          loc, rewriter.getIndexType(), logicalLinearIdI64);
+      LDBG("------adjusting linearId with mask: " << scaledLinearId);
+    }
+
     // 3. Compute remapped indices.
     SmallVector<Value> ids;
     // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
@@ -148,15 +177,23 @@ static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
           affine::makeComposedAffineApply(rewriter, loc, e, {scaledLinearId}));
     }
 
-    // 4. Handle predicates using physicalLinearId.
     std::string errorMsg;
     SmallVector<Value> predicateOps;
-    FailureOr<SmallVector<Value>> maybePredicateOps =
-        buildPredicates(rewriter, loc, physicalLinearId,
-                        computeProduct(forallMappingSizes) * multiplicity,
-                        computeProduct(originalBasis), errorMsg);
-    if (succeeded(maybePredicateOps))
-      predicateOps = *maybePredicateOps;
+    // 4. If mask present, it takes precedence to determine predication.
+    if (mask) {
+      Value isActiveIdPredicate =
+          mask.getIsActiveIdPredicate(rewriter, scaledLinearIdI64);
+      LDBG("------adjusting predicate with mask: " << isActiveIdPredicate);
+      predicateOps.push_back(isActiveIdPredicate);
+    } else {
+      // 4.b. Otherwise, handle predicates using physicalLinearId.
+      FailureOr<SmallVector<Value>> maybePredicateOps =
+          buildPredicates(rewriter, loc, physicalLinearId,
+                          computeProduct(forallMappingSizes) * multiplicity,
+  ...
[truncated]

@nicolasvasilache nicolasvasilache force-pushed the users/nico/map-to-lanes-3 branch 4 times, most recently from ad456bb to 85aa5f8 Compare July 3, 2025 20:54
Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address comments, okay otherwise

@nicolasvasilache nicolasvasilache force-pushed the users/nico/map-to-lanes-2 branch 2 times, most recently from ed41e82 to 76ecaf8 Compare July 7, 2025 13:24
Base automatically changed from users/nico/map-to-lanes-2 to main July 7, 2025 13:42
…lOp and use it to implement warp specialization.

This revision adds DeviceMaskingAttrInterface and extends
DeviceMappingArrayAttr to accept a union of DeviceMappingAttrInterface
and DeviceMaskingAttrInterface.

The first implementation is if the form of a GPUMappingMaskAttr, which
can be additionally passed to the scf.forall.mapping attribute to
specify a mask on compute resources that should be active.

Support is added to GPUTransformOps to take advantage of this
information and lower to block/warpgroup/warp/thread specialization when
mapped to linear ids.

Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]>
@joker-eph
Copy link
Collaborator

Nit: Please fix the title (annoyingly) wrapped by GitHub here.

@nicolasvasilache nicolasvasilache changed the title [mlir][SCF][GPU] Add DeviceMaskingAttrInterface support to scf::Foral… [mlir][SCF][GPU] Add DeviceMaskingAttrInterface Jul 7, 2025
@nicolasvasilache nicolasvasilache force-pushed the users/nico/map-to-lanes-3 branch from 0cb1bc9 to 8a0fe9c Compare July 7, 2025 15:53
Copy link

github-actions bot commented Jul 7, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@nicolasvasilache nicolasvasilache force-pushed the users/nico/map-to-lanes-3 branch from 8a0fe9c to fd39e3f Compare July 7, 2025 16:03
@nicolasvasilache nicolasvasilache merged commit 2b28d10 into main Jul 7, 2025
7 of 9 checks passed
@nicolasvasilache nicolasvasilache deleted the users/nico/map-to-lanes-3 branch July 7, 2025 16:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants