-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[VectorCombine] Scalarize bin ops and cmps with two splatted operands #137786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1035,50 +1035,61 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) { | |
if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value()))) | ||
return false; | ||
|
||
// Match against one or both scalar values being inserted into constant | ||
// vectors: | ||
// vec_op VecC0, (inselt VecC1, V1, Index) | ||
// vec_op (inselt VecC0, V0, Index), VecC1 | ||
// vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) | ||
// TODO: Deal with mismatched index constants and variable indexes? | ||
Constant *VecC0 = nullptr, *VecC1 = nullptr; | ||
Value *V0 = nullptr, *V1 = nullptr; | ||
uint64_t Index0 = 0, Index1 = 0; | ||
if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0), | ||
m_ConstantInt(Index0))) && | ||
!match(Ins0, m_Constant(VecC0))) | ||
return false; | ||
if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1), | ||
m_ConstantInt(Index1))) && | ||
!match(Ins1, m_Constant(VecC1))) | ||
return false; | ||
std::optional<uint64_t> Index; | ||
|
||
// Try and match against two splatted operands first. | ||
// vec_op (splat V0), (splat V1) | ||
V0 = getSplatValue(Ins0); | ||
V1 = getSplatValue(Ins1); | ||
if (!V0 || !V1) { | ||
// Match against one or both scalar values being inserted into constant | ||
// vectors: | ||
// vec_op VecC0, (inselt VecC1, V1, Index) | ||
// vec_op (inselt VecC0, V0, Index), VecC1 | ||
// vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) | ||
// TODO: Deal with mismatched index constants and variable indexes? | ||
V0 = nullptr, V1 = nullptr; | ||
uint64_t Index0 = 0, Index1 = 0; | ||
if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0), | ||
m_ConstantInt(Index0))) && | ||
!match(Ins0, m_Constant(VecC0))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like the prior code did handle constant splats here. I think most of the m_Constant dependent bits become dead with your change? Of, this is because of the coupled condition noted just above. |
||
return false; | ||
if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1), | ||
m_ConstantInt(Index1))) && | ||
!match(Ins1, m_Constant(VecC1))) | ||
return false; | ||
|
||
bool IsConst0 = !V0; | ||
bool IsConst1 = !V1; | ||
if (IsConst0 && IsConst1) | ||
return false; | ||
if (!IsConst0 && !IsConst1 && Index0 != Index1) | ||
return false; | ||
bool IsConst0 = !V0; | ||
bool IsConst1 = !V1; | ||
if (IsConst0 && IsConst1) | ||
return false; | ||
if (!IsConst0 && !IsConst1 && Index0 != Index1) | ||
return false; | ||
|
||
auto *VecTy0 = cast<VectorType>(Ins0->getType()); | ||
auto *VecTy1 = cast<VectorType>(Ins1->getType()); | ||
if (VecTy0->getElementCount().getKnownMinValue() <= Index0 || | ||
VecTy1->getElementCount().getKnownMinValue() <= Index1) | ||
return false; | ||
auto *VecTy0 = cast<VectorType>(Ins0->getType()); | ||
auto *VecTy1 = cast<VectorType>(Ins1->getType()); | ||
if (VecTy0->getElementCount().getKnownMinValue() <= Index0 || | ||
VecTy1->getElementCount().getKnownMinValue() <= Index1) | ||
return false; | ||
|
||
// Bail for single insertion if it is a load. | ||
// TODO: Handle this once getVectorInstrCost can cost for load/stores. | ||
auto *I0 = dyn_cast_or_null<Instruction>(V0); | ||
auto *I1 = dyn_cast_or_null<Instruction>(V1); | ||
if ((IsConst0 && I1 && I1->mayReadFromMemory()) || | ||
(IsConst1 && I0 && I0->mayReadFromMemory())) | ||
return false; | ||
// Bail for single insertion if it is a load. | ||
// TODO: Handle this once getVectorInstrCost can cost for load/stores. | ||
auto *I0 = dyn_cast_or_null<Instruction>(V0); | ||
auto *I1 = dyn_cast_or_null<Instruction>(V1); | ||
if ((IsConst0 && I1 && I1->mayReadFromMemory()) || | ||
(IsConst1 && I0 && I0->mayReadFromMemory())) | ||
return false; | ||
|
||
Index = IsConst0 ? Index1 : Index0; | ||
} | ||
|
||
uint64_t Index = IsConst0 ? Index1 : Index0; | ||
Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType(); | ||
Type *VecTy = I.getType(); | ||
auto *VecTy = cast<VectorType>(I.getType()); | ||
Type *ScalarTy = VecTy->getElementType(); | ||
assert(VecTy->isVectorTy() && | ||
(IsConst0 || IsConst1 || V0->getType() == V1->getType()) && | ||
(isa<Constant>(Ins0) || isa<Constant>(Ins1) || | ||
V0->getType() == V1->getType()) && | ||
(ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() || | ||
ScalarTy->isPointerTy()) && | ||
"Unexpected types for insert element into binop or cmp"); | ||
|
@@ -1099,29 +1110,33 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) { | |
// Get cost estimate for the insert element. This cost will factor into | ||
// both sequences. | ||
InstructionCost InsertCost = TTI.getVectorInstrCost( | ||
Instruction::InsertElement, VecTy, CostKind, Index); | ||
InstructionCost OldCost = | ||
(IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) + VectorOpCost; | ||
InstructionCost NewCost = ScalarOpCost + InsertCost + | ||
(IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) + | ||
(IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost); | ||
Instruction::InsertElement, VecTy, CostKind, Index.value_or(0)); | ||
InstructionCost OldCost = (isa<Constant>(Ins0) ? 0 : InsertCost) + | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The costing here looks suspect. In particular, I would assume that (splat x) + (splat y) needs to be costed differently that (insertelt undef, x, 0) + (insertelt undef, y, 0). Consider the case where this is a m8 vector type. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just noting that as far as I can tell we don't have a "splat from scalar" cost hook, which is why I originally kept it as just an insert. I think the technically correct cost would be an insert element + broadcast shuffle. But that would overcost it on RISC-V I presume? |
||
(isa<Constant>(Ins1) ? 0 : InsertCost) + | ||
VectorOpCost; | ||
InstructionCost NewCost = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same basic costing problem with the NewCost too. |
||
ScalarOpCost + InsertCost + | ||
(isa<Constant>(Ins0) ? 0 : !Ins0->hasOneUse() * InsertCost) + | ||
(isa<Constant>(Ins1) ? 0 : !Ins1->hasOneUse() * InsertCost); | ||
|
||
// We want to scalarize unless the vector variant actually has lower cost. | ||
if (OldCost < NewCost || !NewCost.isValid()) | ||
return false; | ||
|
||
// vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) --> | ||
// inselt NewVecC, (scalar_op V0, V1), Index | ||
// | ||
// vec_op (splat V0), (splat V1) --> splat (scalar_op V0, V1) | ||
if (IsCmp) | ||
++NumScalarCmp; | ||
else | ||
++NumScalarBO; | ||
|
||
// For constant cases, extract the scalar element, this should constant fold. | ||
if (IsConst0) | ||
V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index)); | ||
if (IsConst1) | ||
V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index)); | ||
if (Index && isa<Constant>(Ins0)) | ||
V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(*Index)); | ||
if (Index && isa<Constant>(Ins1)) | ||
V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(*Index)); | ||
|
||
Value *Scalar = | ||
IsCmp ? Builder.CreateCmp(Pred, V0, V1) | ||
|
@@ -1134,12 +1149,16 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) { | |
if (auto *ScalarInst = dyn_cast<Instruction>(Scalar)) | ||
ScalarInst->copyIRFlags(&I); | ||
|
||
// Fold the vector constants in the original vectors into a new base vector. | ||
Value *NewVecC = | ||
IsCmp ? Builder.CreateCmp(Pred, VecC0, VecC1) | ||
: Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1); | ||
Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index); | ||
replaceValue(I, *Insert); | ||
Value *Result; | ||
if (Index) { | ||
// Fold the vector constants in the original vectors into a new base vector. | ||
Value *NewVecC = IsCmp ? Builder.CreateCmp(Pred, VecC0, VecC1) | ||
: Builder.CreateBinOp((Instruction::BinaryOps)Opcode, | ||
VecC0, VecC1); | ||
Result = Builder.CreateInsertElement(NewVecC, Scalar, *Index); | ||
} else | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. matching curly braces |
||
Result = Builder.CreateVectorSplat(VecTy->getElementCount(), Scalar); | ||
replaceValue(I, *Result); | ||
return true; | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 | ||
; RUN: opt -S -passes=vector-combine < %s | FileCheck %s | ||
|
||
define <4 x i32> @add_v4i32(i32 %x, i32 %y) { | ||
; CHECK-LABEL: define <4 x i32> @add_v4i32( | ||
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) { | ||
; CHECK-NEXT: [[RES_SCALAR:%.*]] = add i32 [[X]], 42 | ||
; CHECK-NEXT: [[X_HEAD:%.*]] = insertelement <4 x i32> poison, i32 [[RES_SCALAR]], i64 0 | ||
; CHECK-NEXT: [[X_SPLAT:%.*]] = shufflevector <4 x i32> [[X_HEAD]], <4 x i32> poison, <4 x i32> zeroinitializer | ||
; CHECK-NEXT: ret <4 x i32> [[X_SPLAT]] | ||
; | ||
%x.head = insertelement <4 x i32> poison, i32 %x, i32 0 | ||
%x.splat = shufflevector <4 x i32> %x.head, <4 x i32> poison, <4 x i32> zeroinitializer | ||
%res = add <4 x i32> %x.splat, splat (i32 42) | ||
ret <4 x i32> %res | ||
} | ||
|
||
define <vscale x 4 x i32> @add_nxv4i32(i32 %x, i32 %y) { | ||
; CHECK-LABEL: define <vscale x 4 x i32> @add_nxv4i32( | ||
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) { | ||
; CHECK-NEXT: [[RES_SCALAR:%.*]] = add i32 [[X]], 42 | ||
; CHECK-NEXT: [[Y_HEAD1:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[RES_SCALAR]], i64 0 | ||
; CHECK-NEXT: [[Y_SPLAT1:%.*]] = shufflevector <vscale x 4 x i32> [[Y_HEAD1]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer | ||
; CHECK-NEXT: ret <vscale x 4 x i32> [[Y_SPLAT1]] | ||
; | ||
%x.head = insertelement <vscale x 4 x i32> poison, i32 %x, i32 0 | ||
%x.splat = shufflevector <vscale x 4 x i32> %x.head, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer | ||
%res = add <vscale x 4 x i32> %x.splat, splat (i32 42) | ||
ret <vscale x 4 x i32> %res | ||
} | ||
|
||
; Make sure that we can scalarize sequences of vector instructions. | ||
define <4 x i32> @add_mul_v4i32(i32 %x, i32 %y, i32 %z) { | ||
; CHECK-LABEL: define <4 x i32> @add_mul_v4i32( | ||
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) { | ||
; CHECK-NEXT: [[RES0_SCALAR:%.*]] = add i32 [[X]], 42 | ||
; CHECK-NEXT: [[RES1_SCALAR:%.*]] = mul i32 [[RES0_SCALAR]], 42 | ||
; CHECK-NEXT: [[Z_HEAD1:%.*]] = insertelement <4 x i32> poison, i32 [[RES1_SCALAR]], i64 0 | ||
; CHECK-NEXT: [[Z_SPLAT1:%.*]] = shufflevector <4 x i32> [[Z_HEAD1]], <4 x i32> poison, <4 x i32> zeroinitializer | ||
; CHECK-NEXT: ret <4 x i32> [[Z_SPLAT1]] | ||
; | ||
%x.head = insertelement <4 x i32> poison, i32 %x, i32 0 | ||
%x.splat = shufflevector <4 x i32> %x.head, <4 x i32> poison, <4 x i32> zeroinitializer | ||
%res0 = add <4 x i32> %x.splat, splat (i32 42) | ||
%res1 = mul <4 x i32> %res0, splat (i32 42) | ||
ret <4 x i32> %res1 | ||
} | ||
|
||
; Shouldn't be scalarized since %x.splat and %y.splat have other users. | ||
define <4 x i32> @other_users_v4i32(i32 %x, i32 %y, ptr %p, ptr %q) { | ||
; CHECK-LABEL: define <4 x i32> @other_users_v4i32( | ||
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], ptr [[P:%.*]], ptr [[Q:%.*]]) { | ||
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <4 x i32> poison, i32 [[X]], i32 0 | ||
; CHECK-NEXT: [[RES:%.*]] = shufflevector <4 x i32> [[DOTSPLATINSERT]], <4 x i32> poison, <4 x i32> zeroinitializer | ||
; CHECK-NEXT: [[RES1:%.*]] = add <4 x i32> [[RES]], splat (i32 42) | ||
; CHECK-NEXT: store <4 x i32> [[RES]], ptr [[P]], align 16 | ||
; CHECK-NEXT: store <4 x i32> [[RES]], ptr [[Q]], align 16 | ||
; CHECK-NEXT: ret <4 x i32> [[RES1]] | ||
; | ||
%x.head = insertelement <4 x i32> poison, i32 %x, i32 0 | ||
%x.splat = shufflevector <4 x i32> %x.head, <4 x i32> poison, <4 x i32> zeroinitializer | ||
%res = add <4 x i32> %x.splat, splat (i32 42) | ||
store <4 x i32> %x.splat, ptr %p | ||
store <4 x i32> %x.splat, ptr %q | ||
ret <4 x i32> %res | ||
} | ||
|
||
define <4 x i1> @icmp_v4i32(i32 %x, i32 %y) { | ||
; CHECK-LABEL: define <4 x i1> @icmp_v4i32( | ||
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) { | ||
; CHECK-NEXT: [[RES_SCALAR:%.*]] = icmp eq i32 [[X]], 42 | ||
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <4 x i1> poison, i1 [[RES_SCALAR]], i64 0 | ||
; CHECK-NEXT: [[RES:%.*]] = shufflevector <4 x i1> [[DOTSPLATINSERT]], <4 x i1> poison, <4 x i32> zeroinitializer | ||
; CHECK-NEXT: ret <4 x i1> [[RES]] | ||
; | ||
%x.head = insertelement <4 x i32> poison, i32 %x, i32 0 | ||
%x.splat = shufflevector <4 x i32> %x.head, <4 x i32> poison, <4 x i32> zeroinitializer | ||
%res = icmp eq <4 x i32> %x.splat, splat (i32 42) | ||
ret <4 x i1> %res | ||
} | ||
|
||
define <vscale x 4 x i1> @icmp_nxv4i32(i32 %x, i32 %y) { | ||
; CHECK-LABEL: define <vscale x 4 x i1> @icmp_nxv4i32( | ||
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) { | ||
; CHECK-NEXT: [[RES_SCALAR:%.*]] = icmp eq i32 [[X]], 42 | ||
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 4 x i1> poison, i1 [[RES_SCALAR]], i64 0 | ||
; CHECK-NEXT: [[RES:%.*]] = shufflevector <vscale x 4 x i1> [[DOTSPLATINSERT]], <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer | ||
; CHECK-NEXT: ret <vscale x 4 x i1> [[RES]] | ||
; | ||
%x.head = insertelement <vscale x 4 x i32> poison, i32 %x, i32 0 | ||
%x.splat = shufflevector <vscale x 4 x i32> %x.head, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer | ||
%res = icmp eq <vscale x 4 x i32> %x.splat, splat (i32 42) | ||
ret <vscale x 4 x i1> %res | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you're going to need this if (!VO) { do something; } if (!V1) { do_something } if (indices don't match) { bail; }