Skip to content

[ConstantFold] Support scalable constant splats in ConstantFoldCastInstruction #133207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions llvm/lib/IR/ConstantFold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,9 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
// If the cast operand is a constant vector, perform the cast by
// operating on each element. In the cast of bitcasts, the element
// count may be mismatched; don't attempt to handle that here.
if ((isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) &&
DestTy->isVectorTy() &&
cast<FixedVectorType>(DestTy)->getNumElements() ==
cast<FixedVectorType>(V->getType())->getNumElements()) {
if (DestTy->isVectorTy() && V->getType()->isVectorTy() &&
cast<VectorType>(DestTy)->getElementCount() ==
cast<VectorType>(V->getType())->getElementCount()) {
VectorType *DestVecTy = cast<VectorType>(DestTy);
Type *DstEltTy = DestVecTy->getElementType();
// Fast path for splatted constants.
Expand All @@ -174,6 +173,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
return ConstantVector::getSplat(
cast<VectorType>(DestTy)->getElementCount(), Res);
}
if (isa<ScalableVectorType>(DestTy))
return nullptr;
SmallVector<Constant *, 16> res;
Type *Ty = IntegerType::get(V->getContext(), 32);
for (unsigned i = 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
@g = global [21 x i32] zeroinitializer
define i32 @test1(i32 %a) {
; CHECK-LABEL: @test1(
; CHECK-NEXT: [[T:%.*]] = sub i32 [[A:%.*]], extractelement (<4 x i32> ptrtoint (<4 x ptr> getelementptr inbounds ([21 x i32], ptr @g, <4 x i32> zeroinitializer, <4 x i32> <i32 1, i32 2, i32 3, i32 17>) to <4 x i32>), i32 3)
; CHECK-NEXT: [[T:%.*]] = sub i32 [[A:%.*]], ptrtoint (ptr getelementptr inbounds ([21 x i32], ptr @g, i32 0, i32 17) to i32)
; CHECK-NEXT: ret i32 [[T]]
;
%t = sub i32 %a, extractelement (<4 x i32> ptrtoint (<4 x ptr> getelementptr inbounds ([21 x i32], ptr @g, <4 x i32> zeroinitializer, <4 x i32> <i32 1, i32 2, i32 3, i32 17>) to <4 x i32>), i32 3)
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/Transforms/InstCombine/addrspacecast.ll
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ define ptr addrspace(4) @constant_fold_undef() #0 {

define <4 x ptr addrspace(4)> @constant_fold_null_vector() #0 {
; CHECK-LABEL: @constant_fold_null_vector(
; CHECK-NEXT: ret <4 x ptr addrspace(4)> addrspacecast (<4 x ptr addrspace(3)> zeroinitializer to <4 x ptr addrspace(4)>)
; CHECK-NEXT: ret <4 x ptr addrspace(4)> <ptr addrspace(4) addrspacecast (ptr addrspace(3) null to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr addrspace(3) null to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr addrspace(3) null to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr addrspace(3) null to ptr addrspace(4))>
;
%cast = addrspacecast <4 x ptr addrspace(3)> zeroinitializer to <4 x ptr addrspace(4)>
ret <4 x ptr addrspace(4)> %cast
Expand Down
3 changes: 1 addition & 2 deletions llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ define <vscale x 2 x float> @shrink_splat_scalable_extend(<vscale x 2 x float> %
define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(<vscale x 2 x float> %a) {
; CHECK-LABEL: define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(
; CHECK-SAME: <vscale x 2 x float> [[A:%.*]]) {
; CHECK-NEXT: [[TMP1:%.*]] = fptrunc <vscale x 2 x double> splat (double -1.000000e+00) to <vscale x 2 x float>
; CHECK-NEXT: [[TMP3:%.*]] = fadd <vscale x 2 x float> [[A]], [[TMP1]]
; CHECK-NEXT: [[TMP3:%.*]] = fadd <vscale x 2 x float> [[A]], splat (float -1.000000e+00)
; CHECK-NEXT: ret <vscale x 2 x float> [[TMP3]]
;
%2 = fpext <vscale x 2 x float> %a to <vscale x 2 x double>
Expand Down
15 changes: 15 additions & 0 deletions llvm/test/Transforms/InstCombine/scalable-trunc.ll
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@ entry:
ret void
}

define <vscale x 1 x i8> @constant_splat_trunc() {
; CHECK-LABEL: @constant_splat_trunc(
; CHECK-NEXT: ret <vscale x 1 x i8> splat (i8 1)
;
%1 = trunc <vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>
ret <vscale x 1 x i8> %1
}

define <vscale x 1 x i8> @constant_splat_trunc_constantexpr() {
; CHECK-LABEL: @constant_splat_trunc_constantexpr(
; CHECK-NEXT: ret <vscale x 1 x i8> splat (i8 1)
;
ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
}

declare void @llvm.aarch64.sve.st1.nxv2i32(<vscale x 2 x i32>, <vscale x 2 x i1>, ptr)
declare <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1>)
declare <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 %pattern)
4 changes: 2 additions & 2 deletions llvm/test/Transforms/InstSimplify/ConstProp/cast-vector.ll
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
define <2 x i16> @test1() {
; CHECK-LABEL: @test1(
; CHECK-NEXT: entry:
; CHECK-NEXT: ret <2 x i16> ptrtoint (<2 x ptr> getelementptr inbounds ([10 x i32], ptr null, <2 x i64> zeroinitializer, <2 x i64> <i64 5, i64 7>) to <2 x i16>)
; CHECK-NEXT: ret <2 x i16> <i16 ptrtoint (ptr getelementptr inbounds ([10 x i32], ptr null, i64 0, i64 5) to i16), i16 ptrtoint (ptr getelementptr inbounds ([10 x i32], ptr null, i64 0, i64 7) to i16)>
;
entry:
%gep = getelementptr inbounds [10 x i32], ptr null, i16 0, <2 x i16> <i16 5, i16 7>
Expand All @@ -23,7 +23,7 @@ entry:
define <2 x i16> @test2() {
; CHECK-LABEL: @test2(
; CHECK-NEXT: entry:
; CHECK-NEXT: ret <2 x i16> ptrtoint (<2 x ptr> getelementptr (i32, ptr null, <2 x i64> <i64 5, i64 7>) to <2 x i16>)
; CHECK-NEXT: ret <2 x i16> <i16 ptrtoint (ptr getelementptr (i32, ptr null, i64 5) to i16), i16 ptrtoint (ptr getelementptr (i32, ptr null, i64 7) to i16)>
;
entry:
%gep = getelementptr i32, ptr null, <2 x i16> <i16 5, i16 7>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ define <vscale x 4 x i32> @shufflevector() {

define <vscale x 4 x float> @bitcast() {
; CHECK-LABEL: @bitcast(
; CHECK-NEXT: ret <vscale x 4 x float> bitcast (<vscale x 4 x i32> splat (i32 1) to <vscale x 4 x float>)
; CHECK-NEXT: ret <vscale x 4 x float> splat (float 0x36A0000000000000)
;
%i1 = insertelement <vscale x 4 x i32> poison, i32 1, i32 0
%i2 = shufflevector <vscale x 4 x i32> %i1, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/Transforms/InstSimplify/ConstProp/vscale.ll
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ define <vscale x 4 x i32> @shufflevector() {

define <vscale x 4 x float> @bitcast() {
; CHECK-LABEL: @bitcast(
; CHECK-NEXT: ret <vscale x 4 x float> bitcast (<vscale x 4 x i32> splat (i32 1) to <vscale x 4 x float>)
; CHECK-NEXT: ret <vscale x 4 x float> splat (float 0x36A0000000000000)
;
%i1 = insertelement <vscale x 4 x i32> undef, i32 1, i32 0
%i2 = shufflevector <vscale x 4 x i32> %i1, <vscale x 4 x i32> undef, <vscale x 4 x i32> zeroinitializer
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/Transforms/InstSimplify/vscale-inseltpoison.ll
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ define <vscale x 2 x i1> @cmp_le_smax_always_true(<vscale x 2 x i64> %x) {

define <vscale x 4 x float> @bitcast() {
; CHECK-LABEL: @bitcast(
; CHECK-NEXT: ret <vscale x 4 x float> bitcast (<vscale x 4 x i32> splat (i32 1) to <vscale x 4 x float>)
; CHECK-NEXT: ret <vscale x 4 x float> splat (float 0x36A0000000000000)
;
%i1 = insertelement <vscale x 4 x i32> poison, i32 1, i32 0
%i2 = shufflevector <vscale x 4 x i32> %i1, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/Transforms/InstSimplify/vscale.ll
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ define <vscale x 2 x i1> @cmp_le_smax_always_true(<vscale x 2 x i64> %x) {

define <vscale x 4 x float> @bitcast() {
; CHECK-LABEL: @bitcast(
; CHECK-NEXT: ret <vscale x 4 x float> bitcast (<vscale x 4 x i32> splat (i32 1) to <vscale x 4 x float>)
; CHECK-NEXT: ret <vscale x 4 x float> splat (float 0x36A0000000000000)
;
%i1 = insertelement <vscale x 4 x i32> undef, i32 1, i32 0
%i2 = shufflevector <vscale x 4 x i32> %i1, <vscale x 4 x i32> undef, <vscale x 4 x i32> zeroinitializer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ define void @iv_casts(ptr %dst, ptr %src, i32 %x, i64 %N) #0 {
; DEFAULT-NEXT: [[TMP31:%.*]] = zext <vscale x 8 x i8> [[WIDE_LOAD4]] to <vscale x 8 x i16>
; DEFAULT-NEXT: [[TMP32:%.*]] = or <vscale x 8 x i16> [[TMP28]], [[TMP30]]
; DEFAULT-NEXT: [[TMP33:%.*]] = or <vscale x 8 x i16> [[TMP29]], [[TMP31]]
; DEFAULT-NEXT: [[TMP34:%.*]] = lshr <vscale x 8 x i16> [[TMP32]], trunc (<vscale x 8 x i32> splat (i32 1) to <vscale x 8 x i16>)
; DEFAULT-NEXT: [[TMP35:%.*]] = lshr <vscale x 8 x i16> [[TMP33]], trunc (<vscale x 8 x i32> splat (i32 1) to <vscale x 8 x i16>)
; DEFAULT-NEXT: [[TMP34:%.*]] = lshr <vscale x 8 x i16> [[TMP32]], splat (i16 1)
; DEFAULT-NEXT: [[TMP35:%.*]] = lshr <vscale x 8 x i16> [[TMP33]], splat (i16 1)
; DEFAULT-NEXT: [[TMP36:%.*]] = trunc <vscale x 8 x i16> [[TMP34]] to <vscale x 8 x i8>
; DEFAULT-NEXT: [[TMP37:%.*]] = trunc <vscale x 8 x i16> [[TMP35]] to <vscale x 8 x i8>
; DEFAULT-NEXT: [[TMP38:%.*]] = getelementptr i8, ptr [[DST]], i64 [[INDEX]]
Expand Down Expand Up @@ -131,7 +131,7 @@ define void @iv_casts(ptr %dst, ptr %src, i32 %x, i64 %N) #0 {
; PRED-NEXT: [[TMP22:%.*]] = mul <vscale x 16 x i16> [[TMP17]], [[TMP16]]
; PRED-NEXT: [[TMP24:%.*]] = zext <vscale x 16 x i8> [[WIDE_MASKED_LOAD]] to <vscale x 16 x i16>
; PRED-NEXT: [[TMP20:%.*]] = or <vscale x 16 x i16> [[TMP22]], [[TMP24]]
; PRED-NEXT: [[TMP21:%.*]] = lshr <vscale x 16 x i16> [[TMP20]], trunc (<vscale x 16 x i32> splat (i32 1) to <vscale x 16 x i16>)
; PRED-NEXT: [[TMP21:%.*]] = lshr <vscale x 16 x i16> [[TMP20]], splat (i16 1)
; PRED-NEXT: [[TMP23:%.*]] = trunc <vscale x 16 x i16> [[TMP21]] to <vscale x 16 x i8>
; PRED-NEXT: [[TMP26:%.*]] = getelementptr i8, ptr [[DST]], i64 [[INDEX]]
; PRED-NEXT: [[TMP27:%.*]] = getelementptr i8, ptr [[TMP26]], i32 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ define void @truncate_to_i1_used_by_branch(i8 %x, ptr %dst) #0 {
; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <vscale x 4 x i8> [[BROADCAST_SPLATINSERT]], <vscale x 4 x i8> poison, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP6:%.*]] = trunc i32 [[N_VEC]] to i8
; CHECK-NEXT: [[TMP7:%.*]] = trunc <vscale x 4 x i8> [[BROADCAST_SPLAT]] to <vscale x 4 x i1>
; CHECK-NEXT: [[TMP8:%.*]] = or <vscale x 4 x i1> trunc (<vscale x 4 x i8> splat (i8 23) to <vscale x 4 x i1>), [[TMP7]]
; CHECK-NEXT: [[TMP8:%.*]] = or <vscale x 4 x i1> splat (i1 true), [[TMP7]]
; CHECK-NEXT: [[BROADCAST_SPLATINSERT1:%.*]] = insertelement <vscale x 4 x ptr> poison, ptr [[DST]], i64 0
; CHECK-NEXT: [[BROADCAST_SPLAT2:%.*]] = shufflevector <vscale x 4 x ptr> [[BROADCAST_SPLATINSERT1]], <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ define void @truncate_to_minimal_bitwidths_widen_cast_recipe(ptr %src) {
; CHECK-NEXT: [[VP_OP_LOAD:%.*]] = call <vscale x 8 x i8> @llvm.vp.load.nxv8i8.p0(ptr align 1 [[TMP6]], <vscale x 8 x i1> splat (i1 true), i32 [[TMP7]])
; CHECK-NEXT: [[TMP8:%.*]] = zext <vscale x 8 x i8> [[VP_OP_LOAD]] to <vscale x 8 x i16>
; CHECK-NEXT: [[TMP12:%.*]] = mul <vscale x 8 x i16> zeroinitializer, [[TMP8]]
; CHECK-NEXT: [[TMP13:%.*]] = lshr <vscale x 8 x i16> [[TMP12]], trunc (<vscale x 8 x i32> splat (i32 1) to <vscale x 8 x i16>)
; CHECK-NEXT: [[TMP13:%.*]] = lshr <vscale x 8 x i16> [[TMP12]], splat (i16 1)
; CHECK-NEXT: [[TMP14:%.*]] = trunc <vscale x 8 x i16> [[TMP13]] to <vscale x 8 x i8>
; CHECK-NEXT: call void @llvm.vp.scatter.nxv8i8.nxv8p0(<vscale x 8 x i8> [[TMP14]], <vscale x 8 x ptr> align 1 zeroinitializer, <vscale x 8 x i1> splat (i1 true), i32 [[TMP7]])
; CHECK-NEXT: [[TMP9:%.*]] = zext i32 [[TMP7]] to i64
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/Transforms/MemCpyOpt/crash.ll
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ define void @test2(i32 %cmd) nounwind {

define void @inttoptr_constexpr_crash(ptr %p) {
; CHECK-LABEL: @inttoptr_constexpr_crash(
; CHECK-NEXT: store <1 x ptr> inttoptr (<1 x i16> bitcast (<2 x i8> <i8 ptrtoint (ptr @g to i8), i8 ptrtoint (ptr @g to i8)> to <1 x i16>) to <1 x ptr>), ptr [[P:%.*]], align 1
; CHECK-NEXT: store <1 x ptr> <ptr inttoptr (i16 extractelement (<1 x i16> bitcast (<2 x i8> <i8 ptrtoint (ptr @g to i8), i8 ptrtoint (ptr @g to i8)> to <1 x i16>), i32 0) to ptr)>, ptr [[P:%.*]], align 1
; CHECK-NEXT: ret void
;
store <1 x ptr> inttoptr (<1 x i16> bitcast (<2 x i8> <i8 ptrtoint (ptr @g to i8), i8 ptrtoint (ptr @g to i8)> to <1 x i16>) to <1 x ptr>), ptr %p, align 1
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/Transforms/VectorCombine/pr88796.ll
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
define i32 @test() {
; CHECK-LABEL: define i32 @test() {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = tail call i16 @llvm.vector.reduce.and.nxv8i16(<vscale x 8 x i16> trunc (<vscale x 8 x i32> splat (i32 268435456) to <vscale x 8 x i16>))
; CHECK-NEXT: [[TMP0:%.*]] = tail call i16 @llvm.vector.reduce.and.nxv8i16(<vscale x 8 x i16> zeroinitializer)
; CHECK-NEXT: ret i32 0
;
entry:
Expand Down
Loading