Skip to content

Commit 08b81a8

Browse files
committed
[ConstantFold] Support scalable constant splats in ConstantFoldCastInstruction
Previously only fixed vector splats were handled. This adds supports for scalable vectors too by allowing ConstantExpr splats. We need to add the extra V->getType()->isVectorTy() check because a ConstantExpr might be a scalar to vector bitcast. By allowing ConstantExprs this also allow fixed vector ConstantExprs to be folded, which causes the diffs in llvm/test/Analysis/ValueTracking/known-bits-from-operator-constexpr.ll and llvm/test/Transforms/InstSimplify/ConstProp/cast-vector.ll. I can remove them from this PR if reviewers would prefer. Fixes #132922
1 parent 11b2bb1 commit 08b81a8

File tree

12 files changed

+21
-20
lines changed

12 files changed

+21
-20
lines changed

llvm/lib/IR/ConstantFold.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,10 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
160160
// If the cast operand is a constant vector, perform the cast by
161161
// operating on each element. In the cast of bitcasts, the element
162162
// count may be mismatched; don't attempt to handle that here.
163-
if ((isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) &&
164-
DestTy->isVectorTy() &&
165-
cast<FixedVectorType>(DestTy)->getNumElements() ==
166-
cast<FixedVectorType>(V->getType())->getNumElements()) {
163+
if ((isa<ConstantVector, ConstantDataVector, ConstantExpr>(V)) &&
164+
DestTy->isVectorTy() && V->getType()->isVectorTy() &&
165+
cast<VectorType>(DestTy)->getElementCount() ==
166+
cast<VectorType>(V->getType())->getElementCount()) {
167167
VectorType *DestVecTy = cast<VectorType>(DestTy);
168168
Type *DstEltTy = DestVecTy->getElementType();
169169
// Fast path for splatted constants.
@@ -174,6 +174,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
174174
return ConstantVector::getSplat(
175175
cast<VectorType>(DestTy)->getElementCount(), Res);
176176
}
177+
if (isa<ScalableVectorType>(DestTy))
178+
return nullptr;
177179
SmallVector<Constant *, 16> res;
178180
Type *Ty = IntegerType::get(V->getContext(), 32);
179181
for (unsigned i = 0,

llvm/test/Analysis/ValueTracking/known-bits-from-operator-constexpr.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
@g = global [21 x i32] zeroinitializer
88
define i32 @test1(i32 %a) {
99
; CHECK-LABEL: @test1(
10-
; CHECK-NEXT: [[T:%.*]] = sub i32 [[A:%.*]], extractelement (<4 x i32> ptrtoint (<4 x ptr> getelementptr inbounds ([21 x i32], ptr @g, <4 x i32> zeroinitializer, <4 x i32> <i32 1, i32 2, i32 3, i32 17>) to <4 x i32>), i32 3)
10+
; CHECK-NEXT: [[T:%.*]] = sub i32 [[A:%.*]], ptrtoint (ptr getelementptr inbounds ([21 x i32], ptr @g, i32 0, i32 17) to i32)
1111
; CHECK-NEXT: ret i32 [[T]]
1212
;
1313
%t = sub i32 %a, extractelement (<4 x i32> ptrtoint (<4 x ptr> getelementptr inbounds ([21 x i32], ptr @g, <4 x i32> zeroinitializer, <4 x i32> <i32 1, i32 2, i32 3, i32 17>) to <4 x i32>), i32 3)

llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ define <vscale x 2 x float> @shrink_splat_scalable_extend(<vscale x 2 x float> %
1717
define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(<vscale x 2 x float> %a) {
1818
; CHECK-LABEL: define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(
1919
; CHECK-SAME: <vscale x 2 x float> [[A:%.*]]) {
20-
; CHECK-NEXT: [[TMP1:%.*]] = fptrunc <vscale x 2 x double> splat (double -1.000000e+00) to <vscale x 2 x float>
21-
; CHECK-NEXT: [[TMP3:%.*]] = fadd <vscale x 2 x float> [[A]], [[TMP1]]
20+
; CHECK-NEXT: [[TMP3:%.*]] = fadd <vscale x 2 x float> [[A]], splat (float -1.000000e+00)
2221
; CHECK-NEXT: ret <vscale x 2 x float> [[TMP3]]
2322
;
2423
%2 = fpext <vscale x 2 x float> %a to <vscale x 2 x double>

llvm/test/Transforms/InstCombine/scalable-trunc.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ entry:
2222

2323
define <vscale x 1 x i8> @constant_splat_trunc() {
2424
; CHECK-LABEL: @constant_splat_trunc(
25-
; CHECK-NEXT: ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
25+
; CHECK-NEXT: ret <vscale x 1 x i8> splat (i8 1)
2626
;
2727
%1 = trunc <vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>
2828
ret <vscale x 1 x i8> %1
2929
}
3030

3131
define <vscale x 1 x i8> @constant_splat_trunc_constantexpr() {
3232
; CHECK-LABEL: @constant_splat_trunc_constantexpr(
33-
; CHECK-NEXT: ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
33+
; CHECK-NEXT: ret <vscale x 1 x i8> splat (i8 1)
3434
;
3535
ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
3636
}

llvm/test/Transforms/InstSimplify/ConstProp/cast-vector.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
define <2 x i16> @test1() {
99
; CHECK-LABEL: @test1(
1010
; CHECK-NEXT: entry:
11-
; CHECK-NEXT: ret <2 x i16> ptrtoint (<2 x ptr> getelementptr inbounds ([10 x i32], ptr null, <2 x i64> zeroinitializer, <2 x i64> <i64 5, i64 7>) to <2 x i16>)
11+
; CHECK-NEXT: ret <2 x i16> <i16 ptrtoint (ptr getelementptr inbounds ([10 x i32], ptr null, i64 0, i64 5) to i16), i16 ptrtoint (ptr getelementptr inbounds ([10 x i32], ptr null, i64 0, i64 7) to i16)>
1212
;
1313
entry:
1414
%gep = getelementptr inbounds [10 x i32], ptr null, i16 0, <2 x i16> <i16 5, i16 7>
@@ -23,7 +23,7 @@ entry:
2323
define <2 x i16> @test2() {
2424
; CHECK-LABEL: @test2(
2525
; CHECK-NEXT: entry:
26-
; CHECK-NEXT: ret <2 x i16> ptrtoint (<2 x ptr> getelementptr (i32, ptr null, <2 x i64> <i64 5, i64 7>) to <2 x i16>)
26+
; CHECK-NEXT: ret <2 x i16> <i16 ptrtoint (ptr getelementptr (i32, ptr null, i64 5) to i16), i16 ptrtoint (ptr getelementptr (i32, ptr null, i64 7) to i16)>
2727
;
2828
entry:
2929
%gep = getelementptr i32, ptr null, <2 x i16> <i16 5, i16 7>

llvm/test/Transforms/InstSimplify/ConstProp/vscale-inseltpoison.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ define <vscale x 4 x i32> @shufflevector() {
208208

209209
define <vscale x 4 x float> @bitcast() {
210210
; CHECK-LABEL: @bitcast(
211-
; CHECK-NEXT: ret <vscale x 4 x float> bitcast (<vscale x 4 x i32> splat (i32 1) to <vscale x 4 x float>)
211+
; CHECK-NEXT: ret <vscale x 4 x float> splat (float 0x36A0000000000000)
212212
;
213213
%i1 = insertelement <vscale x 4 x i32> poison, i32 1, i32 0
214214
%i2 = shufflevector <vscale x 4 x i32> %i1, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer

llvm/test/Transforms/InstSimplify/ConstProp/vscale.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ define <vscale x 4 x i32> @shufflevector() {
208208

209209
define <vscale x 4 x float> @bitcast() {
210210
; CHECK-LABEL: @bitcast(
211-
; CHECK-NEXT: ret <vscale x 4 x float> bitcast (<vscale x 4 x i32> splat (i32 1) to <vscale x 4 x float>)
211+
; CHECK-NEXT: ret <vscale x 4 x float> splat (float 0x36A0000000000000)
212212
;
213213
%i1 = insertelement <vscale x 4 x i32> undef, i32 1, i32 0
214214
%i2 = shufflevector <vscale x 4 x i32> %i1, <vscale x 4 x i32> undef, <vscale x 4 x i32> zeroinitializer

llvm/test/Transforms/InstSimplify/vscale-inseltpoison.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ define <vscale x 2 x i1> @cmp_le_smax_always_true(<vscale x 2 x i64> %x) {
140140

141141
define <vscale x 4 x float> @bitcast() {
142142
; CHECK-LABEL: @bitcast(
143-
; CHECK-NEXT: ret <vscale x 4 x float> bitcast (<vscale x 4 x i32> splat (i32 1) to <vscale x 4 x float>)
143+
; CHECK-NEXT: ret <vscale x 4 x float> splat (float 0x36A0000000000000)
144144
;
145145
%i1 = insertelement <vscale x 4 x i32> poison, i32 1, i32 0
146146
%i2 = shufflevector <vscale x 4 x i32> %i1, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer

llvm/test/Transforms/InstSimplify/vscale.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ define <vscale x 2 x i1> @cmp_le_smax_always_true(<vscale x 2 x i64> %x) {
152152

153153
define <vscale x 4 x float> @bitcast() {
154154
; CHECK-LABEL: @bitcast(
155-
; CHECK-NEXT: ret <vscale x 4 x float> bitcast (<vscale x 4 x i32> splat (i32 1) to <vscale x 4 x float>)
155+
; CHECK-NEXT: ret <vscale x 4 x float> splat (float 0x36A0000000000000)
156156
;
157157
%i1 = insertelement <vscale x 4 x i32> undef, i32 1, i32 0
158158
%i2 = shufflevector <vscale x 4 x i32> %i1, <vscale x 4 x i32> undef, <vscale x 4 x i32> zeroinitializer

llvm/test/Transforms/LoopVectorize/AArch64/induction-costs-sve.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ define void @iv_casts(ptr %dst, ptr %src, i32 %x, i64 %N) #0 {
5151
; DEFAULT-NEXT: [[TMP31:%.*]] = zext <vscale x 8 x i8> [[WIDE_LOAD4]] to <vscale x 8 x i16>
5252
; DEFAULT-NEXT: [[TMP32:%.*]] = or <vscale x 8 x i16> [[TMP28]], [[TMP30]]
5353
; DEFAULT-NEXT: [[TMP33:%.*]] = or <vscale x 8 x i16> [[TMP29]], [[TMP31]]
54-
; DEFAULT-NEXT: [[TMP34:%.*]] = lshr <vscale x 8 x i16> [[TMP32]], trunc (<vscale x 8 x i32> splat (i32 1) to <vscale x 8 x i16>)
55-
; DEFAULT-NEXT: [[TMP35:%.*]] = lshr <vscale x 8 x i16> [[TMP33]], trunc (<vscale x 8 x i32> splat (i32 1) to <vscale x 8 x i16>)
54+
; DEFAULT-NEXT: [[TMP34:%.*]] = lshr <vscale x 8 x i16> [[TMP32]], splat (i16 1)
55+
; DEFAULT-NEXT: [[TMP35:%.*]] = lshr <vscale x 8 x i16> [[TMP33]], splat (i16 1)
5656
; DEFAULT-NEXT: [[TMP36:%.*]] = trunc <vscale x 8 x i16> [[TMP34]] to <vscale x 8 x i8>
5757
; DEFAULT-NEXT: [[TMP37:%.*]] = trunc <vscale x 8 x i16> [[TMP35]] to <vscale x 8 x i8>
5858
; DEFAULT-NEXT: [[TMP38:%.*]] = getelementptr i8, ptr [[DST]], i64 [[INDEX]]
@@ -131,7 +131,7 @@ define void @iv_casts(ptr %dst, ptr %src, i32 %x, i64 %N) #0 {
131131
; PRED-NEXT: [[TMP22:%.*]] = mul <vscale x 16 x i16> [[TMP17]], [[TMP16]]
132132
; PRED-NEXT: [[TMP24:%.*]] = zext <vscale x 16 x i8> [[WIDE_MASKED_LOAD]] to <vscale x 16 x i16>
133133
; PRED-NEXT: [[TMP20:%.*]] = or <vscale x 16 x i16> [[TMP22]], [[TMP24]]
134-
; PRED-NEXT: [[TMP21:%.*]] = lshr <vscale x 16 x i16> [[TMP20]], trunc (<vscale x 16 x i32> splat (i32 1) to <vscale x 16 x i16>)
134+
; PRED-NEXT: [[TMP21:%.*]] = lshr <vscale x 16 x i16> [[TMP20]], splat (i16 1)
135135
; PRED-NEXT: [[TMP23:%.*]] = trunc <vscale x 16 x i16> [[TMP21]] to <vscale x 16 x i8>
136136
; PRED-NEXT: [[TMP26:%.*]] = getelementptr i8, ptr [[DST]], i64 [[INDEX]]
137137
; PRED-NEXT: [[TMP27:%.*]] = getelementptr i8, ptr [[TMP26]], i32 0

llvm/test/Transforms/LoopVectorize/RISCV/truncate-to-minimal-bitwidth-evl-crash.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ define void @truncate_to_minimal_bitwidths_widen_cast_recipe(ptr %src) {
3030
; CHECK-NEXT: [[VP_OP_LOAD:%.*]] = call <vscale x 8 x i8> @llvm.vp.load.nxv8i8.p0(ptr align 1 [[TMP6]], <vscale x 8 x i1> splat (i1 true), i32 [[TMP7]])
3131
; CHECK-NEXT: [[TMP8:%.*]] = zext <vscale x 8 x i8> [[VP_OP_LOAD]] to <vscale x 8 x i16>
3232
; CHECK-NEXT: [[TMP12:%.*]] = mul <vscale x 8 x i16> zeroinitializer, [[TMP8]]
33-
; CHECK-NEXT: [[TMP13:%.*]] = lshr <vscale x 8 x i16> [[TMP12]], trunc (<vscale x 8 x i32> splat (i32 1) to <vscale x 8 x i16>)
33+
; CHECK-NEXT: [[TMP13:%.*]] = lshr <vscale x 8 x i16> [[TMP12]], splat (i16 1)
3434
; CHECK-NEXT: [[TMP14:%.*]] = trunc <vscale x 8 x i16> [[TMP13]] to <vscale x 8 x i8>
3535
; CHECK-NEXT: call void @llvm.vp.scatter.nxv8i8.nxv8p0(<vscale x 8 x i8> [[TMP14]], <vscale x 8 x ptr> align 1 zeroinitializer, <vscale x 8 x i1> splat (i1 true), i32 [[TMP7]])
3636
; CHECK-NEXT: [[TMP9:%.*]] = zext i32 [[TMP7]] to i64

llvm/test/Transforms/VectorCombine/pr88796.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
define i32 @test() {
55
; CHECK-LABEL: define i32 @test() {
66
; CHECK-NEXT: entry:
7-
; CHECK-NEXT: [[TMP0:%.*]] = tail call i16 @llvm.vector.reduce.and.nxv8i16(<vscale x 8 x i16> trunc (<vscale x 8 x i32> splat (i32 268435456) to <vscale x 8 x i16>))
7+
; CHECK-NEXT: [[TMP0:%.*]] = tail call i16 @llvm.vector.reduce.and.nxv8i16(<vscale x 8 x i16> zeroinitializer)
88
; CHECK-NEXT: ret i32 0
99
;
1010
entry:

0 commit comments

Comments
 (0)