Skip to content

Commit fd66195

Browse files
committed
[VPlan] Manage compare predicates in VPRecipeWithIRFlags.
Extend VPRecipeWithIRFlags to also manage predicates for compares. This allows removing the custom ICmpULE opcode from VPInstruction which was a workaround for missing proper predicate handling. This simplifies the code a bit while also allowing compares with any predicates. It also fixes a case where the compare predixcate wasn't printed properly for VPReplicateRecipes. Discussed/split off from D150398. Reviewed By: Ayal Differential Revision: https://reviews.llvm.org/D158992
1 parent 64d16ef commit fd66195

File tree

5 files changed

+64
-17
lines changed

5 files changed

+64
-17
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,17 @@ class VPBuilder {
4545
VPBasicBlock *BB = nullptr;
4646
VPBasicBlock::iterator InsertPt = VPBasicBlock::iterator();
4747

48+
/// Insert \p VPI in BB at InsertPt if BB is set.
49+
VPInstruction *tryInsertInstruction(VPInstruction *VPI) {
50+
if (BB)
51+
BB->insert(VPI, InsertPt);
52+
return VPI;
53+
}
54+
4855
VPInstruction *createInstruction(unsigned Opcode,
4956
ArrayRef<VPValue *> Operands, DebugLoc DL,
5057
const Twine &Name = "") {
51-
VPInstruction *Instr = new VPInstruction(Opcode, Operands, DL, Name);
52-
if (BB)
53-
BB->insert(Instr, InsertPt);
54-
return Instr;
58+
return tryInsertInstruction(new VPInstruction(Opcode, Operands, DL, Name));
5559
}
5660

5761
VPInstruction *createInstruction(unsigned Opcode,
@@ -152,6 +156,12 @@ class VPBuilder {
152156
Name);
153157
}
154158

159+
/// Create a new ICmp VPInstruction with predicate \p Pred and operands \p A
160+
/// and \p B.
161+
/// TODO: add createFCmp when needed.
162+
VPValue *createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
163+
DebugLoc DL = {}, const Twine &Name = "");
164+
155165
//===--------------------------------------------------------------------===//
156166
// RAII helpers.
157167
//===--------------------------------------------------------------------===//

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7372,6 +7372,14 @@ void LoopVectorizationCostModel::collectInLoopReductions() {
73727372
}
73737373
}
73747374

7375+
VPValue *VPBuilder::createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
7376+
DebugLoc DL, const Twine &Name) {
7377+
assert(Pred >= CmpInst::FIRST_ICMP_PREDICATE &&
7378+
Pred <= CmpInst::LAST_ICMP_PREDICATE && "invalid predicate");
7379+
return tryInsertInstruction(
7380+
new VPInstruction(Instruction::ICmp, Pred, A, B, DL, Name));
7381+
}
7382+
73757383
// TODO: we could return a pair of values that specify the max VF and
73767384
// min VF, to be used in `buildVPlans(MinVF, MaxVF)` instead of
73777385
// `buildVPlans(VF, VF)`. We cannot do it because VPLAN at the moment
@@ -8079,7 +8087,7 @@ void VPRecipeBuilder::createHeaderMask(VPlan &Plan) {
80798087
nullptr, "active.lane.mask");
80808088
} else {
80818089
VPValue *BTC = Plan.getOrCreateBackedgeTakenCount();
8082-
BlockMask = Builder.createNaryOp(VPInstruction::ICmpULE, {IV, BTC});
8090+
BlockMask = Builder.createICmp(CmpInst::ICMP_ULE, IV, BTC);
80838091
}
80848092
BlockMaskCache[Header] = BlockMask;
80858093
}

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,7 @@ class VPRecipeBase : public ilist_node_with_parent<VPRecipeBase, VPBasicBlock>,
814814
/// Class to record LLVM IR flag for a recipe along with it.
815815
class VPRecipeWithIRFlags : public VPRecipeBase {
816816
enum class OperationType : unsigned char {
817+
Cmp,
817818
OverflowingBinOp,
818819
PossiblyExactOp,
819820
GEPOp,
@@ -851,11 +852,12 @@ class VPRecipeWithIRFlags : public VPRecipeBase {
851852
OperationType OpType;
852853

853854
union {
855+
CmpInst::Predicate CmpPredicate;
854856
WrapFlagsTy WrapFlags;
855857
ExactFlagsTy ExactFlags;
856858
GEPFlagsTy GEPFlags;
857859
FastMathFlagsTy FMFs;
858-
unsigned char AllFlags;
860+
unsigned AllFlags;
859861
};
860862

861863
public:
@@ -869,7 +871,10 @@ class VPRecipeWithIRFlags : public VPRecipeBase {
869871
template <typename IterT>
870872
VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, Instruction &I)
871873
: VPRecipeWithIRFlags(SC, Operands) {
872-
if (auto *Op = dyn_cast<OverflowingBinaryOperator>(&I)) {
874+
if (auto *Op = dyn_cast<CmpInst>(&I)) {
875+
OpType = OperationType::Cmp;
876+
CmpPredicate = Op->getPredicate();
877+
} else if (auto *Op = dyn_cast<OverflowingBinaryOperator>(&I)) {
873878
OpType = OperationType::OverflowingBinOp;
874879
WrapFlags = {Op->hasNoUnsignedWrap(), Op->hasNoSignedWrap()};
875880
} else if (auto *Op = dyn_cast<PossiblyExactOperator>(&I)) {
@@ -884,6 +889,12 @@ class VPRecipeWithIRFlags : public VPRecipeBase {
884889
}
885890
}
886891

892+
template <typename IterT>
893+
VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
894+
CmpInst::Predicate Pred)
895+
: VPRecipeBase(SC, Operands), OpType(OperationType::Cmp),
896+
CmpPredicate(Pred) {}
897+
887898
template <typename IterT>
888899
VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
889900
WrapFlagsTy WrapFlags)
@@ -922,6 +933,7 @@ class VPRecipeWithIRFlags : public VPRecipeBase {
922933
FMFs.NoNaNs = false;
923934
FMFs.NoInfs = false;
924935
break;
936+
case OperationType::Cmp:
925937
case OperationType::Other:
926938
break;
927939
}
@@ -949,11 +961,18 @@ class VPRecipeWithIRFlags : public VPRecipeBase {
949961
I->setHasAllowContract(FMFs.AllowContract);
950962
I->setHasApproxFunc(FMFs.ApproxFunc);
951963
break;
964+
case OperationType::Cmp:
952965
case OperationType::Other:
953966
break;
954967
}
955968
}
956969

970+
CmpInst::Predicate getPredicate() const {
971+
assert(OpType == OperationType::Cmp &&
972+
"recipe doesn't have a compare predicate");
973+
return CmpPredicate;
974+
}
975+
957976
bool isInBounds() const {
958977
assert(OpType == OperationType::GEPOp &&
959978
"recipe doesn't have inbounds flag");
@@ -996,7 +1015,6 @@ class VPInstruction : public VPRecipeWithIRFlags, public VPValue {
9961015
Instruction::OtherOpsEnd + 1, // Combines the incoming and previous
9971016
// values of a first-order recurrence.
9981017
Not,
999-
ICmpULE,
10001018
SLPLoad,
10011019
SLPStore,
10021020
ActiveLaneMask,
@@ -1042,6 +1060,9 @@ class VPInstruction : public VPRecipeWithIRFlags, public VPValue {
10421060
DebugLoc DL = {}, const Twine &Name = "")
10431061
: VPInstruction(Opcode, ArrayRef<VPValue *>(Operands), DL, Name) {}
10441062

1063+
VPInstruction(unsigned Opcode, CmpInst::Predicate Pred, VPValue *A,
1064+
VPValue *B, DebugLoc DL = {}, const Twine &Name = "");
1065+
10451066
VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
10461067
WrapFlagsTy WrapFlags, DebugLoc DL = {}, const Twine &Name = "")
10471068
: VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, WrapFlags),

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ bool VPRecipeBase::mayHaveSideEffects() const {
116116
return false;
117117
case VPInstructionSC:
118118
switch (cast<VPInstruction>(this)->getOpcode()) {
119+
case Instruction::ICmp:
119120
case VPInstruction::Not:
120-
case VPInstruction::ICmpULE:
121121
case VPInstruction::CalculateTripCountMinusVF:
122122
case VPInstruction::CanonicalIVIncrement:
123123
case VPInstruction::CanonicalIVIncrementForPart:
@@ -246,6 +246,16 @@ FastMathFlags VPRecipeWithIRFlags::getFastMathFlags() const {
246246
return Res;
247247
}
248248

249+
VPInstruction::VPInstruction(unsigned Opcode, CmpInst::Predicate Pred,
250+
VPValue *A, VPValue *B, DebugLoc DL,
251+
const Twine &Name)
252+
: VPRecipeWithIRFlags(VPDef::VPInstructionSC, ArrayRef<VPValue *>({A, B}),
253+
Pred),
254+
VPValue(this), Opcode(Opcode), DL(DL), Name(Name.str()) {
255+
assert(Opcode == Instruction::ICmp &&
256+
"only ICmp predicates supported at the moment");
257+
}
258+
249259
VPInstruction::VPInstruction(unsigned Opcode,
250260
std::initializer_list<VPValue *> Operands,
251261
FastMathFlags FMFs, DebugLoc DL, const Twine &Name)
@@ -271,10 +281,10 @@ Value *VPInstruction::generateInstruction(VPTransformState &State,
271281
Value *A = State.get(getOperand(0), Part);
272282
return Builder.CreateNot(A, Name);
273283
}
274-
case VPInstruction::ICmpULE: {
284+
case Instruction::ICmp: {
275285
Value *A = State.get(getOperand(0), Part);
276286
Value *B = State.get(getOperand(1), Part);
277-
return Builder.CreateICmpULE(A, B, Name);
287+
return Builder.CreateCmp(getPredicate(), A, B, Name);
278288
}
279289
case Instruction::Select: {
280290
Value *Cond = State.get(getOperand(0), Part);
@@ -444,9 +454,6 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
444454
case VPInstruction::Not:
445455
O << "not";
446456
break;
447-
case VPInstruction::ICmpULE:
448-
O << "icmp ule";
449-
break;
450457
case VPInstruction::SLPLoad:
451458
O << "combined load";
452459
break;
@@ -618,6 +625,9 @@ VPRecipeWithIRFlags::FastMathFlagsTy::FastMathFlagsTy(
618625
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
619626
void VPRecipeWithIRFlags::printFlags(raw_ostream &O) const {
620627
switch (OpType) {
628+
case OperationType::Cmp:
629+
O << " " << CmpInst::getPredicateName(getPredicate());
630+
break;
621631
case OperationType::PossiblyExactOp:
622632
if (ExactFlags.IsExact)
623633
O << " exact";
@@ -741,8 +751,6 @@ void VPWidenRecipe::print(raw_ostream &O, const Twine &Indent,
741751
const Instruction *UI = getUnderlyingInstr();
742752
O << " = " << UI->getOpcodeName();
743753
printFlags(O);
744-
if (auto *Cmp = dyn_cast<CmpInst>(UI))
745-
O << Cmp->getPredicate() << " ";
746754
printOperands(O, SlotTracker);
747755
}
748756
#endif

llvm/test/Transforms/LoopVectorize/vplan-sink-scalars-and-merge.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,7 @@ define void @update_multiple_users(ptr noalias %src, ptr noalias %dst, i1 %c) {
906906
; CHECK-NEXT: pred.store.if:
907907
; CHECK-NEXT: REPLICATE ir<%l1> = load ir<%src>
908908
; CHECK-NEXT: REPLICATE ir<%l2> = trunc ir<%l1>
909-
; CHECK-NEXT: REPLICATE ir<%cmp> = icmp ir<%l1>, ir<0>
909+
; CHECK-NEXT: REPLICATE ir<%cmp> = icmp eq ir<%l1>, ir<0>
910910
; CHECK-NEXT: REPLICATE ir<%sel> = select ir<%cmp>, ir<5>, ir<%l2>
911911
; CHECK-NEXT: REPLICATE store ir<%sel>, ir<%dst>
912912
; CHECK-NEXT: Successor(s): pred.store.continue

0 commit comments

Comments
 (0)