Skip to content

[VPlan][LoopVectorize] Truncate min/max intrinsic ops #90643

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions llvm/lib/Analysis/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,14 @@ llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB,
!InstructionSet.count(I))
continue;

// Byteswaps require at least 16 bits
if (const auto *II = dyn_cast<IntrinsicInst>(I)) {
if (II->getIntrinsicID() == Intrinsic::bswap) {
DBits[Leader] |= 0xFFFF;
DBits[I] |= 0xFFFF;
}
}

// Unsafe casts terminate a chain unsuccessfully. We can't do anything
// useful with bitcasts, ptrtoints or inttoptrs and it'd be unsafe to
// transform anything that relies on them.
Expand Down Expand Up @@ -687,6 +695,30 @@ llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB,
isa<ShlOperator, LShrOperator, AShrOperator>(U.getUser()) &&
U.getOperandNo() == 1)
return CI->uge(MinBW);
// Ignore the call pointer when considering intrinsics that
// DemandedBits understands.
if (U->getType()->isPointerTy() && isa<CallInst>(U.getUser()) &&
dyn_cast<CallInst>(U.getUser())->getCalledFunction() ==
dyn_cast<Function>(U)) {
if (const auto *II = dyn_cast<IntrinsicInst>(U.getUser())) {
// Only ignore cases that DemandedBits understands.
switch (II->getIntrinsicID()) {
default:
break;
case Intrinsic::umax:
case Intrinsic::umin:
case Intrinsic::smax:
case Intrinsic::smin:
case Intrinsic::fshl:
case Intrinsic::fshr:
case Intrinsic::cttz:
case Intrinsic::ctlz:
case Intrinsic::bitreverse:
case Intrinsic::bswap:
return false;
}
}
}
uint64_t BW = bit_width(DB.getDemandedBits(&U).getZExtValue());
return bit_ceil(BW) > MinBW;
}))
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8284,7 +8284,7 @@ VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI,
Range);
if (ShouldUseVectorIntrinsic)
return new VPWidenCallRecipe(CI, make_range(Ops.begin(), Ops.end()), ID,
CI->getDebugLoc());
CI->getType(), CI->getDebugLoc());

Function *Variant = nullptr;
std::optional<unsigned> MaskPos;
Expand Down Expand Up @@ -8337,8 +8337,8 @@ VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI,
}

return new VPWidenCallRecipe(CI, make_range(Ops.begin(), Ops.end()),
Intrinsic::not_intrinsic, CI->getDebugLoc(),
Variant);
Intrinsic::not_intrinsic, CI->getType(),
CI->getDebugLoc(), Variant);
}

return nullptr;
Expand Down
17 changes: 13 additions & 4 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -1455,14 +1455,17 @@ class VPWidenCallRecipe : public VPSingleDefRecipe {
/// chosen vectorized variant, so there will be a different vplan for each
/// VF with a valid variant.
Function *Variant;
/// Result type for the cast.
Type *ResultTy;

public:
template <typename IterT>
VPWidenCallRecipe(Value *UV, iterator_range<IterT> CallArguments,
Intrinsic::ID VectorIntrinsicID, DebugLoc DL = {},
Function *Variant = nullptr)
Intrinsic::ID VectorIntrinsicID, Type *ResultTy,
DebugLoc DL = {}, Function *Variant = nullptr)
: VPSingleDefRecipe(VPDef::VPWidenCallSC, CallArguments, UV, DL),
VectorIntrinsicID(VectorIntrinsicID), Variant(Variant) {
VectorIntrinsicID(VectorIntrinsicID), Variant(Variant),
ResultTy(ResultTy) {
assert(
isa<Function>(getOperand(getNumOperands() - 1)->getLiveInIRValue()) &&
"last operand must be the called function");
Expand All @@ -1472,7 +1475,8 @@ class VPWidenCallRecipe : public VPSingleDefRecipe {

VPWidenCallRecipe *clone() override {
return new VPWidenCallRecipe(getUnderlyingValue(), operands(),
VectorIntrinsicID, getDebugLoc(), Variant);
VectorIntrinsicID, ResultTy, getDebugLoc(),
Variant);
}

VP_CLASSOF_IMPL(VPDef::VPWidenCallSC)
Expand All @@ -1496,6 +1500,11 @@ class VPWidenCallRecipe : public VPSingleDefRecipe {
void print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const override;
#endif

/// Returns the result type of the cast.
Type *getResultType() const { return ResultTy; }

void setResultType(Type *newResTy) { ResultTy = newResTy; }
};

/// A recipe for widening select instructions.
Expand Down
9 changes: 3 additions & 6 deletions llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,6 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenRecipe *R) {
llvm_unreachable("Unhandled opcode!");
}

Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenCallRecipe *R) {
auto &CI = *cast<CallInst>(R->getUnderlyingInstr());
return CI.getType();
}

Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenMemoryRecipe *R) {
assert((isa<VPWidenLoadRecipe>(R) || isa<VPWidenLoadEVLRecipe>(R)) &&
"Store recipes should not define any values");
Expand Down Expand Up @@ -238,7 +233,7 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
return inferScalarType(R->getOperand(0));
})
.Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe,
VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe>(
VPWidenMemoryRecipe, VPWidenSelectRecipe>(
[this](const auto *R) { return inferScalarTypeForRecipe(R); })
.Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) {
// TODO: Use info from interleave group.
Expand All @@ -248,6 +243,8 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
[](const VPWidenCastRecipe *R) { return R->getResultType(); })
.Case<VPScalarCastRecipe>(
[](const VPScalarCastRecipe *R) { return R->getResultType(); })
.Case<VPWidenCallRecipe>(
[](const VPWidenCallRecipe *R) { return R->getResultType(); })
.Case<VPExpandSCEVRecipe>([](const VPExpandSCEVRecipe *R) {
return R->getSCEV()->getType();
});
Expand Down
1 change: 0 additions & 1 deletion llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class VPTypeAnalysis {

Type *inferScalarTypeForRecipe(const VPBlendRecipe *R);
Type *inferScalarTypeForRecipe(const VPInstruction *R);
Type *inferScalarTypeForRecipe(const VPWidenCallRecipe *R);
Type *inferScalarTypeForRecipe(const VPWidenRecipe *R);
Type *inferScalarTypeForRecipe(const VPWidenIntOrFpInductionRecipe *R);
Type *inferScalarTypeForRecipe(const VPWidenMemoryRecipe *R);
Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -723,8 +723,8 @@ void VPWidenCallRecipe::execute(VPTransformState &State) {
// Add return type if intrinsic is overloaded on it.
if (UseIntrinsic &&
isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, -1))
TysForDecl.push_back(VectorType::get(
CalledScalarFn->getReturnType()->getScalarType(), State.VF));
TysForDecl.push_back(
VectorType::get(getResultType()->getScalarType(), State.VF));
SmallVector<Value *, 4> Args;
for (const auto &I : enumerate(arg_operands())) {
// Some intrinsics have a scalar argument - don't replace it with a
Expand Down Expand Up @@ -780,14 +780,14 @@ void VPWidenCallRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
O << Indent << "WIDEN-CALL ";

Function *CalledFn = getCalledScalarFunction();
if (CalledFn->getReturnType()->isVoidTy())
if (getResultType()->isVoidTy())
O << "void ";
else {
printAsOperand(O, SlotTracker);
O << " = ";
}

Function *CalledFn = getCalledScalarFunction();
O << "call @" << CalledFn->getName() << "(";
interleaveComma(arg_operands(), O, [&O, &SlotTracker](VPValue *Op) {
Op->printAsOperand(O, SlotTracker);
Expand Down
16 changes: 12 additions & 4 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void VPlanTransforms::VPInstructionsToVPRecipes(
} else if (CallInst *CI = dyn_cast<CallInst>(Inst)) {
NewRecipe = new VPWidenCallRecipe(
CI, Ingredient.operands(), getVectorIntrinsicIDForCall(CI, &TLI),
CI->getDebugLoc());
CI->getType(), CI->getDebugLoc());
} else if (SelectInst *SI = dyn_cast<SelectInst>(Inst)) {
NewRecipe = new VPWidenSelectRecipe(*SI, Ingredient.operands());
} else if (auto *CI = dyn_cast<CastInst>(Inst)) {
Expand Down Expand Up @@ -971,8 +971,8 @@ void VPlanTransforms::truncateToMinimalBitwidths(
for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
vp_depth_first_deep(Plan.getVectorLoopRegion()))) {
for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
if (!isa<VPWidenRecipe, VPWidenCastRecipe, VPReplicateRecipe,
VPWidenSelectRecipe, VPWidenLoadRecipe>(&R))
if (!isa<VPWidenRecipe, VPWidenCallRecipe, VPWidenCastRecipe,
VPReplicateRecipe, VPWidenSelectRecipe, VPWidenLoadRecipe>(&R))
continue;

VPValue *ResultVPV = R.getVPSingleValue();
Expand Down Expand Up @@ -1049,7 +1049,9 @@ void VPlanTransforms::truncateToMinimalBitwidths(

// Shrink operands by introducing truncates as needed.
unsigned StartIdx = isa<VPWidenSelectRecipe>(&R) ? 1 : 0;
for (unsigned Idx = StartIdx; Idx != R.getNumOperands(); ++Idx) {
unsigned EndIdx =
R.getNumOperands() - (isa<VPWidenCallRecipe>(&R) ? 1 : 0);
for (unsigned Idx = StartIdx; Idx != EndIdx; ++Idx) {
auto *Op = R.getOperand(Idx);
unsigned OpSizeInBits =
TypeInfo.inferScalarType(Op)->getScalarSizeInBits();
Expand Down Expand Up @@ -1078,6 +1080,12 @@ void VPlanTransforms::truncateToMinimalBitwidths(
}
}

// If this was a WIDEN-CALL (intrinsic) then we need to update the return
// type so it's compatible with the new args.
if (isa<VPWidenCallRecipe>(&R)) {
auto *callInsn = dyn_cast<VPWidenCallRecipe>(&R);
callInsn->setResultType(NewResTy);
}
}
}

Expand Down
Loading