Skip to content

[VectorCombine] Support nary operands and intrinsics in scalarizeOpOrCmp #138406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 107 additions & 103 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ STATISTIC(NumVecCmp, "Number of vector compares formed");
STATISTIC(NumVecBO, "Number of vector binops formed");
STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
STATISTIC(NumScalarBO, "Number of scalar binops formed");
STATISTIC(NumScalarOps, "Number of scalar unary + binary ops formed");
STATISTIC(NumScalarCmp, "Number of scalar compares formed");
STATISTIC(NumScalarIntrinsic, "Number of scalar intrinsic calls formed");

Expand Down Expand Up @@ -114,7 +114,7 @@ class VectorCombine {
bool foldInsExtBinop(Instruction &I);
bool foldInsExtVectorToShuffle(Instruction &I);
bool foldBitcastShuffle(Instruction &I);
bool scalarizeBinopOrCmp(Instruction &I);
bool scalarizeOpOrCmp(Instruction &I);
bool scalarizeVPIntrinsic(Instruction &I);
bool foldExtractedCmps(Instruction &I);
bool foldBinopOfReductions(Instruction &I);
Expand Down Expand Up @@ -1018,91 +1018,90 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
return true;
}

/// Match a vector binop, compare or binop-like intrinsic with at least one
/// inserted scalar operand and convert to scalar binop/cmp/intrinsic followed
/// Match a vector op/compare/intrinsic with at least one
/// inserted scalar operand and convert to scalar op/cmp/intrinsic followed
/// by insertelement.
bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
Value *Ins0, *Ins1;
if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
!match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1)))) {
// TODO: Allow unary and ternary intrinsics
// TODO: Allow intrinsics with different argument types
// TODO: Allow intrinsics with scalar arguments
if (auto *II = dyn_cast<IntrinsicInst>(&I);
II && II->arg_size() == 2 &&
isTriviallyVectorizable(II->getIntrinsicID()) &&
all_of(II->args(),
[&II](Value *Arg) { return Arg->getType() == II->getType(); })) {
Ins0 = II->getArgOperand(0);
Ins1 = II->getArgOperand(1);
} else {
return false;
}
}
bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
auto *UO = dyn_cast<UnaryOperator>(&I);
auto *BO = dyn_cast<BinaryOperator>(&I);
auto *CI = dyn_cast<CmpInst>(&I);
auto *II = dyn_cast<IntrinsicInst>(&I);
if (!UO && !BO && !CI && !II)
return false;

// TODO: Allow intrinsics with different argument types
// TODO: Allow intrinsics with scalar arguments
if (II && (!isTriviallyVectorizable(II->getIntrinsicID()) ||
!all_of(II->args(), [&II](Value *Arg) {
return Arg->getType() == II->getType();
})))
return false;

// Do not convert the vector condition of a vector select into a scalar
// condition. That may cause problems for codegen because of differences in
// boolean formats and register-file transfers.
// TODO: Can we account for that in the cost model?
if (isa<CmpInst>(I))
if (CI)
for (User *U : I.users())
if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
return false;

// Match against one or both scalar values being inserted into constant
// vectors:
// vec_op VecC0, (inselt VecC1, V1, Index)
// vec_op (inselt VecC0, V0, Index), VecC1
// vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
// TODO: Deal with mismatched index constants and variable indexes?
Constant *VecC0 = nullptr, *VecC1 = nullptr;
Value *V0 = nullptr, *V1 = nullptr;
uint64_t Index0 = 0, Index1 = 0;
if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
m_ConstantInt(Index0))) &&
!match(Ins0, m_Constant(VecC0)))
return false;
if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
m_ConstantInt(Index1))) &&
!match(Ins1, m_Constant(VecC1)))
return false;

bool IsConst0 = !V0;
bool IsConst1 = !V1;
if (IsConst0 && IsConst1)
return false;
if (!IsConst0 && !IsConst1 && Index0 != Index1)
return false;
// Match constant vectors or scalars being inserted into constant vectors:
// vec_op [VecC0 | (inselt VecC0, V0, Index)], ...
SmallVector<Constant *> VecCs;
SmallVector<Value *> ScalarOps;
std::optional<uint64_t> Index;

auto Ops = II ? II->args() : I.operand_values();
for (Value *Op : Ops) {
Constant *VecC;
Value *V;
uint64_t InsIdx = 0;
VectorType *OpTy = cast<VectorType>(Op->getType());
if (match(Op, m_InsertElt(m_Constant(VecC), m_Value(V),
m_ConstantInt(InsIdx)))) {
// Bail if any inserts are out of bounds.
if (OpTy->getElementCount().getKnownMinValue() <= InsIdx)
return false;
// All inserts must have the same index.
// TODO: Deal with mismatched index constants and variable indexes?
if (!Index)
Index = InsIdx;
else if (InsIdx != *Index)
return false;
VecCs.push_back(VecC);
ScalarOps.push_back(V);
} else if (match(Op, m_Constant(VecC))) {
VecCs.push_back(VecC);
ScalarOps.push_back(nullptr);
} else {
return false;
}
}

auto *VecTy0 = cast<VectorType>(Ins0->getType());
auto *VecTy1 = cast<VectorType>(Ins1->getType());
if (VecTy0->getElementCount().getKnownMinValue() <= Index0 ||
VecTy1->getElementCount().getKnownMinValue() <= Index1)
// Bail if all operands are constant.
if (!Index.has_value())
return false;

uint64_t Index = IsConst0 ? Index1 : Index0;
Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
Type *VecTy = I.getType();
VectorType *VecTy = cast<VectorType>(I.getType());
Type *ScalarTy = VecTy->getScalarType();
assert(VecTy->isVectorTy() &&
(IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
(ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
ScalarTy->isPointerTy()) &&
"Unexpected types for insert element into binop or cmp");

unsigned Opcode = I.getOpcode();
InstructionCost ScalarOpCost, VectorOpCost;
if (isa<CmpInst>(I)) {
CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
if (CI) {
CmpInst::Predicate Pred = CI->getPredicate();
ScalarOpCost = TTI.getCmpSelInstrCost(
Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
VectorOpCost = TTI.getCmpSelInstrCost(
Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
} else if (isa<BinaryOperator>(I)) {
} else if (UO || BO) {
ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
} else {
auto *II = cast<IntrinsicInst>(&I);
IntrinsicCostAttributes ScalarICA(
II->getIntrinsicID(), ScalarTy,
SmallVector<Type *>(II->arg_size(), ScalarTy));
Expand All @@ -1115,56 +1114,59 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {

// Fold the vector constants in the original vectors into a new base vector to
// get more accurate cost modelling.
Value *NewVecC;
if (isa<CmpInst>(I))
NewVecC = ConstantFoldCompareInstOperands(Pred, VecC0, VecC1, *DL);
else if (isa<BinaryOperator>(I))
NewVecC = ConstantFoldBinaryOpOperands((Instruction::BinaryOps)Opcode,
VecC0, VecC1, *DL);
else
NewVecC = ConstantFoldBinaryIntrinsic(
cast<IntrinsicInst>(I).getIntrinsicID(), VecC0, VecC1, I.getType(), &I);
Value *NewVecC = nullptr;
if (CI)
NewVecC = ConstantFoldCompareInstOperands(CI->getPredicate(), VecCs[0],
VecCs[1], *DL);
else if (UO)
NewVecC = ConstantFoldUnaryOpOperand(Opcode, VecCs[0], *DL);
else if (BO)
NewVecC = ConstantFoldBinaryOpOperands(Opcode, VecCs[0], VecCs[1], *DL);
else if (II->arg_size() == 2)
NewVecC = ConstantFoldBinaryIntrinsic(II->getIntrinsicID(), VecCs[0],
VecCs[1], II->getType(), II);

// Get cost estimate for the insert element. This cost will factor into
// both sequences.
InstructionCost InsertCostNewVecC = TTI.getVectorInstrCost(
Instruction::InsertElement, VecTy, CostKind, Index, NewVecC);
InstructionCost InsertCostV0 = TTI.getVectorInstrCost(
Instruction::InsertElement, VecTy, CostKind, Index, VecC0, V0);
InstructionCost InsertCostV1 = TTI.getVectorInstrCost(
Instruction::InsertElement, VecTy, CostKind, Index, VecC1, V1);
InstructionCost OldCost = (IsConst0 ? 0 : InsertCostV0) +
(IsConst1 ? 0 : InsertCostV1) + VectorOpCost;
InstructionCost NewCost = ScalarOpCost + InsertCostNewVecC +
(IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCostV0) +
(IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCostV1);
InstructionCost OldCost = VectorOpCost;
InstructionCost NewCost =
ScalarOpCost + TTI.getVectorInstrCost(Instruction::InsertElement, VecTy,
CostKind, *Index, NewVecC);
for (auto [Op, VecC, Scalar] : zip(Ops, VecCs, ScalarOps)) {
if (!Scalar)
continue;
InstructionCost InsertCost = TTI.getVectorInstrCost(
Instruction::InsertElement, VecTy, CostKind, *Index, VecC, Scalar);
OldCost += InsertCost;
NewCost += !Op->hasOneUse() * InsertCost;
}

// We want to scalarize unless the vector variant actually has lower cost.
if (OldCost < NewCost || !NewCost.isValid())
return false;

// vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
// inselt NewVecC, (scalar_op V0, V1), Index
if (isa<CmpInst>(I))
if (CI)
++NumScalarCmp;
else if (isa<BinaryOperator>(I))
++NumScalarBO;
else if (isa<IntrinsicInst>(I))
else if (UO || BO)
++NumScalarOps;
else
++NumScalarIntrinsic;

// For constant cases, extract the scalar element, this should constant fold.
if (IsConst0)
V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
if (IsConst1)
V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
for (auto [OpIdx, Scalar, VecC] : enumerate(ScalarOps, VecCs))
if (!Scalar)
ScalarOps[OpIdx] = ConstantExpr::getExtractElement(
cast<Constant>(VecC), Builder.getInt64(*Index));

Value *Scalar;
if (isa<CmpInst>(I))
Scalar = Builder.CreateCmp(Pred, V0, V1);
else if (isa<BinaryOperator>(I))
Scalar = Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1);
if (CI)
Scalar = Builder.CreateCmp(CI->getPredicate(), ScalarOps[0], ScalarOps[1]);
else if (UO || BO)
Scalar = Builder.CreateNAryOp(Opcode, ScalarOps);
else
Scalar = Builder.CreateIntrinsic(
ScalarTy, cast<IntrinsicInst>(I).getIntrinsicID(), {V0, V1});
Scalar = Builder.CreateIntrinsic(ScalarTy, II->getIntrinsicID(), ScalarOps);

Scalar->setName(I.getName() + ".scalar");

Expand All @@ -1175,16 +1177,18 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {

// Create a new base vector if the constant folding failed.
if (!NewVecC) {
if (isa<CmpInst>(I))
NewVecC = Builder.CreateCmp(Pred, VecC0, VecC1);
else if (isa<BinaryOperator>(I))
NewVecC =
Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1);
SmallVector<Value *> VecCValues;
VecCValues.reserve(VecCs.size());
append_range(VecCValues, VecCs);
if (CI)
NewVecC = Builder.CreateCmp(CI->getPredicate(), VecCs[0], VecCs[1]);
else if (UO || BO)
NewVecC = Builder.CreateNAryOp(Opcode, VecCValues);
else
NewVecC = Builder.CreateIntrinsic(
VecTy, cast<IntrinsicInst>(I).getIntrinsicID(), {VecC0, VecC1});
NewVecC =
Builder.CreateIntrinsic(VecTy, II->getIntrinsicID(), VecCValues);
}
Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, *Index);
replaceValue(I, *Insert);
return true;
}
Expand Down Expand Up @@ -3570,7 +3574,7 @@ bool VectorCombine::run() {
// This transform works with scalable and fixed vectors
// TODO: Identify and allow other scalable transforms
if (IsVectorType) {
MadeChange |= scalarizeBinopOrCmp(I);
MadeChange |= scalarizeOpOrCmp(I);
MadeChange |= scalarizeLoadExtract(I);
MadeChange |= scalarizeVPIntrinsic(I);
MadeChange |= foldInterleaveIntrinsics(I);
Expand Down
56 changes: 56 additions & 0 deletions llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,62 @@ define <4 x i32> @non_trivially_vectorizable(i32 %x, i32 %y) {
ret <4 x i32> %v
}

define <4 x float> @fabs_fixed(float %x) {
; CHECK-LABEL: define <4 x float> @fabs_fixed(
; CHECK-SAME: float [[X:%.*]]) {
; CHECK-NEXT: [[V_SCALAR:%.*]] = call float @llvm.fabs.f32(float [[X]])
; CHECK-NEXT: [[TMP1:%.*]] = call <4 x float> @llvm.fabs.v4f32(<4 x float> poison)
; CHECK-NEXT: [[V:%.*]] = insertelement <4 x float> [[TMP1]], float [[V_SCALAR]], i64 0
; CHECK-NEXT: ret <4 x float> [[V]]
;
%x.insert = insertelement <4 x float> poison, float %x, i32 0
%v = call <4 x float> @llvm.fabs(<4 x float> %x.insert)
ret <4 x float> %v
}

define <vscale x 4 x float> @fabs_scalable(float %x) {
; CHECK-LABEL: define <vscale x 4 x float> @fabs_scalable(
; CHECK-SAME: float [[X:%.*]]) {
; CHECK-NEXT: [[V_SCALAR:%.*]] = call float @llvm.fabs.f32(float [[X]])
; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 4 x float> @llvm.fabs.nxv4f32(<vscale x 4 x float> poison)
; CHECK-NEXT: [[V:%.*]] = insertelement <vscale x 4 x float> [[TMP1]], float [[V_SCALAR]], i64 0
; CHECK-NEXT: ret <vscale x 4 x float> [[V]]
;
%x.insert = insertelement <vscale x 4 x float> poison, float %x, i32 0
%v = call <vscale x 4 x float> @llvm.fabs(<vscale x 4 x float> %x.insert)
ret <vscale x 4 x float> %v
}

define <4 x float> @fma_fixed(float %x, float %y, float %z) {
; CHECK-LABEL: define <4 x float> @fma_fixed(
; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]], float [[Z:%.*]]) {
; CHECK-NEXT: [[V_SCALAR:%.*]] = call float @llvm.fma.f32(float [[X]], float [[Y]], float [[Z]])
; CHECK-NEXT: [[TMP1:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> poison, <4 x float> poison, <4 x float> poison)
; CHECK-NEXT: [[V:%.*]] = insertelement <4 x float> [[TMP1]], float [[V_SCALAR]], i64 0
; CHECK-NEXT: ret <4 x float> [[V]]
;
%x.insert = insertelement <4 x float> poison, float %x, i32 0
%y.insert = insertelement <4 x float> poison, float %y, i32 0
%z.insert = insertelement <4 x float> poison, float %z, i32 0
%v = call <4 x float> @llvm.fma(<4 x float> %x.insert, <4 x float> %y.insert, <4 x float> %z.insert)
ret <4 x float> %v
}

define <vscale x 4 x float> @fma_scalable(float %x, float %y, float %z) {
; CHECK-LABEL: define <vscale x 4 x float> @fma_scalable(
; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]], float [[Z:%.*]]) {
; CHECK-NEXT: [[V_SCALAR:%.*]] = call float @llvm.fma.f32(float [[X]], float [[Y]], float [[Z]])
; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 4 x float> @llvm.fma.nxv4f32(<vscale x 4 x float> poison, <vscale x 4 x float> poison, <vscale x 4 x float> poison)
; CHECK-NEXT: [[V:%.*]] = insertelement <vscale x 4 x float> [[TMP1]], float [[V_SCALAR]], i64 0
; CHECK-NEXT: ret <vscale x 4 x float> [[V]]
;
%x.insert = insertelement <vscale x 4 x float> poison, float %x, i32 0
%y.insert = insertelement <vscale x 4 x float> poison, float %y, i32 0
%z.insert = insertelement <vscale x 4 x float> poison, float %z, i32 0
%v = call <vscale x 4 x float> @llvm.fma(<vscale x 4 x float> %x.insert, <vscale x 4 x float> %y.insert, <vscale x 4 x float> %z.insert)
ret <vscale x 4 x float> %v
}

; TODO: We should be able to scalarize this if we preserve the scalar argument.
define <4 x float> @scalar_argument(float %x) {
; CHECK-LABEL: define <4 x float> @scalar_argument(
Expand Down
26 changes: 26 additions & 0 deletions llvm/test/Transforms/VectorCombine/unary-op-scalarize.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt < %s -S -p vector-combine | FileCheck %s

define <4 x float> @fneg_fixed(float %x) {
; CHECK-LABEL: define <4 x float> @fneg_fixed(
; CHECK-SAME: float [[X:%.*]]) {
; CHECK-NEXT: [[V_SCALAR:%.*]] = fneg float [[X]]
; CHECK-NEXT: [[V:%.*]] = insertelement <4 x float> poison, float [[V_SCALAR]], i64 0
; CHECK-NEXT: ret <4 x float> [[V]]
;
%x.insert = insertelement <4 x float> poison, float %x, i32 0
%v = fneg <4 x float> %x.insert
ret <4 x float> %v
}

define <vscale x 4 x float> @fneg_scalable(float %x) {
; CHECK-LABEL: define <vscale x 4 x float> @fneg_scalable(
; CHECK-SAME: float [[X:%.*]]) {
; CHECK-NEXT: [[V_SCALAR:%.*]] = fneg float [[X]]
; CHECK-NEXT: [[V:%.*]] = insertelement <vscale x 4 x float> poison, float [[V_SCALAR]], i64 0
; CHECK-NEXT: ret <vscale x 4 x float> [[V]]
;
%x.insert = insertelement <vscale x 4 x float> poison, float %x, i32 0
%v = fneg <vscale x 4 x float> %x.insert
ret <vscale x 4 x float> %v
}