Skip to content

Commit 6abb0d3

Browse files
fixup! [InstCombine][RISCV] Convert VPIntrinsics with splat operands to splats of the scalar operation
1 parent 8b30ae5 commit 6abb0d3

File tree

2 files changed

+430
-46
lines changed

2 files changed

+430
-46
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ bool VectorCombine::foldBitcastShuf(Instruction &I) {
731731
}
732732

733733
/// VP Intrinsics whose vector operands are both splat values may be simplified
734-
/// into the scalar version of the operation and the result is splatted. This
734+
/// into the scalar version of the operation and the result splatted. This
735735
/// can lead to scalarization down the line.
736736
bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
737737
if (!isa<VPIntrinsic>(I))
@@ -758,15 +758,8 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
758758
return false;
759759

760760
// Check to make sure we support scalarization of the intrinsic
761-
std::set<Intrinsic::ID> SupportedIntrinsics(
762-
{Intrinsic::vp_add, Intrinsic::vp_sub, Intrinsic::vp_mul,
763-
Intrinsic::vp_ashr, Intrinsic::vp_lshr, Intrinsic::vp_shl,
764-
Intrinsic::vp_or, Intrinsic::vp_and, Intrinsic::vp_xor,
765-
Intrinsic::vp_fadd, Intrinsic::vp_fsub, Intrinsic::vp_fmul,
766-
Intrinsic::vp_sdiv, Intrinsic::vp_udiv, Intrinsic::vp_srem,
767-
Intrinsic::vp_urem});
768761
Intrinsic::ID IntrID = VPI.getIntrinsicID();
769-
if (!SupportedIntrinsics.count(IntrID))
762+
if (!VPBinOpIntrinsic::isVPBinOp(IntrID))
770763
return false;
771764

772765
// Calculate cost of splatting both operands into vectors and the vector
@@ -785,14 +778,39 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
785778
InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
786779
InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
787780

781+
// Determine scalar opcode
782+
std::optional<unsigned> FunctionalOpcode =
783+
VPI.getFunctionalOpcode();
784+
bool ScalarIsIntr = false;
785+
Intrinsic::ID ScalarIntrID;
786+
if (!FunctionalOpcode) {
787+
// Explicitly handle supported instructions (i.e. those that return
788+
// isVPBinOp above, that do not have functional nor constrained opcodes due
789+
// their intrinsic definitions.
790+
DenseMap<Intrinsic::ID, Intrinsic::ID> VPIntrToIntr(
791+
{{Intrinsic::vp_smax, Intrinsic::smax},
792+
{Intrinsic::vp_smin, Intrinsic::smin},
793+
{Intrinsic::vp_umax, Intrinsic::umax},
794+
{Intrinsic::vp_umin, Intrinsic::umin},
795+
{Intrinsic::vp_copysign, Intrinsic::copysign},
796+
{Intrinsic::vp_minnum, Intrinsic::minnum},
797+
{Intrinsic::vp_maxnum, Intrinsic::maxnum}});
798+
ScalarIsIntr = true;
799+
assert(VPIntrToIntr.contains(IntrID) &&
800+
"Unable to determine scalar opcode");
801+
ScalarIntrID = VPIntrToIntr[IntrID];
802+
}
803+
788804
// Calculate cost of scalarizing
789-
std::optional<unsigned> ScalarOpcodeOpt =
790-
VPIntrinsic::getFunctionalOpcodeForVP(IntrID);
791-
assert(ScalarOpcodeOpt && "Unable to determine scalar opcode");
792-
unsigned ScalarOpcode = *ScalarOpcodeOpt;
805+
InstructionCost ScalarOpCost = 0;
806+
if (ScalarIsIntr) {
807+
IntrinsicCostAttributes Attrs(ScalarIntrID, VecTy->getScalarType(), Args);
808+
ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
809+
} else {
810+
ScalarOpCost =
811+
TTI.getArithmeticInstrCost(*FunctionalOpcode, VecTy->getScalarType());
812+
}
793813

794-
InstructionCost ScalarOpCost =
795-
TTI.getArithmeticInstrCost(ScalarOpcode, VecTy->getScalarType());
796814
// The existing splats may be kept around if other instructions use them.
797815
InstructionCost CostToKeepSplats =
798816
(SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
@@ -814,13 +832,19 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
814832
bool IsKnownNonZeroVL = isKnownNonZero(EVL, DL, 0, &AC, &VPI, &DT);
815833
bool MustHaveNonZeroVL =
816834
IntrID == Intrinsic::vp_sdiv || IntrID == Intrinsic::vp_udiv ||
817-
IntrID == Intrinsic::vp_srem || IntrID == Intrinsic::vp_urem;
835+
IntrID == Intrinsic::vp_srem || IntrID == Intrinsic::vp_urem ||
836+
IntrID == Intrinsic::vp_fdiv || IntrID == Intrinsic::vp_frem;
818837

819838
if ((MustHaveNonZeroVL && IsKnownNonZeroVL) || !MustHaveNonZeroVL) {
820-
replaceValue(VPI, *Builder.CreateVectorSplat(
821-
EC, Builder.CreateBinOp(
822-
(Instruction::BinaryOps)ScalarOpcode,
823-
getSplatValue(Op0), getSplatValue(Op1))));
839+
Value *ScalarOp0 = getSplatValue(Op0);
840+
Value *ScalarOp1 = getSplatValue(Op1);
841+
Value *ScalarVal =
842+
ScalarIsIntr
843+
? Builder.CreateIntrinsic(VecTy->getScalarType(), ScalarIntrID,
844+
{ScalarOp0, ScalarOp1})
845+
: Builder.CreateBinOp((Instruction::BinaryOps)(*FunctionalOpcode),
846+
ScalarOp0, ScalarOp1);
847+
replaceValue(VPI, *Builder.CreateVectorSplat(EC, ScalarVal));
824848
return true;
825849
}
826850
return false;

0 commit comments

Comments
 (0)