Skip to content

[SelectionDAG][RISCV] Use VP_STORE to widen MSTORE in type legalization when possible. #140991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

topperc
Copy link
Collaborator

@topperc topperc commented May 22, 2025

Widening the mask and padding with zeros doesn't work for scalable vectors. Using VL produces less code for fixed vectors.

Similar was recently done for MLOAD.

…on when possible.

Widening the mask and padding with zeros doesn't work for scalable
vectors. Using VL produces less code for fixed vectors.

Similar was recently done for MLOAD.
@topperc topperc requested review from preames and lukel97 May 22, 2025 02:43
@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label May 22, 2025
@llvmbot
Copy link
Member

llvmbot commented May 22, 2025

@llvm/pr-subscribers-llvm-selectiondag

Author: Craig Topper (topperc)

Changes

Widening the mask and padding with zeros doesn't work for scalable vectors. Using VL produces less code for fixed vectors.

Similar was recently done for MLOAD.


Full diff: https://github.com/llvm/llvm-project/pull/140991.diff

4 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+27-11)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-store-fp.ll (+1-4)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-store-int.ll (+1-4)
  • (modified) llvm/test/CodeGen/RISCV/rvv/masked-store-int.ll (+43-61)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 1fd171c24a6f0..c011a0a61d698 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -7316,32 +7316,48 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSTORE(SDNode *N, unsigned OpNo) {
   SDValue Mask = MST->getMask();
   EVT MaskVT = Mask.getValueType();
   SDValue StVal = MST->getValue();
+  EVT VT = StVal.getValueType();
   SDLoc dl(N);
 
+  EVT WideVT, WideMaskVT;
   if (OpNo == 1) {
     // Widen the value.
     StVal = GetWidenedVector(StVal);
 
+    WideVT = StVal.getValueType();
+    WideMaskVT =
+        EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(),
+                         WideVT.getVectorElementCount());
+  } else {
+    WideMaskVT = TLI.getTypeToTransformTo(*DAG.getContext(), MaskVT);
+
+    EVT ValueVT = StVal.getValueType();
+    WideVT = EVT::getVectorVT(*DAG.getContext(), ValueVT.getVectorElementType(),
+                              WideMaskVT.getVectorElementCount());
+  }
+
+  if (TLI.isOperationLegalOrCustom(ISD::VP_STORE, WideVT) &&
+      TLI.isTypeLegal(WideMaskVT)) {
+    Mask = DAG.getInsertSubvector(dl, DAG.getUNDEF(WideMaskVT), Mask, 0);
+    SDValue EVL = DAG.getElementCount(dl, TLI.getVPExplicitVectorLengthTy(),
+                                      VT.getVectorElementCount());
+    return DAG.getStoreVP(MST->getChain(), dl, StVal, MST->getBasePtr(),
+                          MST->getOffset(), Mask, EVL, MST->getMemoryVT(),
+                          MST->getMemOperand(), MST->getAddressingMode());
+  }
+
+  if (OpNo == 1) {
     // The mask should be widened as well.
-    EVT WideVT = StVal.getValueType();
-    EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(),
-                                      MaskVT.getVectorElementType(),
-                                      WideVT.getVectorNumElements());
     Mask = ModifyToType(Mask, WideMaskVT, true);
   } else {
     // Widen the mask.
-    EVT WideMaskVT = TLI.getTypeToTransformTo(*DAG.getContext(), MaskVT);
     Mask = ModifyToType(Mask, WideMaskVT, true);
 
-    EVT ValueVT = StVal.getValueType();
-    EVT WideVT = EVT::getVectorVT(*DAG.getContext(),
-                                  ValueVT.getVectorElementType(),
-                                  WideMaskVT.getVectorNumElements());
     StVal = ModifyToType(StVal, WideVT);
   }
 
-  assert(Mask.getValueType().getVectorNumElements() ==
-         StVal.getValueType().getVectorNumElements() &&
+  assert(Mask.getValueType().getVectorElementCount() ==
+             StVal.getValueType().getVectorElementCount() &&
          "Mask and data vectors should have the same number of elements");
   return DAG.getMaskedStore(MST->getChain(), dl, StVal, MST->getBasePtr(),
                             MST->getOffset(), Mask, MST->getMemoryVT(),
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-store-fp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-store-fp.ll
index f2ff4b284c293..5f4a039866d7e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-store-fp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-store-fp.ll
@@ -325,10 +325,7 @@ define void @masked_store_v128f16(<128 x half> %val, ptr %a, <128 x i1> %mask) {
 define void @masked_store_v7f32(<7 x float> %val, ptr %a, <7 x i1> %mask) {
 ; CHECK-LABEL: masked_store_v7f32:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    li a1, 127
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vmv.s.x v10, a1
-; CHECK-NEXT:    vmand.mm v0, v0, v10
+; CHECK-NEXT:    vsetivli zero, 7, e32, m2, ta, ma
 ; CHECK-NEXT:    vse32.v v8, (a0), v0.t
 ; CHECK-NEXT:    ret
   call void @llvm.masked.store.v7f32.p0(<7 x float> %val, ptr %a, i32 8, <7 x i1> %mask)
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-store-int.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-store-int.ll
index a6bbaab97c8bd..67e19e046f331 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-store-int.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-store-int.ll
@@ -333,10 +333,7 @@ define void @masked_store_v256i8(<256 x i8> %val, ptr %a, <256 x i1> %mask) {
 define void @masked_store_v7i8(<7 x i8> %val, ptr %a, <7 x i1> %mask) {
 ; CHECK-LABEL: masked_store_v7i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    li a1, 127
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vmv.s.x v9, a1
-; CHECK-NEXT:    vmand.mm v0, v0, v9
+; CHECK-NEXT:    vsetivli zero, 7, e8, mf2, ta, ma
 ; CHECK-NEXT:    vse8.v v8, (a0), v0.t
 ; CHECK-NEXT:    ret
   call void @llvm.masked.store.v7i8.p0(<7 x i8> %val, ptr %a, i32 8, <7 x i1> %mask)
diff --git a/llvm/test/CodeGen/RISCV/rvv/masked-store-int.ll b/llvm/test/CodeGen/RISCV/rvv/masked-store-int.ll
index 32414feab722a..92893a7dd463a 100644
--- a/llvm/test/CodeGen/RISCV/rvv/masked-store-int.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/masked-store-int.ll
@@ -1,51 +1,66 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,V
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,V
+; RUN: llc -mtriple=riscv32 -mattr=+zve32x,+zvl128b -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,ZVE32
+; RUN: llc -mtriple=riscv64 -mattr=+zve32x,+zvl128b -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,ZVE32
 
 define void @masked_store_nxv1i8(<vscale x 1 x i8> %val, ptr %a, <vscale x 1 x i1> %mask) nounwind {
-; CHECK-LABEL: masked_store_nxv1i8:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a1, zero, e8, mf8, ta, ma
-; CHECK-NEXT:    vse8.v v8, (a0), v0.t
-; CHECK-NEXT:    ret
+; V-LABEL: masked_store_nxv1i8:
+; V:       # %bb.0:
+; V-NEXT:    vsetvli a1, zero, e8, mf8, ta, ma
+; V-NEXT:    vse8.v v8, (a0), v0.t
+; V-NEXT:    ret
+;
+; ZVE32-LABEL: masked_store_nxv1i8:
+; ZVE32:       # %bb.0:
+; ZVE32-NEXT:    csrr a1, vlenb
+; ZVE32-NEXT:    srli a1, a1, 3
+; ZVE32-NEXT:    vsetvli zero, a1, e8, mf4, ta, ma
+; ZVE32-NEXT:    vse8.v v8, (a0), v0.t
+; ZVE32-NEXT:    ret
   call void @llvm.masked.store.v1i8.p0(<vscale x 1 x i8> %val, ptr %a, i32 1, <vscale x 1 x i1> %mask)
   ret void
 }
 declare void @llvm.masked.store.v1i8.p0(<vscale x 1 x i8>, ptr, i32, <vscale x 1 x i1>)
 
 define void @masked_store_nxv1i16(<vscale x 1 x i16> %val, ptr %a, <vscale x 1 x i1> %mask) nounwind {
-; CHECK-LABEL: masked_store_nxv1i16:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a1, zero, e16, mf4, ta, ma
-; CHECK-NEXT:    vse16.v v8, (a0), v0.t
-; CHECK-NEXT:    ret
+; V-LABEL: masked_store_nxv1i16:
+; V:       # %bb.0:
+; V-NEXT:    vsetvli a1, zero, e16, mf4, ta, ma
+; V-NEXT:    vse16.v v8, (a0), v0.t
+; V-NEXT:    ret
+;
+; ZVE32-LABEL: masked_store_nxv1i16:
+; ZVE32:       # %bb.0:
+; ZVE32-NEXT:    csrr a1, vlenb
+; ZVE32-NEXT:    srli a1, a1, 3
+; ZVE32-NEXT:    vsetvli zero, a1, e16, mf2, ta, ma
+; ZVE32-NEXT:    vse16.v v8, (a0), v0.t
+; ZVE32-NEXT:    ret
   call void @llvm.masked.store.v1i16.p0(<vscale x 1 x i16> %val, ptr %a, i32 2, <vscale x 1 x i1> %mask)
   ret void
 }
 declare void @llvm.masked.store.v1i16.p0(<vscale x 1 x i16>, ptr, i32, <vscale x 1 x i1>)
 
 define void @masked_store_nxv1i32(<vscale x 1 x i32> %val, ptr %a, <vscale x 1 x i1> %mask) nounwind {
-; CHECK-LABEL: masked_store_nxv1i32:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a1, zero, e32, mf2, ta, ma
-; CHECK-NEXT:    vse32.v v8, (a0), v0.t
-; CHECK-NEXT:    ret
+; V-LABEL: masked_store_nxv1i32:
+; V:       # %bb.0:
+; V-NEXT:    vsetvli a1, zero, e32, mf2, ta, ma
+; V-NEXT:    vse32.v v8, (a0), v0.t
+; V-NEXT:    ret
+;
+; ZVE32-LABEL: masked_store_nxv1i32:
+; ZVE32:       # %bb.0:
+; ZVE32-NEXT:    csrr a1, vlenb
+; ZVE32-NEXT:    srli a1, a1, 3
+; ZVE32-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
+; ZVE32-NEXT:    vse32.v v8, (a0), v0.t
+; ZVE32-NEXT:    ret
   call void @llvm.masked.store.v1i32.p0(<vscale x 1 x i32> %val, ptr %a, i32 4, <vscale x 1 x i1> %mask)
   ret void
 }
 declare void @llvm.masked.store.v1i32.p0(<vscale x 1 x i32>, ptr, i32, <vscale x 1 x i1>)
 
-define void @masked_store_nxv1i64(<vscale x 1 x i64> %val, ptr %a, <vscale x 1 x i1> %mask) nounwind {
-; CHECK-LABEL: masked_store_nxv1i64:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a1, zero, e64, m1, ta, ma
-; CHECK-NEXT:    vse64.v v8, (a0), v0.t
-; CHECK-NEXT:    ret
-  call void @llvm.masked.store.v1i64.p0(<vscale x 1 x i64> %val, ptr %a, i32 8, <vscale x 1 x i1> %mask)
-  ret void
-}
-declare void @llvm.masked.store.v1i64.p0(<vscale x 1 x i64>, ptr, i32, <vscale x 1 x i1>)
-
 define void @masked_store_nxv2i8(<vscale x 2 x i8> %val, ptr %a, <vscale x 2 x i1> %mask) nounwind {
 ; CHECK-LABEL: masked_store_nxv2i8:
 ; CHECK:       # %bb.0:
@@ -79,17 +94,6 @@ define void @masked_store_nxv2i32(<vscale x 2 x i32> %val, ptr %a, <vscale x 2 x
 }
 declare void @llvm.masked.store.v2i32.p0(<vscale x 2 x i32>, ptr, i32, <vscale x 2 x i1>)
 
-define void @masked_store_nxv2i64(<vscale x 2 x i64> %val, ptr %a, <vscale x 2 x i1> %mask) nounwind {
-; CHECK-LABEL: masked_store_nxv2i64:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a1, zero, e64, m2, ta, ma
-; CHECK-NEXT:    vse64.v v8, (a0), v0.t
-; CHECK-NEXT:    ret
-  call void @llvm.masked.store.v2i64.p0(<vscale x 2 x i64> %val, ptr %a, i32 8, <vscale x 2 x i1> %mask)
-  ret void
-}
-declare void @llvm.masked.store.v2i64.p0(<vscale x 2 x i64>, ptr, i32, <vscale x 2 x i1>)
-
 define void @masked_store_nxv4i8(<vscale x 4 x i8> %val, ptr %a, <vscale x 4 x i1> %mask) nounwind {
 ; CHECK-LABEL: masked_store_nxv4i8:
 ; CHECK:       # %bb.0:
@@ -123,17 +127,6 @@ define void @masked_store_nxv4i32(<vscale x 4 x i32> %val, ptr %a, <vscale x 4 x
 }
 declare void @llvm.masked.store.v4i32.p0(<vscale x 4 x i32>, ptr, i32, <vscale x 4 x i1>)
 
-define void @masked_store_nxv4i64(<vscale x 4 x i64> %val, ptr %a, <vscale x 4 x i1> %mask) nounwind {
-; CHECK-LABEL: masked_store_nxv4i64:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a1, zero, e64, m4, ta, ma
-; CHECK-NEXT:    vse64.v v8, (a0), v0.t
-; CHECK-NEXT:    ret
-  call void @llvm.masked.store.v4i64.p0(<vscale x 4 x i64> %val, ptr %a, i32 8, <vscale x 4 x i1> %mask)
-  ret void
-}
-declare void @llvm.masked.store.v4i64.p0(<vscale x 4 x i64>, ptr, i32, <vscale x 4 x i1>)
-
 define void @masked_store_nxv8i8(<vscale x 8 x i8> %val, ptr %a, <vscale x 8 x i1> %mask) nounwind {
 ; CHECK-LABEL: masked_store_nxv8i8:
 ; CHECK:       # %bb.0:
@@ -167,17 +160,6 @@ define void @masked_store_nxv8i32(<vscale x 8 x i32> %val, ptr %a, <vscale x 8 x
 }
 declare void @llvm.masked.store.v8i32.p0(<vscale x 8 x i32>, ptr, i32, <vscale x 8 x i1>)
 
-define void @masked_store_nxv8i64(<vscale x 8 x i64> %val, ptr %a, <vscale x 8 x i1> %mask) nounwind {
-; CHECK-LABEL: masked_store_nxv8i64:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a1, zero, e64, m8, ta, ma
-; CHECK-NEXT:    vse64.v v8, (a0), v0.t
-; CHECK-NEXT:    ret
-  call void @llvm.masked.store.v8i64.p0(<vscale x 8 x i64> %val, ptr %a, i32 8, <vscale x 8 x i1> %mask)
-  ret void
-}
-declare void @llvm.masked.store.v8i64.p0(<vscale x 8 x i64>, ptr, i32, <vscale x 8 x i1>)
-
 define void @masked_store_nxv16i8(<vscale x 16 x i8> %val, ptr %a, <vscale x 16 x i1> %mask) nounwind {
 ; CHECK-LABEL: masked_store_nxv16i8:
 ; CHECK:       # %bb.0:

Copy link
Contributor

@wangpc-pp wangpc-pp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

assert(Mask.getValueType().getVectorNumElements() ==
StVal.getValueType().getVectorNumElements() &&
assert(Mask.getValueType().getVectorElementCount() ==
StVal.getValueType().getVectorElementCount() &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird indentation here but clang-format doesn't complain?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants