Skip to content

[VectorCombine] Scalarize bin ops and cmps with two splatted operands #137786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 72 additions & 53 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1035,50 +1035,61 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
return false;

// Match against one or both scalar values being inserted into constant
// vectors:
// vec_op VecC0, (inselt VecC1, V1, Index)
// vec_op (inselt VecC0, V0, Index), VecC1
// vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
// TODO: Deal with mismatched index constants and variable indexes?
Constant *VecC0 = nullptr, *VecC1 = nullptr;
Value *V0 = nullptr, *V1 = nullptr;
uint64_t Index0 = 0, Index1 = 0;
if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
m_ConstantInt(Index0))) &&
!match(Ins0, m_Constant(VecC0)))
return false;
if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
m_ConstantInt(Index1))) &&
!match(Ins1, m_Constant(VecC1)))
return false;
std::optional<uint64_t> Index;

// Try and match against two splatted operands first.
// vec_op (splat V0), (splat V1)
V0 = getSplatValue(Ins0);
V1 = getSplatValue(Ins1);
if (!V0 || !V1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're going to need this if (!VO) { do something; } if (!V1) { do_something } if (indices don't match) { bail; }

// Match against one or both scalar values being inserted into constant
// vectors:
// vec_op VecC0, (inselt VecC1, V1, Index)
// vec_op (inselt VecC0, V0, Index), VecC1
// vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
// TODO: Deal with mismatched index constants and variable indexes?
V0 = nullptr, V1 = nullptr;
uint64_t Index0 = 0, Index1 = 0;
if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
m_ConstantInt(Index0))) &&
!match(Ins0, m_Constant(VecC0)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the prior code did handle constant splats here. I think most of the m_Constant dependent bits become dead with your change? Of, this is because of the coupled condition noted just above.

return false;
if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
m_ConstantInt(Index1))) &&
!match(Ins1, m_Constant(VecC1)))
return false;

bool IsConst0 = !V0;
bool IsConst1 = !V1;
if (IsConst0 && IsConst1)
return false;
if (!IsConst0 && !IsConst1 && Index0 != Index1)
return false;
bool IsConst0 = !V0;
bool IsConst1 = !V1;
if (IsConst0 && IsConst1)
return false;
if (!IsConst0 && !IsConst1 && Index0 != Index1)
return false;

auto *VecTy0 = cast<VectorType>(Ins0->getType());
auto *VecTy1 = cast<VectorType>(Ins1->getType());
if (VecTy0->getElementCount().getKnownMinValue() <= Index0 ||
VecTy1->getElementCount().getKnownMinValue() <= Index1)
return false;
auto *VecTy0 = cast<VectorType>(Ins0->getType());
auto *VecTy1 = cast<VectorType>(Ins1->getType());
if (VecTy0->getElementCount().getKnownMinValue() <= Index0 ||
VecTy1->getElementCount().getKnownMinValue() <= Index1)
return false;

// Bail for single insertion if it is a load.
// TODO: Handle this once getVectorInstrCost can cost for load/stores.
auto *I0 = dyn_cast_or_null<Instruction>(V0);
auto *I1 = dyn_cast_or_null<Instruction>(V1);
if ((IsConst0 && I1 && I1->mayReadFromMemory()) ||
(IsConst1 && I0 && I0->mayReadFromMemory()))
return false;
// Bail for single insertion if it is a load.
// TODO: Handle this once getVectorInstrCost can cost for load/stores.
auto *I0 = dyn_cast_or_null<Instruction>(V0);
auto *I1 = dyn_cast_or_null<Instruction>(V1);
if ((IsConst0 && I1 && I1->mayReadFromMemory()) ||
(IsConst1 && I0 && I0->mayReadFromMemory()))
return false;

Index = IsConst0 ? Index1 : Index0;
}

uint64_t Index = IsConst0 ? Index1 : Index0;
Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
Type *VecTy = I.getType();
auto *VecTy = cast<VectorType>(I.getType());
Type *ScalarTy = VecTy->getElementType();
assert(VecTy->isVectorTy() &&
(IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
(isa<Constant>(Ins0) || isa<Constant>(Ins1) ||
V0->getType() == V1->getType()) &&
(ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
ScalarTy->isPointerTy()) &&
"Unexpected types for insert element into binop or cmp");
Expand All @@ -1099,29 +1110,33 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
// Get cost estimate for the insert element. This cost will factor into
// both sequences.
InstructionCost InsertCost = TTI.getVectorInstrCost(
Instruction::InsertElement, VecTy, CostKind, Index);
InstructionCost OldCost =
(IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) + VectorOpCost;
InstructionCost NewCost = ScalarOpCost + InsertCost +
(IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) +
(IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost);
Instruction::InsertElement, VecTy, CostKind, Index.value_or(0));
InstructionCost OldCost = (isa<Constant>(Ins0) ? 0 : InsertCost) +
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The costing here looks suspect. In particular, I would assume that (splat x) + (splat y) needs to be costed differently that (insertelt undef, x, 0) + (insertelt undef, y, 0). Consider the case where this is a m8 vector type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noting that as far as I can tell we don't have a "splat from scalar" cost hook, which is why I originally kept it as just an insert. I think the technically correct cost would be an insert element + broadcast shuffle. But that would overcost it on RISC-V I presume?

(isa<Constant>(Ins1) ? 0 : InsertCost) +
VectorOpCost;
InstructionCost NewCost =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same basic costing problem with the NewCost too.

ScalarOpCost + InsertCost +
(isa<Constant>(Ins0) ? 0 : !Ins0->hasOneUse() * InsertCost) +
(isa<Constant>(Ins1) ? 0 : !Ins1->hasOneUse() * InsertCost);

// We want to scalarize unless the vector variant actually has lower cost.
if (OldCost < NewCost || !NewCost.isValid())
return false;

// vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
// inselt NewVecC, (scalar_op V0, V1), Index
//
// vec_op (splat V0), (splat V1) --> splat (scalar_op V0, V1)
if (IsCmp)
++NumScalarCmp;
else
++NumScalarBO;

// For constant cases, extract the scalar element, this should constant fold.
if (IsConst0)
V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
if (IsConst1)
V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
if (Index && isa<Constant>(Ins0))
V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(*Index));
if (Index && isa<Constant>(Ins1))
V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(*Index));

Value *Scalar =
IsCmp ? Builder.CreateCmp(Pred, V0, V1)
Expand All @@ -1134,12 +1149,16 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
ScalarInst->copyIRFlags(&I);

// Fold the vector constants in the original vectors into a new base vector.
Value *NewVecC =
IsCmp ? Builder.CreateCmp(Pred, VecC0, VecC1)
: Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1);
Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
replaceValue(I, *Insert);
Value *Result;
if (Index) {
// Fold the vector constants in the original vectors into a new base vector.
Value *NewVecC = IsCmp ? Builder.CreateCmp(Pred, VecC0, VecC1)
: Builder.CreateBinOp((Instruction::BinaryOps)Opcode,
VecC0, VecC1);
Result = Builder.CreateInsertElement(NewVecC, Scalar, *Index);
} else
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

matching curly braces

Result = Builder.CreateVectorSplat(VecTy->getElementCount(), Scalar);
replaceValue(I, *Result);
return true;
}

Expand Down
48 changes: 36 additions & 12 deletions llvm/test/Transforms/VectorCombine/X86/shuffle-of-cmps.ll
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,24 @@ define <4 x i32> @shuf_icmp_ugt_v4i32_use(<4 x i32> %x, <4 x i32> %y, <4 x i32>
; PR121110 - don't merge equivalent (but not matching) predicates

define <2 x i1> @PR121110() {
; CHECK-LABEL: define <2 x i1> @PR121110(
; CHECK-SAME: ) #[[ATTR0]] {
; CHECK-NEXT: [[UGT:%.*]] = icmp samesign ugt <2 x i32> zeroinitializer, zeroinitializer
; CHECK-NEXT: [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
; CHECK-NEXT: [[RES:%.*]] = shufflevector <2 x i1> [[UGT]], <2 x i1> [[SGT]], <2 x i32> <i32 0, i32 3>
; CHECK-NEXT: ret <2 x i1> [[RES]]
; SSE-LABEL: define <2 x i1> @PR121110(
; SSE-SAME: ) #[[ATTR0]] {
; SSE-NEXT: [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
; SSE-NEXT: [[RES:%.*]] = shufflevector <2 x i1> zeroinitializer, <2 x i1> [[SGT]], <2 x i32> <i32 0, i32 3>
; SSE-NEXT: ret <2 x i1> [[RES]]
;
; AVX2-LABEL: define <2 x i1> @PR121110(
; AVX2-SAME: ) #[[ATTR0]] {
; AVX2-NEXT: [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
; AVX2-NEXT: [[RES:%.*]] = shufflevector <2 x i1> zeroinitializer, <2 x i1> [[SGT]], <2 x i32> <i32 0, i32 3>
; AVX2-NEXT: ret <2 x i1> [[RES]]
;
; AVX512-LABEL: define <2 x i1> @PR121110(
; AVX512-SAME: ) #[[ATTR0]] {
; AVX512-NEXT: [[UGT:%.*]] = icmp samesign ugt <2 x i32> zeroinitializer, zeroinitializer
; AVX512-NEXT: [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
; AVX512-NEXT: [[RES:%.*]] = shufflevector <2 x i1> [[UGT]], <2 x i1> [[SGT]], <2 x i32> <i32 0, i32 3>
; AVX512-NEXT: ret <2 x i1> [[RES]]
;
%ugt = icmp samesign ugt <2 x i32> < i32 0, i32 0 >, < i32 0, i32 0 >
%sgt = icmp sgt <2 x i32> < i32 0, i32 0 >, < i32 6, i32 4294967292 >
Expand All @@ -285,12 +297,24 @@ define <2 x i1> @PR121110() {
}

define <2 x i1> @PR121110_commute() {
; CHECK-LABEL: define <2 x i1> @PR121110_commute(
; CHECK-SAME: ) #[[ATTR0]] {
; CHECK-NEXT: [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
; CHECK-NEXT: [[UGT:%.*]] = icmp samesign ugt <2 x i32> zeroinitializer, zeroinitializer
; CHECK-NEXT: [[RES:%.*]] = shufflevector <2 x i1> [[SGT]], <2 x i1> [[UGT]], <2 x i32> <i32 0, i32 3>
; CHECK-NEXT: ret <2 x i1> [[RES]]
; SSE-LABEL: define <2 x i1> @PR121110_commute(
; SSE-SAME: ) #[[ATTR0]] {
; SSE-NEXT: [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
; SSE-NEXT: [[RES:%.*]] = shufflevector <2 x i1> [[SGT]], <2 x i1> zeroinitializer, <2 x i32> <i32 0, i32 3>
; SSE-NEXT: ret <2 x i1> [[RES]]
;
; AVX2-LABEL: define <2 x i1> @PR121110_commute(
; AVX2-SAME: ) #[[ATTR0]] {
; AVX2-NEXT: [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
; AVX2-NEXT: [[RES:%.*]] = shufflevector <2 x i1> [[SGT]], <2 x i1> zeroinitializer, <2 x i32> <i32 0, i32 3>
; AVX2-NEXT: ret <2 x i1> [[RES]]
;
; AVX512-LABEL: define <2 x i1> @PR121110_commute(
; AVX512-SAME: ) #[[ATTR0]] {
; AVX512-NEXT: [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
; AVX512-NEXT: [[UGT:%.*]] = icmp samesign ugt <2 x i32> zeroinitializer, zeroinitializer
; AVX512-NEXT: [[RES:%.*]] = shufflevector <2 x i1> [[SGT]], <2 x i1> [[UGT]], <2 x i32> <i32 0, i32 3>
; AVX512-NEXT: ret <2 x i1> [[RES]]
;
%sgt = icmp sgt <2 x i32> < i32 0, i32 0 >, < i32 6, i32 4294967292 >
%ugt = icmp samesign ugt <2 x i32> < i32 0, i32 0 >, < i32 0, i32 0 >
Expand Down
94 changes: 94 additions & 0 deletions llvm/test/Transforms/VectorCombine/scalarize-binop.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes=vector-combine < %s | FileCheck %s

define <4 x i32> @add_v4i32(i32 %x, i32 %y) {
; CHECK-LABEL: define <4 x i32> @add_v4i32(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
; CHECK-NEXT: [[RES_SCALAR:%.*]] = add i32 [[X]], 42
; CHECK-NEXT: [[X_HEAD:%.*]] = insertelement <4 x i32> poison, i32 [[RES_SCALAR]], i64 0
; CHECK-NEXT: [[X_SPLAT:%.*]] = shufflevector <4 x i32> [[X_HEAD]], <4 x i32> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: ret <4 x i32> [[X_SPLAT]]
;
%x.head = insertelement <4 x i32> poison, i32 %x, i32 0
%x.splat = shufflevector <4 x i32> %x.head, <4 x i32> poison, <4 x i32> zeroinitializer
%res = add <4 x i32> %x.splat, splat (i32 42)
ret <4 x i32> %res
}

define <vscale x 4 x i32> @add_nxv4i32(i32 %x, i32 %y) {
; CHECK-LABEL: define <vscale x 4 x i32> @add_nxv4i32(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
; CHECK-NEXT: [[RES_SCALAR:%.*]] = add i32 [[X]], 42
; CHECK-NEXT: [[Y_HEAD1:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[RES_SCALAR]], i64 0
; CHECK-NEXT: [[Y_SPLAT1:%.*]] = shufflevector <vscale x 4 x i32> [[Y_HEAD1]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: ret <vscale x 4 x i32> [[Y_SPLAT1]]
;
%x.head = insertelement <vscale x 4 x i32> poison, i32 %x, i32 0
%x.splat = shufflevector <vscale x 4 x i32> %x.head, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
%res = add <vscale x 4 x i32> %x.splat, splat (i32 42)
ret <vscale x 4 x i32> %res
}

; Make sure that we can scalarize sequences of vector instructions.
define <4 x i32> @add_mul_v4i32(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: define <4 x i32> @add_mul_v4i32(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
; CHECK-NEXT: [[RES0_SCALAR:%.*]] = add i32 [[X]], 42
; CHECK-NEXT: [[RES1_SCALAR:%.*]] = mul i32 [[RES0_SCALAR]], 42
; CHECK-NEXT: [[Z_HEAD1:%.*]] = insertelement <4 x i32> poison, i32 [[RES1_SCALAR]], i64 0
; CHECK-NEXT: [[Z_SPLAT1:%.*]] = shufflevector <4 x i32> [[Z_HEAD1]], <4 x i32> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: ret <4 x i32> [[Z_SPLAT1]]
;
%x.head = insertelement <4 x i32> poison, i32 %x, i32 0
%x.splat = shufflevector <4 x i32> %x.head, <4 x i32> poison, <4 x i32> zeroinitializer
%res0 = add <4 x i32> %x.splat, splat (i32 42)
%res1 = mul <4 x i32> %res0, splat (i32 42)
ret <4 x i32> %res1
}

; Shouldn't be scalarized since %x.splat and %y.splat have other users.
define <4 x i32> @other_users_v4i32(i32 %x, i32 %y, ptr %p, ptr %q) {
; CHECK-LABEL: define <4 x i32> @other_users_v4i32(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], ptr [[P:%.*]], ptr [[Q:%.*]]) {
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <4 x i32> poison, i32 [[X]], i32 0
; CHECK-NEXT: [[RES:%.*]] = shufflevector <4 x i32> [[DOTSPLATINSERT]], <4 x i32> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: [[RES1:%.*]] = add <4 x i32> [[RES]], splat (i32 42)
; CHECK-NEXT: store <4 x i32> [[RES]], ptr [[P]], align 16
; CHECK-NEXT: store <4 x i32> [[RES]], ptr [[Q]], align 16
; CHECK-NEXT: ret <4 x i32> [[RES1]]
;
%x.head = insertelement <4 x i32> poison, i32 %x, i32 0
%x.splat = shufflevector <4 x i32> %x.head, <4 x i32> poison, <4 x i32> zeroinitializer
%res = add <4 x i32> %x.splat, splat (i32 42)
store <4 x i32> %x.splat, ptr %p
store <4 x i32> %x.splat, ptr %q
ret <4 x i32> %res
}

define <4 x i1> @icmp_v4i32(i32 %x, i32 %y) {
; CHECK-LABEL: define <4 x i1> @icmp_v4i32(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
; CHECK-NEXT: [[RES_SCALAR:%.*]] = icmp eq i32 [[X]], 42
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <4 x i1> poison, i1 [[RES_SCALAR]], i64 0
; CHECK-NEXT: [[RES:%.*]] = shufflevector <4 x i1> [[DOTSPLATINSERT]], <4 x i1> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: ret <4 x i1> [[RES]]
;
%x.head = insertelement <4 x i32> poison, i32 %x, i32 0
%x.splat = shufflevector <4 x i32> %x.head, <4 x i32> poison, <4 x i32> zeroinitializer
%res = icmp eq <4 x i32> %x.splat, splat (i32 42)
ret <4 x i1> %res
}

define <vscale x 4 x i1> @icmp_nxv4i32(i32 %x, i32 %y) {
; CHECK-LABEL: define <vscale x 4 x i1> @icmp_nxv4i32(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
; CHECK-NEXT: [[RES_SCALAR:%.*]] = icmp eq i32 [[X]], 42
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 4 x i1> poison, i1 [[RES_SCALAR]], i64 0
; CHECK-NEXT: [[RES:%.*]] = shufflevector <vscale x 4 x i1> [[DOTSPLATINSERT]], <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: ret <vscale x 4 x i1> [[RES]]
;
%x.head = insertelement <vscale x 4 x i32> poison, i32 %x, i32 0
%x.splat = shufflevector <vscale x 4 x i32> %x.head, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
%res = icmp eq <vscale x 4 x i32> %x.splat, splat (i32 42)
ret <vscale x 4 x i1> %res
}