-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[VectorCombine][RISCV] Convert VPIntrinsics with splat operands to splats #65706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
032cf54
618349b
b843089
52fb71e
adab8fa
6efc815
8c23455
2cda6cb
aefb961
8b30ae5
6abb0d3
ca48343
2b96e3f
5cc6e53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -102,6 +102,7 @@ class VectorCombine { | |
bool foldInsExtFNeg(Instruction &I); | ||
bool foldBitcastShuf(Instruction &I); | ||
bool scalarizeBinopOrCmp(Instruction &I); | ||
bool scalarizeVPIntrinsic(Instruction &I); | ||
bool foldExtractedCmps(Instruction &I); | ||
bool foldSingleElementStore(Instruction &I); | ||
bool scalarizeLoadExtract(Instruction &I); | ||
|
@@ -729,6 +730,111 @@ bool VectorCombine::foldBitcastShuf(Instruction &I) { | |
return true; | ||
} | ||
|
||
/// VP Intrinsics whose vector operands are both splat values may be simplified | ||
/// into the scalar version of the operation and the result splatted. This | ||
/// can lead to scalarization down the line. | ||
bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) { | ||
if (!isa<VPIntrinsic>(I)) | ||
return false; | ||
VPIntrinsic &VPI = cast<VPIntrinsic>(I); | ||
Value *Op0 = VPI.getArgOperand(0); | ||
Value *Op1 = VPI.getArgOperand(1); | ||
|
||
if (!isSplatValue(Op0) || !isSplatValue(Op1)) | ||
return false; | ||
|
||
// For the binary VP intrinsics supported here, the result on disabled lanes | ||
// is a poison value. For now, only do this simplification if all lanes | ||
// are active. | ||
// TODO: Relax the condition that all lanes are active by using insertelement | ||
// on inactive lanes. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not for this patch, but for reinserting the inactive lanes later maybe we could do something like %x = scalar
%v = splat
%res = @llvm.vp.merge.v16i32(%mask, %v, poison, %evl) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good idea. I will work on this after this patch lands. |
||
auto IsAllTrueMask = [](Value *MaskVal) { | ||
if (Value *SplattedVal = getSplatValue(MaskVal)) | ||
if (auto *ConstValue = dyn_cast<Constant>(SplattedVal)) | ||
return ConstValue->isAllOnesValue(); | ||
return false; | ||
}; | ||
if (!IsAllTrueMask(VPI.getArgOperand(2))) | ||
return false; | ||
lukel97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// Check to make sure we support scalarization of the intrinsic | ||
Intrinsic::ID IntrID = VPI.getIntrinsicID(); | ||
michaelmaitland marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (!VPBinOpIntrinsic::isVPBinOp(IntrID)) | ||
return false; | ||
michaelmaitland marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// Calculate cost of splatting both operands into vectors and the vector | ||
// intrinsic | ||
VectorType *VecTy = cast<VectorType>(VPI.getType()); | ||
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; | ||
InstructionCost SplatCost = | ||
TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) + | ||
TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy); | ||
|
||
// Calculate the cost of the VP Intrinsic | ||
SmallVector<Type *, 4> Args; | ||
for (Value *V : VPI.args()) | ||
Args.push_back(V->getType()); | ||
IntrinsicCostAttributes Attrs(IntrID, VecTy, Args); | ||
InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind); | ||
InstructionCost OldCost = 2 * SplatCost + VectorOpCost; | ||
|
||
// Determine scalar opcode | ||
std::optional<unsigned> FunctionalOpcode = | ||
VPI.getFunctionalOpcode(); | ||
std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt; | ||
if (!FunctionalOpcode) { | ||
ScalarIntrID = VPI.getFunctionalIntrinsicID(); | ||
if (!ScalarIntrID) | ||
return false; | ||
} | ||
|
||
// Calculate cost of scalarizing | ||
InstructionCost ScalarOpCost = 0; | ||
if (ScalarIntrID) { | ||
IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args); | ||
ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind); | ||
} else { | ||
ScalarOpCost = | ||
TTI.getArithmeticInstrCost(*FunctionalOpcode, VecTy->getScalarType()); | ||
} | ||
|
||
// The existing splats may be kept around if other instructions use them. | ||
InstructionCost CostToKeepSplats = | ||
(SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse()); | ||
InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats; | ||
|
||
LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI | ||
<< "\n"); | ||
LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost | ||
<< ", Cost of scalarizing:" << NewCost << "\n"); | ||
|
||
// We want to scalarize unless the vector variant actually has lower cost. | ||
if (OldCost < NewCost || !NewCost.isValid()) | ||
return false; | ||
|
||
// Scalarize the intrinsic | ||
ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount(); | ||
Value *EVL = VPI.getArgOperand(3); | ||
const DataLayout &DL = VPI.getModule()->getDataLayout(); | ||
bool MustHaveNonZeroVL = | ||
IntrID == Intrinsic::vp_sdiv || IntrID == Intrinsic::vp_udiv || | ||
IntrID == Intrinsic::vp_srem || IntrID == Intrinsic::vp_urem; | ||
|
||
if (!MustHaveNonZeroVL || isKnownNonZero(EVL, DL, 0, &AC, &VPI, &DT)) { | ||
Value *ScalarOp0 = getSplatValue(Op0); | ||
Value *ScalarOp1 = getSplatValue(Op1); | ||
Value *ScalarVal = | ||
ScalarIntrID | ||
? Builder.CreateIntrinsic(VecTy->getScalarType(), *ScalarIntrID, | ||
{ScalarOp0, ScalarOp1}) | ||
: Builder.CreateBinOp((Instruction::BinaryOps)(*FunctionalOpcode), | ||
ScalarOp0, ScalarOp1); | ||
replaceValue(VPI, *Builder.CreateVectorSplat(EC, ScalarVal)); | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
/// Match a vector binop or compare instruction with at least one inserted | ||
/// scalar operand and convert to scalar binop/cmp followed by insertelement. | ||
bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) { | ||
|
@@ -1737,6 +1843,7 @@ bool VectorCombine::run() { | |
if (isa<VectorType>(I.getType())) { | ||
MadeChange |= scalarizeBinopOrCmp(I); | ||
MadeChange |= scalarizeLoadExtract(I); | ||
MadeChange |= scalarizeVPIntrinsic(I); | ||
} | ||
|
||
if (Opcode == Instruction::Store) | ||
|
Uh oh!
There was an error while loading. Please reload this page.