Skip to content

Commit 79435de

Browse files
authored
[ConstantFold] Support scalable constant splats in ConstantFoldCastInstruction (#133207)
Previously only fixed vector splats were handled. This adds supports for scalable vectors too by allowing ConstantExpr splats. We need to add the extra V->getType()->isVectorTy() check because a ConstantExpr might be a scalar to vector bitcast. By allowing ConstantExprs this also allow fixed vector ConstantExprs to be folded, which causes the diffs in llvm/test/Analysis/ValueTracking/known-bits-from-operator-constexpr.ll and llvm/test/Transforms/InstSimplify/ConstProp/cast-vector.ll. I can remove them from this PR if reviewers would prefer. Fixes #132922
1 parent 2080334 commit 79435de

File tree

15 files changed

+36
-21
lines changed

15 files changed

+36
-21
lines changed

llvm/lib/IR/ConstantFold.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,9 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
160160
// If the cast operand is a constant vector, perform the cast by
161161
// operating on each element. In the cast of bitcasts, the element
162162
// count may be mismatched; don't attempt to handle that here.
163-
if ((isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) &&
164-
DestTy->isVectorTy() &&
165-
cast<FixedVectorType>(DestTy)->getNumElements() ==
166-
cast<FixedVectorType>(V->getType())->getNumElements()) {
163+
if (DestTy->isVectorTy() && V->getType()->isVectorTy() &&
164+
cast<VectorType>(DestTy)->getElementCount() ==
165+
cast<VectorType>(V->getType())->getElementCount()) {
167166
VectorType *DestVecTy = cast<VectorType>(DestTy);
168167
Type *DstEltTy = DestVecTy->getElementType();
169168
// Fast path for splatted constants.
@@ -174,6 +173,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
174173
return ConstantVector::getSplat(
175174
cast<VectorType>(DestTy)->getElementCount(), Res);
176175
}
176+
if (isa<ScalableVectorType>(DestTy))
177+
return nullptr;
177178
SmallVector<Constant *, 16> res;
178179
Type *Ty = IntegerType::get(V->getContext(), 32);
179180
for (unsigned i = 0,

llvm/test/Analysis/ValueTracking/known-bits-from-operator-constexpr.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
@g = global [21 x i32] zeroinitializer
88
define i32 @test1(i32 %a) {
99
; CHECK-LABEL: @test1(
10-
; CHECK-NEXT: [[T:%.*]] = sub i32 [[A:%.*]], extractelement (<4 x i32> ptrtoint (<4 x ptr> getelementptr inbounds ([21 x i32], ptr @g, <4 x i32> zeroinitializer, <4 x i32> <i32 1, i32 2, i32 3, i32 17>) to <4 x i32>), i32 3)
10+
; CHECK-NEXT: [[T:%.*]] = sub i32 [[A:%.*]], ptrtoint (ptr getelementptr inbounds ([21 x i32], ptr @g, i32 0, i32 17) to i32)
1111
; CHECK-NEXT: ret i32 [[T]]
1212
;
1313
%t = sub i32 %a, extractelement (<4 x i32> ptrtoint (<4 x ptr> getelementptr inbounds ([21 x i32], ptr @g, <4 x i32> zeroinitializer, <4 x i32> <i32 1, i32 2, i32 3, i32 17>) to <4 x i32>), i32 3)

llvm/test/Transforms/InstCombine/addrspacecast.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ define ptr addrspace(4) @constant_fold_undef() #0 {
191191

192192
define <4 x ptr addrspace(4)> @constant_fold_null_vector() #0 {
193193
; CHECK-LABEL: @constant_fold_null_vector(
194-
; CHECK-NEXT: ret <4 x ptr addrspace(4)> addrspacecast (<4 x ptr addrspace(3)> zeroinitializer to <4 x ptr addrspace(4)>)
194+
; CHECK-NEXT: ret <4 x ptr addrspace(4)> <ptr addrspace(4) addrspacecast (ptr addrspace(3) null to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr addrspace(3) null to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr addrspace(3) null to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr addrspace(3) null to ptr addrspace(4))>
195195
;
196196
%cast = addrspacecast <4 x ptr addrspace(3)> zeroinitializer to <4 x ptr addrspace(4)>
197197
ret <4 x ptr addrspace(4)> %cast

llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ define <vscale x 2 x float> @shrink_splat_scalable_extend(<vscale x 2 x float> %
1717
define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(<vscale x 2 x float> %a) {
1818
; CHECK-LABEL: define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(
1919
; CHECK-SAME: <vscale x 2 x float> [[A:%.*]]) {
20-
; CHECK-NEXT: [[TMP1:%.*]] = fptrunc <vscale x 2 x double> splat (double -1.000000e+00) to <vscale x 2 x float>
21-
; CHECK-NEXT: [[TMP3:%.*]] = fadd <vscale x 2 x float> [[A]], [[TMP1]]
20+
; CHECK-NEXT: [[TMP3:%.*]] = fadd <vscale x 2 x float> [[A]], splat (float -1.000000e+00)
2221
; CHECK-NEXT: ret <vscale x 2 x float> [[TMP3]]
2322
;
2423
%2 = fpext <vscale x 2 x float> %a to <vscale x 2 x double>

llvm/test/Transforms/InstCombine/scalable-trunc.ll

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,21 @@ entry:
2020
ret void
2121
}
2222

23+
define <vscale x 1 x i8> @constant_splat_trunc() {
24+
; CHECK-LABEL: @constant_splat_trunc(
25+
; CHECK-NEXT: ret <vscale x 1 x i8> splat (i8 1)
26+
;
27+
%1 = trunc <vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>
28+
ret <vscale x 1 x i8> %1
29+
}
30+
31+
define <vscale x 1 x i8> @constant_splat_trunc_constantexpr() {
32+
; CHECK-LABEL: @constant_splat_trunc_constantexpr(
33+
; CHECK-NEXT: ret <vscale x 1 x i8> splat (i8 1)
34+
;
35+
ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
36+
}
37+
2338
declare void @llvm.aarch64.sve.st1.nxv2i32(<vscale x 2 x i32>, <vscale x 2 x i1>, ptr)
2439
declare <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1>)
2540
declare <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 %pattern)

llvm/test/Transforms/InstSimplify/ConstProp/cast-vector.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
define <2 x i16> @test1() {
99
; CHECK-LABEL: @test1(
1010
; CHECK-NEXT: entry:
11-
; CHECK-NEXT: ret <2 x i16> ptrtoint (<2 x ptr> getelementptr inbounds ([10 x i32], ptr null, <2 x i64> zeroinitializer, <2 x i64> <i64 5, i64 7>) to <2 x i16>)
11+
; CHECK-NEXT: ret <2 x i16> <i16 ptrtoint (ptr getelementptr inbounds ([10 x i32], ptr null, i64 0, i64 5) to i16), i16 ptrtoint (ptr getelementptr inbounds ([10 x i32], ptr null, i64 0, i64 7) to i16)>
1212
;
1313
entry:
1414
%gep = getelementptr inbounds [10 x i32], ptr null, i16 0, <2 x i16> <i16 5, i16 7>
@@ -23,7 +23,7 @@ entry:
2323
define <2 x i16> @test2() {
2424
; CHECK-LABEL: @test2(
2525
; CHECK-NEXT: entry:
26-
; CHECK-NEXT: ret <2 x i16> ptrtoint (<2 x ptr> getelementptr (i32, ptr null, <2 x i64> <i64 5, i64 7>) to <2 x i16>)
26+
; CHECK-NEXT: ret <2 x i16> <i16 ptrtoint (ptr getelementptr (i32, ptr null, i64 5) to i16), i16 ptrtoint (ptr getelementptr (i32, ptr null, i64 7) to i16)>
2727
;
2828
entry:
2929
%gep = getelementptr i32, ptr null, <2 x i16> <i16 5, i16 7>

llvm/test/Transforms/InstSimplify/ConstProp/vscale-inseltpoison.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ define <vscale x 4 x i32> @shufflevector() {
208208

209209
define <vscale x 4 x float> @bitcast() {
210210
; CHECK-LABEL: @bitcast(
211-
; CHECK-NEXT: ret <vscale x 4 x float> bitcast (<vscale x 4 x i32> splat (i32 1) to <vscale x 4 x float>)
211+
; CHECK-NEXT: ret <vscale x 4 x float> splat (float 0x36A0000000000000)
212212
;
213213
%i1 = insertelement <vscale x 4 x i32> poison, i32 1, i32 0
214214
%i2 = shufflevector <vscale x 4 x i32> %i1, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer

llvm/test/Transforms/InstSimplify/ConstProp/vscale.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ define <vscale x 4 x i32> @shufflevector() {
208208

209209
define <vscale x 4 x float> @bitcast() {
210210
; CHECK-LABEL: @bitcast(
211-
; CHECK-NEXT: ret <vscale x 4 x float> bitcast (<vscale x 4 x i32> splat (i32 1) to <vscale x 4 x float>)
211+
; CHECK-NEXT: ret <vscale x 4 x float> splat (float 0x36A0000000000000)
212212
;
213213
%i1 = insertelement <vscale x 4 x i32> undef, i32 1, i32 0
214214
%i2 = shufflevector <vscale x 4 x i32> %i1, <vscale x 4 x i32> undef, <vscale x 4 x i32> zeroinitializer

llvm/test/Transforms/InstSimplify/vscale-inseltpoison.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ define <vscale x 2 x i1> @cmp_le_smax_always_true(<vscale x 2 x i64> %x) {
140140

141141
define <vscale x 4 x float> @bitcast() {
142142
; CHECK-LABEL: @bitcast(
143-
; CHECK-NEXT: ret <vscale x 4 x float> bitcast (<vscale x 4 x i32> splat (i32 1) to <vscale x 4 x float>)
143+
; CHECK-NEXT: ret <vscale x 4 x float> splat (float 0x36A0000000000000)
144144
;
145145
%i1 = insertelement <vscale x 4 x i32> poison, i32 1, i32 0
146146
%i2 = shufflevector <vscale x 4 x i32> %i1, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer

llvm/test/Transforms/InstSimplify/vscale.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ define <vscale x 2 x i1> @cmp_le_smax_always_true(<vscale x 2 x i64> %x) {
152152

153153
define <vscale x 4 x float> @bitcast() {
154154
; CHECK-LABEL: @bitcast(
155-
; CHECK-NEXT: ret <vscale x 4 x float> bitcast (<vscale x 4 x i32> splat (i32 1) to <vscale x 4 x float>)
155+
; CHECK-NEXT: ret <vscale x 4 x float> splat (float 0x36A0000000000000)
156156
;
157157
%i1 = insertelement <vscale x 4 x i32> undef, i32 1, i32 0
158158
%i2 = shufflevector <vscale x 4 x i32> %i1, <vscale x 4 x i32> undef, <vscale x 4 x i32> zeroinitializer

llvm/test/Transforms/LoopVectorize/AArch64/induction-costs-sve.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ define void @iv_casts(ptr %dst, ptr %src, i32 %x, i64 %N) #0 {
5151
; DEFAULT-NEXT: [[TMP31:%.*]] = zext <vscale x 8 x i8> [[WIDE_LOAD4]] to <vscale x 8 x i16>
5252
; DEFAULT-NEXT: [[TMP32:%.*]] = or <vscale x 8 x i16> [[TMP28]], [[TMP30]]
5353
; DEFAULT-NEXT: [[TMP33:%.*]] = or <vscale x 8 x i16> [[TMP29]], [[TMP31]]
54-
; DEFAULT-NEXT: [[TMP34:%.*]] = lshr <vscale x 8 x i16> [[TMP32]], trunc (<vscale x 8 x i32> splat (i32 1) to <vscale x 8 x i16>)
55-
; DEFAULT-NEXT: [[TMP35:%.*]] = lshr <vscale x 8 x i16> [[TMP33]], trunc (<vscale x 8 x i32> splat (i32 1) to <vscale x 8 x i16>)
54+
; DEFAULT-NEXT: [[TMP34:%.*]] = lshr <vscale x 8 x i16> [[TMP32]], splat (i16 1)
55+
; DEFAULT-NEXT: [[TMP35:%.*]] = lshr <vscale x 8 x i16> [[TMP33]], splat (i16 1)
5656
; DEFAULT-NEXT: [[TMP36:%.*]] = trunc <vscale x 8 x i16> [[TMP34]] to <vscale x 8 x i8>
5757
; DEFAULT-NEXT: [[TMP37:%.*]] = trunc <vscale x 8 x i16> [[TMP35]] to <vscale x 8 x i8>
5858
; DEFAULT-NEXT: [[TMP38:%.*]] = getelementptr i8, ptr [[DST]], i64 [[INDEX]]
@@ -131,7 +131,7 @@ define void @iv_casts(ptr %dst, ptr %src, i32 %x, i64 %N) #0 {
131131
; PRED-NEXT: [[TMP22:%.*]] = mul <vscale x 16 x i16> [[TMP17]], [[TMP16]]
132132
; PRED-NEXT: [[TMP24:%.*]] = zext <vscale x 16 x i8> [[WIDE_MASKED_LOAD]] to <vscale x 16 x i16>
133133
; PRED-NEXT: [[TMP20:%.*]] = or <vscale x 16 x i16> [[TMP22]], [[TMP24]]
134-
; PRED-NEXT: [[TMP21:%.*]] = lshr <vscale x 16 x i16> [[TMP20]], trunc (<vscale x 16 x i32> splat (i32 1) to <vscale x 16 x i16>)
134+
; PRED-NEXT: [[TMP21:%.*]] = lshr <vscale x 16 x i16> [[TMP20]], splat (i16 1)
135135
; PRED-NEXT: [[TMP23:%.*]] = trunc <vscale x 16 x i16> [[TMP21]] to <vscale x 16 x i8>
136136
; PRED-NEXT: [[TMP26:%.*]] = getelementptr i8, ptr [[DST]], i64 [[INDEX]]
137137
; PRED-NEXT: [[TMP27:%.*]] = getelementptr i8, ptr [[TMP26]], i32 0

llvm/test/Transforms/LoopVectorize/RISCV/truncate-to-minimal-bitwidth-cost.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ define void @truncate_to_i1_used_by_branch(i8 %x, ptr %dst) #0 {
169169
; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <vscale x 4 x i8> [[BROADCAST_SPLATINSERT]], <vscale x 4 x i8> poison, <vscale x 4 x i32> zeroinitializer
170170
; CHECK-NEXT: [[TMP6:%.*]] = trunc i32 [[N_VEC]] to i8
171171
; CHECK-NEXT: [[TMP7:%.*]] = trunc <vscale x 4 x i8> [[BROADCAST_SPLAT]] to <vscale x 4 x i1>
172-
; CHECK-NEXT: [[TMP8:%.*]] = or <vscale x 4 x i1> trunc (<vscale x 4 x i8> splat (i8 23) to <vscale x 4 x i1>), [[TMP7]]
172+
; CHECK-NEXT: [[TMP8:%.*]] = or <vscale x 4 x i1> splat (i1 true), [[TMP7]]
173173
; CHECK-NEXT: [[BROADCAST_SPLATINSERT1:%.*]] = insertelement <vscale x 4 x ptr> poison, ptr [[DST]], i64 0
174174
; CHECK-NEXT: [[BROADCAST_SPLAT2:%.*]] = shufflevector <vscale x 4 x ptr> [[BROADCAST_SPLATINSERT1]], <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer
175175
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]

llvm/test/Transforms/LoopVectorize/RISCV/truncate-to-minimal-bitwidth-evl-crash.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ define void @truncate_to_minimal_bitwidths_widen_cast_recipe(ptr %src) {
3030
; CHECK-NEXT: [[VP_OP_LOAD:%.*]] = call <vscale x 8 x i8> @llvm.vp.load.nxv8i8.p0(ptr align 1 [[TMP6]], <vscale x 8 x i1> splat (i1 true), i32 [[TMP7]])
3131
; CHECK-NEXT: [[TMP8:%.*]] = zext <vscale x 8 x i8> [[VP_OP_LOAD]] to <vscale x 8 x i16>
3232
; CHECK-NEXT: [[TMP12:%.*]] = mul <vscale x 8 x i16> zeroinitializer, [[TMP8]]
33-
; CHECK-NEXT: [[TMP13:%.*]] = lshr <vscale x 8 x i16> [[TMP12]], trunc (<vscale x 8 x i32> splat (i32 1) to <vscale x 8 x i16>)
33+
; CHECK-NEXT: [[TMP13:%.*]] = lshr <vscale x 8 x i16> [[TMP12]], splat (i16 1)
3434
; CHECK-NEXT: [[TMP14:%.*]] = trunc <vscale x 8 x i16> [[TMP13]] to <vscale x 8 x i8>
3535
; CHECK-NEXT: call void @llvm.vp.scatter.nxv8i8.nxv8p0(<vscale x 8 x i8> [[TMP14]], <vscale x 8 x ptr> align 1 zeroinitializer, <vscale x 8 x i1> splat (i1 true), i32 [[TMP7]])
3636
; CHECK-NEXT: [[TMP9:%.*]] = zext i32 [[TMP7]] to i64

llvm/test/Transforms/MemCpyOpt/crash.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ define void @test2(i32 %cmd) nounwind {
8585

8686
define void @inttoptr_constexpr_crash(ptr %p) {
8787
; CHECK-LABEL: @inttoptr_constexpr_crash(
88-
; CHECK-NEXT: store <1 x ptr> inttoptr (<1 x i16> bitcast (<2 x i8> <i8 ptrtoint (ptr @g to i8), i8 ptrtoint (ptr @g to i8)> to <1 x i16>) to <1 x ptr>), ptr [[P:%.*]], align 1
88+
; CHECK-NEXT: store <1 x ptr> <ptr inttoptr (i16 extractelement (<1 x i16> bitcast (<2 x i8> <i8 ptrtoint (ptr @g to i8), i8 ptrtoint (ptr @g to i8)> to <1 x i16>), i32 0) to ptr)>, ptr [[P:%.*]], align 1
8989
; CHECK-NEXT: ret void
9090
;
9191
store <1 x ptr> inttoptr (<1 x i16> bitcast (<2 x i8> <i8 ptrtoint (ptr @g to i8), i8 ptrtoint (ptr @g to i8)> to <1 x i16>) to <1 x ptr>), ptr %p, align 1

llvm/test/Transforms/VectorCombine/pr88796.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
define i32 @test() {
55
; CHECK-LABEL: define i32 @test() {
66
; CHECK-NEXT: entry:
7-
; CHECK-NEXT: [[TMP0:%.*]] = tail call i16 @llvm.vector.reduce.and.nxv8i16(<vscale x 8 x i16> trunc (<vscale x 8 x i32> splat (i32 268435456) to <vscale x 8 x i16>))
7+
; CHECK-NEXT: [[TMP0:%.*]] = tail call i16 @llvm.vector.reduce.and.nxv8i16(<vscale x 8 x i16> zeroinitializer)
88
; CHECK-NEXT: ret i32 0
99
;
1010
entry:

0 commit comments

Comments
 (0)