Skip to content

[VectorCombine][RISCV] Convert VPIntrinsics with splat operands to splats #65706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
14 commits
Select commit Hold shift + click to select a range
032cf54
[InstCombine][RISCV] Convert VPIntrinsics with splat operands to splats
michaelmaitland Sep 8, 2023
618349b
fixup! [InstCombine][RISCV] Convert VPIntrinsics with splat operands …
michaelmaitland Sep 8, 2023
b843089
fixup! [InstCombine][RISCV] Convert VPIntrinsics with splat operands …
michaelmaitland Sep 8, 2023
52fb71e
fixup! [InstCombine][RISCV] Convert VPIntrinsics with splat operands …
michaelmaitland Sep 8, 2023
adab8fa
fixup! [InstCombine][RISCV] Convert VPIntrinsics with splat operands …
michaelmaitland Sep 8, 2023
6efc815
fixup! [InstCombine][RISCV] Convert VPIntrinsics with splat operands …
michaelmaitland Sep 8, 2023
8c23455
fixup! [InstCombine][RISCV] Convert VPIntrinsics with splat operands …
michaelmaitland Sep 8, 2023
2cda6cb
fixup! [InstCombine][RISCV] Convert VPIntrinsics with splat operands …
michaelmaitland Sep 11, 2023
aefb961
fixup! [InstCombine][RISCV] Convert VPIntrinsics with splat operands …
michaelmaitland Sep 12, 2023
8b30ae5
fixup! [InstCombine][RISCV] Convert VPIntrinsics with splat operands …
michaelmaitland Sep 12, 2023
6abb0d3
fixup! [InstCombine][RISCV] Convert VPIntrinsics with splat operands …
michaelmaitland Sep 12, 2023
ca48343
fixup! [InstCombine][RISCV] Convert VPIntrinsics with splat operands …
michaelmaitland Sep 13, 2023
2b96e3f
fixup! [InstCombine][RISCV] Convert VPIntrinsics with splat operands …
michaelmaitland Sep 13, 2023
5cc6e53
fixup! [InstCombine][RISCV] Convert VPIntrinsics with splat operands …
michaelmaitland Sep 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class VectorCombine {
bool foldInsExtFNeg(Instruction &I);
bool foldBitcastShuf(Instruction &I);
bool scalarizeBinopOrCmp(Instruction &I);
bool scalarizeVPIntrinsic(Instruction &I);
bool foldExtractedCmps(Instruction &I);
bool foldSingleElementStore(Instruction &I);
bool scalarizeLoadExtract(Instruction &I);
Expand Down Expand Up @@ -729,6 +730,111 @@ bool VectorCombine::foldBitcastShuf(Instruction &I) {
return true;
}

/// VP Intrinsics whose vector operands are both splat values may be simplified
/// into the scalar version of the operation and the result splatted. This
/// can lead to scalarization down the line.
bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
if (!isa<VPIntrinsic>(I))
return false;
VPIntrinsic &VPI = cast<VPIntrinsic>(I);
Value *Op0 = VPI.getArgOperand(0);
Value *Op1 = VPI.getArgOperand(1);

if (!isSplatValue(Op0) || !isSplatValue(Op1))
return false;

// For the binary VP intrinsics supported here, the result on disabled lanes
// is a poison value. For now, only do this simplification if all lanes
// are active.
// TODO: Relax the condition that all lanes are active by using insertelement
// on inactive lanes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this patch, but for reinserting the inactive lanes later maybe we could do something like

%x = scalar
%v = splat
%res = @llvm.vp.merge.v16i32(%mask, %v, poison, %evl)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good idea. I will work on this after this patch lands.

auto IsAllTrueMask = [](Value *MaskVal) {
if (Value *SplattedVal = getSplatValue(MaskVal))
if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))
return ConstValue->isAllOnesValue();
return false;
};
if (!IsAllTrueMask(VPI.getArgOperand(2)))
return false;

// Check to make sure we support scalarization of the intrinsic
Intrinsic::ID IntrID = VPI.getIntrinsicID();
if (!VPBinOpIntrinsic::isVPBinOp(IntrID))
return false;

// Calculate cost of splatting both operands into vectors and the vector
// intrinsic
VectorType *VecTy = cast<VectorType>(VPI.getType());
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
InstructionCost SplatCost =
TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) +
TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy);

// Calculate the cost of the VP Intrinsic
SmallVector<Type *, 4> Args;
for (Value *V : VPI.args())
Args.push_back(V->getType());
IntrinsicCostAttributes Attrs(IntrID, VecTy, Args);
InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
InstructionCost OldCost = 2 * SplatCost + VectorOpCost;

// Determine scalar opcode
std::optional<unsigned> FunctionalOpcode =
VPI.getFunctionalOpcode();
std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
if (!FunctionalOpcode) {
ScalarIntrID = VPI.getFunctionalIntrinsicID();
if (!ScalarIntrID)
return false;
}

// Calculate cost of scalarizing
InstructionCost ScalarOpCost = 0;
if (ScalarIntrID) {
IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args);
ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
} else {
ScalarOpCost =
TTI.getArithmeticInstrCost(*FunctionalOpcode, VecTy->getScalarType());
}

// The existing splats may be kept around if other instructions use them.
InstructionCost CostToKeepSplats =
(SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;

LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
<< "\n");
LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
<< ", Cost of scalarizing:" << NewCost << "\n");

// We want to scalarize unless the vector variant actually has lower cost.
if (OldCost < NewCost || !NewCost.isValid())
return false;

// Scalarize the intrinsic
ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount();
Value *EVL = VPI.getArgOperand(3);
const DataLayout &DL = VPI.getModule()->getDataLayout();
bool MustHaveNonZeroVL =
IntrID == Intrinsic::vp_sdiv || IntrID == Intrinsic::vp_udiv ||
IntrID == Intrinsic::vp_srem || IntrID == Intrinsic::vp_urem;

if (!MustHaveNonZeroVL || isKnownNonZero(EVL, DL, 0, &AC, &VPI, &DT)) {
Value *ScalarOp0 = getSplatValue(Op0);
Value *ScalarOp1 = getSplatValue(Op1);
Value *ScalarVal =
ScalarIntrID
? Builder.CreateIntrinsic(VecTy->getScalarType(), *ScalarIntrID,
{ScalarOp0, ScalarOp1})
: Builder.CreateBinOp((Instruction::BinaryOps)(*FunctionalOpcode),
ScalarOp0, ScalarOp1);
replaceValue(VPI, *Builder.CreateVectorSplat(EC, ScalarVal));
return true;
}
return false;
}

/// Match a vector binop or compare instruction with at least one inserted
/// scalar operand and convert to scalar binop/cmp followed by insertelement.
bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
Expand Down Expand Up @@ -1737,6 +1843,7 @@ bool VectorCombine::run() {
if (isa<VectorType>(I.getType())) {
MadeChange |= scalarizeBinopOrCmp(I);
MadeChange |= scalarizeLoadExtract(I);
MadeChange |= scalarizeVPIntrinsic(I);
}

if (Opcode == Instruction::Store)
Expand Down
Loading