Skip to content

Commit b88eef9

Browse files
authored
[DSE] Add predicated vector length store support for masked store elimination (#134175)
In isMaskedStoreOverwrite we process two stores that fully overwrite one another, here we add support for predicated vector length stores so that DSE will eliminate this variant of masked stores. This is the follow up installment mentioned in: https://reviews.llvm.org/D132700
1 parent 78b21dd commit b88eef9

File tree

2 files changed

+122
-14
lines changed

2 files changed

+122
-14
lines changed

llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -248,28 +248,43 @@ static OverwriteResult isMaskedStoreOverwrite(const Instruction *KillingI,
248248
return OW_Unknown;
249249
if (KillingII->getIntrinsicID() != DeadII->getIntrinsicID())
250250
return OW_Unknown;
251-
if (KillingII->getIntrinsicID() == Intrinsic::masked_store) {
252-
// Type size.
253-
VectorType *KillingTy =
254-
cast<VectorType>(KillingII->getArgOperand(0)->getType());
255-
VectorType *DeadTy = cast<VectorType>(DeadII->getArgOperand(0)->getType());
256-
if (KillingTy->getScalarSizeInBits() != DeadTy->getScalarSizeInBits())
251+
252+
switch (KillingII->getIntrinsicID()) {
253+
case Intrinsic::masked_store:
254+
case Intrinsic::vp_store: {
255+
const DataLayout &DL = KillingII->getDataLayout();
256+
auto *KillingTy = KillingII->getArgOperand(0)->getType();
257+
auto *DeadTy = DeadII->getArgOperand(0)->getType();
258+
if (DL.getTypeSizeInBits(KillingTy) != DL.getTypeSizeInBits(DeadTy))
257259
return OW_Unknown;
258260
// Element count.
259-
if (KillingTy->getElementCount() != DeadTy->getElementCount())
261+
if (cast<VectorType>(KillingTy)->getElementCount() !=
262+
cast<VectorType>(DeadTy)->getElementCount())
260263
return OW_Unknown;
261264
// Pointers.
262-
Value *KillingPtr = KillingII->getArgOperand(1)->stripPointerCasts();
263-
Value *DeadPtr = DeadII->getArgOperand(1)->stripPointerCasts();
265+
Value *KillingPtr = KillingII->getArgOperand(1);
266+
Value *DeadPtr = DeadII->getArgOperand(1);
264267
if (KillingPtr != DeadPtr && !AA.isMustAlias(KillingPtr, DeadPtr))
265268
return OW_Unknown;
266-
// Masks.
267-
// TODO: check that KillingII's mask is a superset of the DeadII's mask.
268-
if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3))
269-
return OW_Unknown;
269+
if (KillingII->getIntrinsicID() == Intrinsic::masked_store) {
270+
// Masks.
271+
// TODO: check that KillingII's mask is a superset of the DeadII's mask.
272+
if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3))
273+
return OW_Unknown;
274+
} else if (KillingII->getIntrinsicID() == Intrinsic::vp_store) {
275+
// Masks.
276+
// TODO: check that KillingII's mask is a superset of the DeadII's mask.
277+
if (KillingII->getArgOperand(2) != DeadII->getArgOperand(2))
278+
return OW_Unknown;
279+
// Lengths.
280+
if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3))
281+
return OW_Unknown;
282+
}
270283
return OW_Complete;
271284
}
272-
return OW_Unknown;
285+
default:
286+
return OW_Unknown;
287+
}
273288
}
274289

275290
/// Return 'OW_Complete' if a store to the 'KillingLoc' location completely
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -passes=dse -S < %s | FileCheck %s
3+
4+
; Test predicated vector length masked stores for elimination
5+
6+
define void @test1(ptr %a, i32 %vl, <vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2) {
7+
;
8+
; CHECK-LABEL: @test1(
9+
; CHECK-NEXT: [[VP_OP:%.*]] = call <vscale x 8 x i32> @llvm.vp.add.nxv8i32(<vscale x 8 x i32> [[V1:%.*]], <vscale x 8 x i32> [[V2:%.*]], <vscale x 8 x i1> splat (i1 true), i32 [[VL:%.*]])
10+
; CHECK-NEXT: call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[VP_OP]], ptr nonnull [[A:%.*]], <vscale x 8 x i1> splat (i1 true), i32 [[VL]])
11+
; CHECK-NEXT: ret void
12+
;
13+
call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v1, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl)
14+
%vp.op = call <vscale x 8 x i32> @llvm.vp.add.nxv8i32(<vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2, <vscale x 8 x i1> splat (i1 true), i32 %vl)
15+
call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %vp.op, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl)
16+
ret void
17+
}
18+
19+
; False test for different vector lengths
20+
21+
define void @test2(ptr %a, i32 %vl1, i32 %vl2, <vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2) {
22+
;
23+
; CHECK-LABEL: @test2(
24+
; CHECK-NEXT: call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 8 x i1> splat (i1 true), i32 [[VL1:%.*]])
25+
; CHECK-NEXT: call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[A]], <vscale x 8 x i1> splat (i1 true), i32 [[VL2:%.*]])
26+
; CHECK-NEXT: ret void
27+
;
28+
call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v1, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl1)
29+
call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl2)
30+
ret void
31+
}
32+
33+
; False test for different types
34+
35+
define void @test3(ptr %a, i32 %vl1, i32 %vl2, <vscale x 4 x i32> %v1, <vscale x 8 x i32> %v2) {
36+
;
37+
; CHECK-LABEL: @test3(
38+
; CHECK-NEXT: call void @llvm.vp.store.nxv4i32.p0(<vscale x 4 x i32> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 4 x i1> splat (i1 true), i32 [[VL1:%.*]])
39+
; CHECK-NEXT: call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[A]], <vscale x 8 x i1> splat (i1 true), i32 [[VL2:%.*]])
40+
; CHECK-NEXT: ret void
41+
;
42+
call void @llvm.vp.store.nxv4i32.p0(<vscale x 4 x i32> %v1, ptr nonnull %a, <vscale x 4 x i1> splat (i1 true), i32 %vl1)
43+
call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl2)
44+
ret void
45+
}
46+
47+
; False test for different element count
48+
49+
define void @test4(ptr %a, i32 %vl, <vscale x 4 x i64> %v1, <vscale x 8 x i32> %v2) {
50+
;
51+
; CHECK-LABEL: @test4(
52+
; CHECK-NEXT: call void @llvm.vp.store.nxv4i64.p0(<vscale x 4 x i64> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 4 x i1> splat (i1 true), i32 [[VL:%.*]])
53+
; CHECK-NEXT: call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[A]], <vscale x 8 x i1> splat (i1 true), i32 [[VL]])
54+
; CHECK-NEXT: ret void
55+
;
56+
call void @llvm.vp.store.nxv4i64.p0(<vscale x 4 x i64> %v1, ptr nonnull %a, <vscale x 4 x i1> splat (i1 true), i32 %vl)
57+
call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl)
58+
ret void
59+
}
60+
61+
; False test for different masks
62+
63+
define void @test5(ptr %a, i32 %vl, <vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2, <vscale x 8 x i1> %m1, <vscale x 8 x i1> %m2) {
64+
;
65+
; CHECK-LABEL: @test5(
66+
; CHECK-NEXT: call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 8 x i1> [[M1:%.*]], i32 [[VL:%.*]])
67+
; CHECK-NEXT: call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[A]], <vscale x 8 x i1> [[M2:%.*]], i32 [[VL]])
68+
; CHECK-NEXT: ret void
69+
;
70+
call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v1, ptr nonnull %a, <vscale x 8 x i1> %m1, i32 %vl)
71+
call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %a, <vscale x 8 x i1> %m2, i32 %vl)
72+
ret void
73+
}
74+
75+
; False test for different pointers
76+
77+
define void @test6(ptr %a, ptr %b, i32 %vl, <vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2, <vscale x 8 x i1> %m1) {
78+
;
79+
; CHECK-LABEL: @test6(
80+
; CHECK-NEXT: call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 8 x i1> [[M1:%.*]], i32 [[VL:%.*]])
81+
; CHECK-NEXT: call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[B:%.*]], <vscale x 8 x i1> [[M1]], i32 [[VL]])
82+
; CHECK-NEXT: ret void
83+
;
84+
call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v1, ptr nonnull %a, <vscale x 8 x i1> %m1, i32 %vl)
85+
call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %b, <vscale x 8 x i1> %m1, i32 %vl)
86+
ret void
87+
}
88+
89+
declare <vscale x 8 x i32> @llvm.vp.add.nxv8i32(<vscale x 8 x i32>, <vscale x 8 x i32>, <vscale x 8 x i1>, i32)
90+
declare void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32>, ptr nocapture, <vscale x 8 x i1>, i32)
91+
declare void @llvm.vp.store.nxv4i32.p0(<vscale x 4 x i32>, ptr nocapture, <vscale x 4 x i1>, i32)
92+
declare void @llvm.vp.store.nxv4i64.p0(<vscale x 4 x i64>, ptr nocapture, <vscale x 4 x i1>, i32)
93+

0 commit comments

Comments
 (0)