Skip to content

Commit 41f9586

Browse files
committed
[ConstantFold] Support scalable constant splats in ConstantFoldCastInstruction
Stacked on #132960 to prevent a regression Previously only fixed vector splats were handled. This adds supports for scalable vectors too by allowing ConstantExpr splats. We need to add the extra V->getType()->isVectorTy() check because a ConstantExpr might be a scalar to vector bitcast. I believe this will also allow casts of fixed vector ConstantExprs to be folded but I couldn't come up with a test case for this, the ConstantExprs seem to be folded away before reaching InstCombine. Fixes #132922
1 parent 3758a21 commit 41f9586

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

llvm/lib/IR/ConstantFold.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,10 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
160160
// If the cast operand is a constant vector, perform the cast by
161161
// operating on each element. In the cast of bitcasts, the element
162162
// count may be mismatched; don't attempt to handle that here.
163-
if ((isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) &&
164-
DestTy->isVectorTy() &&
165-
cast<FixedVectorType>(DestTy)->getNumElements() ==
166-
cast<FixedVectorType>(V->getType())->getNumElements()) {
163+
if ((isa<ConstantVector, ConstantDataVector, ConstantExpr>(V)) &&
164+
DestTy->isVectorTy() && V->getType()->isVectorTy() &&
165+
cast<VectorType>(DestTy)->getElementCount() ==
166+
cast<VectorType>(V->getType())->getElementCount()) {
167167
VectorType *DestVecTy = cast<VectorType>(DestTy);
168168
Type *DstEltTy = DestVecTy->getElementType();
169169
// Fast path for splatted constants.
@@ -174,6 +174,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
174174
return ConstantVector::getSplat(
175175
cast<VectorType>(DestTy)->getElementCount(), Res);
176176
}
177+
if (isa<ScalableVectorType>(DestTy))
178+
return nullptr;
177179
SmallVector<Constant *, 16> res;
178180
Type *Ty = IntegerType::get(V->getContext(), 32);
179181
for (unsigned i = 0,

llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ define <vscale x 2 x float> @shrink_splat_scalable_extend(<vscale x 2 x float> %
1717
define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(<vscale x 2 x float> %a) {
1818
; CHECK-LABEL: define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(
1919
; CHECK-SAME: <vscale x 2 x float> [[A:%.*]]) {
20-
; CHECK-NEXT: [[TMP1:%.*]] = fptrunc <vscale x 2 x double> splat (double -1.000000e+00) to <vscale x 2 x float>
21-
; CHECK-NEXT: [[TMP3:%.*]] = fadd <vscale x 2 x float> [[A]], [[TMP1]]
20+
; CHECK-NEXT: [[TMP3:%.*]] = fadd <vscale x 2 x float> [[A]], splat (float -1.000000e+00)
2221
; CHECK-NEXT: ret <vscale x 2 x float> [[TMP3]]
2322
;
2423
%2 = fpext <vscale x 2 x float> %a to <vscale x 2 x double>

llvm/test/Transforms/InstCombine/scalable-trunc.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ entry:
2222

2323
define <vscale x 1 x i8> @constant_splat_trunc() {
2424
; CHECK-LABEL: @constant_splat_trunc(
25-
; CHECK-NEXT: ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
25+
; CHECK-NEXT: ret <vscale x 1 x i8> splat (i8 1)
2626
;
2727
%1 = trunc <vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>
2828
ret <vscale x 1 x i8> %1
2929
}
3030

3131
define <vscale x 1 x i8> @constant_splat_trunc_constantexpr() {
3232
; CHECK-LABEL: @constant_splat_trunc_constantexpr(
33-
; CHECK-NEXT: ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
33+
; CHECK-NEXT: ret <vscale x 1 x i8> splat (i8 1)
3434
;
3535
ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
3636
}

0 commit comments

Comments
 (0)