diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 06511b61a67c3..2fc2447deb3fb 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -3442,7 +3442,6 @@ void VPInterleaveRecipe::execute(VPTransformState &State) { VPValue *BlockInMask = getMask(); VPValue *Addr = getAddr(); Value *ResAddr = State.get(Addr, VPLane(0)); - Value *PoisonVec = PoisonValue::get(VecTy); auto CreateGroupMask = [&BlockInMask, &State, &InterleaveFactor](Value *MaskForGaps) -> Value * { @@ -3481,6 +3480,7 @@ void VPInterleaveRecipe::execute(VPTransformState &State) { Instruction *NewLoad; if (BlockInMask || MaskForGaps) { Value *GroupMask = CreateGroupMask(MaskForGaps); + Value *PoisonVec = PoisonValue::get(VecTy); NewLoad = State.Builder.CreateMaskedLoad(VecTy, ResAddr, Group->getAlign(), GroupMask, PoisonVec, "wide.masked.vec"); @@ -3490,57 +3490,39 @@ void VPInterleaveRecipe::execute(VPTransformState &State) { Group->addMetadata(NewLoad); ArrayRef VPDefs = definedValues(); - const DataLayout &DL = State.CFG.PrevBB->getDataLayout(); if (VecTy->isScalableTy()) { // Scalable vectors cannot use arbitrary shufflevectors (only splats), // so must use intrinsics to deinterleave. assert(InterleaveFactor <= 8 && "Unsupported deinterleave factor for scalable vectors"); - Value *Deinterleave = State.Builder.CreateIntrinsic( + NewLoad = State.Builder.CreateIntrinsic( getDeinterleaveIntrinsicID(InterleaveFactor), NewLoad->getType(), NewLoad, /*FMFSource=*/nullptr, "strided.vec"); + } - for (unsigned I = 0, J = 0; I < InterleaveFactor; ++I) { - Instruction *Member = Group->getMember(I); - Value *StridedVec = State.Builder.CreateExtractValue(Deinterleave, I); - if (!Member) { - // This value is not needed as it's not used - cast(StridedVec)->eraseFromParent(); - continue; - } - // If this member has different type, cast the result type. - if (Member->getType() != ScalarTy) { - VectorType *OtherVTy = VectorType::get(Member->getType(), State.VF); - StridedVec = - createBitOrPointerCast(State.Builder, StridedVec, OtherVTy, DL); - } - - if (Group->isReverse()) - StridedVec = State.Builder.CreateVectorReverse(StridedVec, "reverse"); - - State.set(VPDefs[J], StridedVec); - ++J; - } + auto CreateStridedVector = [&InterleaveFactor, &State, + &NewLoad](unsigned Index) -> Value * { + assert(Index < InterleaveFactor && "Illegal group index"); + if (State.VF.isScalable()) + return State.Builder.CreateExtractValue(NewLoad, Index); - return; - } - assert(!State.VF.isScalable() && "VF is assumed to be non scalable."); + // For fixed length VF, use shuffle to extract the sub-vectors from the + // wide load. + auto StrideMask = + createStrideMask(Index, InterleaveFactor, State.VF.getFixedValue()); + return State.Builder.CreateShuffleVector(NewLoad, StrideMask, + "strided.vec"); + }; - // For each member in the group, shuffle out the appropriate data from the - // wide loads. - unsigned J = 0; - for (unsigned I = 0; I < InterleaveFactor; ++I) { + for (unsigned I = 0, J = 0; I < InterleaveFactor; ++I) { Instruction *Member = Group->getMember(I); // Skip the gaps in the group. if (!Member) continue; - auto StrideMask = - createStrideMask(I, InterleaveFactor, State.VF.getFixedValue()); - Value *StridedVec = - State.Builder.CreateShuffleVector(NewLoad, StrideMask, "strided.vec"); + Value *StridedVec = CreateStridedVector(I); // If this member has different type, cast the result type. if (Member->getType() != ScalarTy) {