Skip to content

[LV] Unify interleaved load handling for fixed and scalable VFs. nfc #146914

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Mel-Chen
Copy link
Contributor

@Mel-Chen Mel-Chen commented Jul 3, 2025

This patch modifies VPInterleaveRecipe::execute to handle both fixed and scalable VFs using a single loop.

@llvmbot
Copy link
Member

llvmbot commented Jul 3, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Mel Chen (Mel-Chen)

Changes

This patch modifies VPInterleaveRecipe::execute to handle both fixed and scalable VFs using a single loop.


Full diff: https://github.com/llvm/llvm-project/pull/146914.diff

1 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+17-35)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 06511b61a67c3..2fc2447deb3fb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -3442,7 +3442,6 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
   VPValue *BlockInMask = getMask();
   VPValue *Addr = getAddr();
   Value *ResAddr = State.get(Addr, VPLane(0));
-  Value *PoisonVec = PoisonValue::get(VecTy);
 
   auto CreateGroupMask = [&BlockInMask, &State,
                           &InterleaveFactor](Value *MaskForGaps) -> Value * {
@@ -3481,6 +3480,7 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
     Instruction *NewLoad;
     if (BlockInMask || MaskForGaps) {
       Value *GroupMask = CreateGroupMask(MaskForGaps);
+      Value *PoisonVec = PoisonValue::get(VecTy);
       NewLoad = State.Builder.CreateMaskedLoad(VecTy, ResAddr,
                                                Group->getAlign(), GroupMask,
                                                PoisonVec, "wide.masked.vec");
@@ -3490,57 +3490,39 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
     Group->addMetadata(NewLoad);
 
     ArrayRef<VPValue *> VPDefs = definedValues();
-    const DataLayout &DL = State.CFG.PrevBB->getDataLayout();
     if (VecTy->isScalableTy()) {
       // Scalable vectors cannot use arbitrary shufflevectors (only splats),
       // so must use intrinsics to deinterleave.
       assert(InterleaveFactor <= 8 &&
              "Unsupported deinterleave factor for scalable vectors");
-      Value *Deinterleave = State.Builder.CreateIntrinsic(
+      NewLoad = State.Builder.CreateIntrinsic(
           getDeinterleaveIntrinsicID(InterleaveFactor), NewLoad->getType(),
           NewLoad,
           /*FMFSource=*/nullptr, "strided.vec");
+    }
 
-      for (unsigned I = 0, J = 0; I < InterleaveFactor; ++I) {
-        Instruction *Member = Group->getMember(I);
-        Value *StridedVec = State.Builder.CreateExtractValue(Deinterleave, I);
-        if (!Member) {
-          // This value is not needed as it's not used
-          cast<Instruction>(StridedVec)->eraseFromParent();
-          continue;
-        }
-        // If this member has different type, cast the result type.
-        if (Member->getType() != ScalarTy) {
-          VectorType *OtherVTy = VectorType::get(Member->getType(), State.VF);
-          StridedVec =
-              createBitOrPointerCast(State.Builder, StridedVec, OtherVTy, DL);
-        }
-
-        if (Group->isReverse())
-          StridedVec = State.Builder.CreateVectorReverse(StridedVec, "reverse");
-
-        State.set(VPDefs[J], StridedVec);
-        ++J;
-      }
+    auto CreateStridedVector = [&InterleaveFactor, &State,
+                                &NewLoad](unsigned Index) -> Value * {
+      assert(Index < InterleaveFactor && "Illegal group index");
+      if (State.VF.isScalable())
+        return State.Builder.CreateExtractValue(NewLoad, Index);
 
-      return;
-    }
-    assert(!State.VF.isScalable() && "VF is assumed to be non scalable.");
+      // For fixed length VF, use shuffle to extract the sub-vectors from the
+      // wide load.
+      auto StrideMask =
+          createStrideMask(Index, InterleaveFactor, State.VF.getFixedValue());
+      return State.Builder.CreateShuffleVector(NewLoad, StrideMask,
+                                               "strided.vec");
+    };
 
-    // For each member in the group, shuffle out the appropriate data from the
-    // wide loads.
-    unsigned J = 0;
-    for (unsigned I = 0; I < InterleaveFactor; ++I) {
+    for (unsigned I = 0, J = 0; I < InterleaveFactor; ++I) {
       Instruction *Member = Group->getMember(I);
 
       // Skip the gaps in the group.
       if (!Member)
         continue;
 
-      auto StrideMask =
-          createStrideMask(I, InterleaveFactor, State.VF.getFixedValue());
-      Value *StridedVec =
-          State.Builder.CreateShuffleVector(NewLoad, StrideMask, "strided.vec");
+      Value *StridedVec = CreateStridedVector(I);
 
       // If this member has different type, cast the result type.
       if (Member->getType() != ScalarTy) {

@llvmbot
Copy link
Member

llvmbot commented Jul 3, 2025

@llvm/pr-subscribers-vectorizers

Author: Mel Chen (Mel-Chen)

Changes

This patch modifies VPInterleaveRecipe::execute to handle both fixed and scalable VFs using a single loop.


Full diff: https://github.com/llvm/llvm-project/pull/146914.diff

1 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+17-35)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 06511b61a67c3..2fc2447deb3fb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -3442,7 +3442,6 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
   VPValue *BlockInMask = getMask();
   VPValue *Addr = getAddr();
   Value *ResAddr = State.get(Addr, VPLane(0));
-  Value *PoisonVec = PoisonValue::get(VecTy);
 
   auto CreateGroupMask = [&BlockInMask, &State,
                           &InterleaveFactor](Value *MaskForGaps) -> Value * {
@@ -3481,6 +3480,7 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
     Instruction *NewLoad;
     if (BlockInMask || MaskForGaps) {
       Value *GroupMask = CreateGroupMask(MaskForGaps);
+      Value *PoisonVec = PoisonValue::get(VecTy);
       NewLoad = State.Builder.CreateMaskedLoad(VecTy, ResAddr,
                                                Group->getAlign(), GroupMask,
                                                PoisonVec, "wide.masked.vec");
@@ -3490,57 +3490,39 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
     Group->addMetadata(NewLoad);
 
     ArrayRef<VPValue *> VPDefs = definedValues();
-    const DataLayout &DL = State.CFG.PrevBB->getDataLayout();
     if (VecTy->isScalableTy()) {
       // Scalable vectors cannot use arbitrary shufflevectors (only splats),
       // so must use intrinsics to deinterleave.
       assert(InterleaveFactor <= 8 &&
              "Unsupported deinterleave factor for scalable vectors");
-      Value *Deinterleave = State.Builder.CreateIntrinsic(
+      NewLoad = State.Builder.CreateIntrinsic(
           getDeinterleaveIntrinsicID(InterleaveFactor), NewLoad->getType(),
           NewLoad,
           /*FMFSource=*/nullptr, "strided.vec");
+    }
 
-      for (unsigned I = 0, J = 0; I < InterleaveFactor; ++I) {
-        Instruction *Member = Group->getMember(I);
-        Value *StridedVec = State.Builder.CreateExtractValue(Deinterleave, I);
-        if (!Member) {
-          // This value is not needed as it's not used
-          cast<Instruction>(StridedVec)->eraseFromParent();
-          continue;
-        }
-        // If this member has different type, cast the result type.
-        if (Member->getType() != ScalarTy) {
-          VectorType *OtherVTy = VectorType::get(Member->getType(), State.VF);
-          StridedVec =
-              createBitOrPointerCast(State.Builder, StridedVec, OtherVTy, DL);
-        }
-
-        if (Group->isReverse())
-          StridedVec = State.Builder.CreateVectorReverse(StridedVec, "reverse");
-
-        State.set(VPDefs[J], StridedVec);
-        ++J;
-      }
+    auto CreateStridedVector = [&InterleaveFactor, &State,
+                                &NewLoad](unsigned Index) -> Value * {
+      assert(Index < InterleaveFactor && "Illegal group index");
+      if (State.VF.isScalable())
+        return State.Builder.CreateExtractValue(NewLoad, Index);
 
-      return;
-    }
-    assert(!State.VF.isScalable() && "VF is assumed to be non scalable.");
+      // For fixed length VF, use shuffle to extract the sub-vectors from the
+      // wide load.
+      auto StrideMask =
+          createStrideMask(Index, InterleaveFactor, State.VF.getFixedValue());
+      return State.Builder.CreateShuffleVector(NewLoad, StrideMask,
+                                               "strided.vec");
+    };
 
-    // For each member in the group, shuffle out the appropriate data from the
-    // wide loads.
-    unsigned J = 0;
-    for (unsigned I = 0; I < InterleaveFactor; ++I) {
+    for (unsigned I = 0, J = 0; I < InterleaveFactor; ++I) {
       Instruction *Member = Group->getMember(I);
 
       // Skip the gaps in the group.
       if (!Member)
         continue;
 
-      auto StrideMask =
-          createStrideMask(I, InterleaveFactor, State.VF.getFixedValue());
-      Value *StridedVec =
-          State.Builder.CreateShuffleVector(NewLoad, StrideMask, "strided.vec");
+      Value *StridedVec = CreateStridedVector(I);
 
       // If this member has different type, cast the result type.
       if (Member->getType() != ScalarTy) {

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants