Skip to content

Commit c0506a1

Browse files
authored
[VPlan] Separate out logic to manage IR flags to VPIRFlags (NFC). (#140621)
This patch moves the logic to manage IR flags to a separate VPIRFlags class. For now, VPRecipeWithIRFlags is the only class that inherits VPIRFlags. The new class allows for simpler passing of flags when constructing recipes, simplifying the constructors for various recipes (VPInstruction in particular, which now just has 2 constructors, one taking an extra VPIRFlags argument. This mirrors the approach taken for VPIRMetadata and makes it easier to extend in the future. The patch also adds a unified flagsValidForOpcode to check if the flags in a VPIRFlags match the provided opcode. PR: #140621
1 parent 9a440f8 commit c0506a1

File tree

6 files changed

+171
-205
lines changed

6 files changed

+171
-205
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -164,25 +164,19 @@ class VPBuilder {
164164
DebugLoc DL, const Twine &Name = "") {
165165
return createInstruction(Opcode, Operands, DL, Name);
166166
}
167-
VPInstruction *createNaryOp(unsigned Opcode,
168-
std::initializer_list<VPValue *> Operands,
169-
std::optional<FastMathFlags> FMFs = {},
170-
DebugLoc DL = {}, const Twine &Name = "") {
171-
if (FMFs)
172-
return tryInsertInstruction(
173-
new VPInstruction(Opcode, Operands, *FMFs, DL, Name));
174-
return createInstruction(Opcode, Operands, DL, Name);
167+
VPInstruction *createNaryOp(unsigned Opcode, ArrayRef<VPValue *> Operands,
168+
const VPIRFlags &Flags, DebugLoc DL = {},
169+
const Twine &Name = "") {
170+
return tryInsertInstruction(
171+
new VPInstruction(Opcode, Operands, Flags, DL, Name));
175172
}
173+
176174
VPInstruction *createNaryOp(unsigned Opcode,
177175
std::initializer_list<VPValue *> Operands,
178-
Type *ResultTy,
179-
std::optional<FastMathFlags> FMFs = {},
176+
Type *ResultTy, const VPIRFlags &Flags = {},
180177
DebugLoc DL = {}, const Twine &Name = "") {
181-
if (FMFs)
182-
return tryInsertInstruction(new VPInstructionWithType(
183-
Opcode, Operands, ResultTy, *FMFs, DL, Name));
184178
return tryInsertInstruction(
185-
new VPInstructionWithType(Opcode, Operands, ResultTy, DL, Name));
179+
new VPInstructionWithType(Opcode, Operands, ResultTy, Flags, DL, Name));
186180
}
187181

188182
VPInstruction *createOverflowingOp(unsigned Opcode,
@@ -236,18 +230,20 @@ class VPBuilder {
236230
assert(Pred >= CmpInst::FIRST_ICMP_PREDICATE &&
237231
Pred <= CmpInst::LAST_ICMP_PREDICATE && "invalid predicate");
238232
return tryInsertInstruction(
239-
new VPInstruction(Instruction::ICmp, Pred, A, B, DL, Name));
233+
new VPInstruction(Instruction::ICmp, {A, B}, Pred, DL, Name));
240234
}
241235

242236
VPInstruction *createPtrAdd(VPValue *Ptr, VPValue *Offset, DebugLoc DL = {},
243237
const Twine &Name = "") {
244238
return tryInsertInstruction(
245-
new VPInstruction(Ptr, Offset, GEPNoWrapFlags::none(), DL, Name));
239+
new VPInstruction(VPInstruction::PtrAdd, {Ptr, Offset},
240+
GEPNoWrapFlags::none(), DL, Name));
246241
}
247242
VPValue *createInBoundsPtrAdd(VPValue *Ptr, VPValue *Offset, DebugLoc DL = {},
248243
const Twine &Name = "") {
249244
return tryInsertInstruction(
250-
new VPInstruction(Ptr, Offset, GEPNoWrapFlags::inBounds(), DL, Name));
245+
new VPInstruction(VPInstruction::PtrAdd, {Ptr, Offset},
246+
GEPNoWrapFlags::inBounds(), DL, Name));
251247
}
252248

253249
VPInstruction *createScalarPhi(ArrayRef<VPValue *> IncomingValues,
@@ -269,7 +265,7 @@ class VPBuilder {
269265
VPInstruction *createScalarCast(Instruction::CastOps Opcode, VPValue *Op,
270266
Type *ResultTy, DebugLoc DL) {
271267
return tryInsertInstruction(
272-
new VPInstructionWithType(Opcode, Op, ResultTy, DL));
268+
new VPInstructionWithType(Opcode, Op, ResultTy, {}, DL));
273269
}
274270

275271
VPWidenCastRecipe *createWidenCast(Instruction::CastOps Opcode, VPValue *Op,

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 83 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -577,8 +577,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
577577
#endif
578578
};
579579

580-
/// Class to record LLVM IR flag for a recipe along with it.
581-
class VPRecipeWithIRFlags : public VPSingleDefRecipe {
580+
/// Class to record and manage LLVM IR flags.
581+
class VPIRFlags {
582582
enum class OperationType : unsigned char {
583583
Cmp,
584584
OverflowingBinOp,
@@ -637,23 +637,10 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
637637
unsigned AllFlags;
638638
};
639639

640-
protected:
641-
void transferFlags(VPRecipeWithIRFlags &Other) {
642-
OpType = Other.OpType;
643-
AllFlags = Other.AllFlags;
644-
}
645-
646640
public:
647-
VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
648-
DebugLoc DL = {})
649-
: VPSingleDefRecipe(SC, Operands, DL) {
650-
OpType = OperationType::Other;
651-
AllFlags = 0;
652-
}
641+
VPIRFlags() : OpType(OperationType::Other), AllFlags(0) {}
653642

654-
VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
655-
Instruction &I)
656-
: VPSingleDefRecipe(SC, Operands, &I, I.getDebugLoc()) {
643+
VPIRFlags(Instruction &I) {
657644
if (auto *Op = dyn_cast<CmpInst>(&I)) {
658645
OpType = OperationType::Cmp;
659646
CmpPredicate = Op->getPredicate();
@@ -681,63 +668,27 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
681668
}
682669
}
683670

684-
VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
685-
CmpInst::Predicate Pred, DebugLoc DL = {})
686-
: VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::Cmp),
687-
CmpPredicate(Pred) {}
671+
VPIRFlags(CmpInst::Predicate Pred)
672+
: OpType(OperationType::Cmp), CmpPredicate(Pred) {}
688673

689-
VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
690-
WrapFlagsTy WrapFlags, DebugLoc DL = {})
691-
: VPSingleDefRecipe(SC, Operands, DL),
692-
OpType(OperationType::OverflowingBinOp), WrapFlags(WrapFlags) {}
674+
VPIRFlags(WrapFlagsTy WrapFlags)
675+
: OpType(OperationType::OverflowingBinOp), WrapFlags(WrapFlags) {}
693676

694-
VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
695-
FastMathFlags FMFs, DebugLoc DL = {})
696-
: VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::FPMathOp),
697-
FMFs(FMFs) {}
677+
VPIRFlags(FastMathFlags FMFs) : OpType(OperationType::FPMathOp), FMFs(FMFs) {}
698678

699-
VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
700-
DisjointFlagsTy DisjointFlags, DebugLoc DL = {})
701-
: VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
702-
DisjointFlags(DisjointFlags) {}
679+
VPIRFlags(DisjointFlagsTy DisjointFlags)
680+
: OpType(OperationType::DisjointOp), DisjointFlags(DisjointFlags) {}
703681

704-
template <typename IterT>
705-
VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
706-
NonNegFlagsTy NonNegFlags, DebugLoc DL = {})
707-
: VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp),
708-
NonNegFlags(NonNegFlags) {}
682+
VPIRFlags(NonNegFlagsTy NonNegFlags)
683+
: OpType(OperationType::NonNegOp), NonNegFlags(NonNegFlags) {}
709684

710-
protected:
711-
VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
712-
GEPNoWrapFlags GEPFlags, DebugLoc DL = {})
713-
: VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::GEPOp),
714-
GEPFlags(GEPFlags) {}
685+
VPIRFlags(GEPNoWrapFlags GEPFlags)
686+
: OpType(OperationType::GEPOp), GEPFlags(GEPFlags) {}
715687

716688
public:
717-
static inline bool classof(const VPRecipeBase *R) {
718-
return R->getVPDefID() == VPRecipeBase::VPInstructionSC ||
719-
R->getVPDefID() == VPRecipeBase::VPWidenSC ||
720-
R->getVPDefID() == VPRecipeBase::VPWidenGEPSC ||
721-
R->getVPDefID() == VPRecipeBase::VPWidenCallSC ||
722-
R->getVPDefID() == VPRecipeBase::VPWidenCastSC ||
723-
R->getVPDefID() == VPRecipeBase::VPWidenIntrinsicSC ||
724-
R->getVPDefID() == VPRecipeBase::VPReductionSC ||
725-
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
726-
R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
727-
R->getVPDefID() == VPRecipeBase::VPVectorEndPointerSC ||
728-
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC ||
729-
R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
730-
R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
731-
}
732-
733-
static inline bool classof(const VPUser *U) {
734-
auto *R = dyn_cast<VPRecipeBase>(U);
735-
return R && classof(R);
736-
}
737-
738-
static inline bool classof(const VPValue *V) {
739-
auto *R = dyn_cast_or_null<VPRecipeBase>(V->getDefiningRecipe());
740-
return R && classof(R);
689+
void transferFlags(VPIRFlags &Other) {
690+
OpType = Other.OpType;
691+
AllFlags = Other.AllFlags;
741692
}
742693

743694
/// Drop all poison-generating flags.
@@ -851,11 +802,60 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
851802
return DisjointFlags.IsDisjoint;
852803
}
853804

805+
#if !defined(NDEBUG)
806+
/// Returns true if the set flags are valid for \p Opcode.
807+
bool flagsValidForOpcode(unsigned Opcode) const;
808+
#endif
809+
854810
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
855811
void printFlags(raw_ostream &O) const;
856812
#endif
857813
};
858814

815+
/// A pure-virtual common base class for recipes defining a single VPValue and
816+
/// using IR flags.
817+
struct VPRecipeWithIRFlags : public VPSingleDefRecipe, public VPIRFlags {
818+
VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
819+
DebugLoc DL = {})
820+
: VPSingleDefRecipe(SC, Operands, DL), VPIRFlags() {}
821+
822+
VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
823+
Instruction &I)
824+
: VPSingleDefRecipe(SC, Operands, &I, I.getDebugLoc()), VPIRFlags(I) {}
825+
826+
VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
827+
const VPIRFlags &Flags, DebugLoc DL = {})
828+
: VPSingleDefRecipe(SC, Operands, DL), VPIRFlags(Flags) {}
829+
830+
static inline bool classof(const VPRecipeBase *R) {
831+
return R->getVPDefID() == VPRecipeBase::VPInstructionSC ||
832+
R->getVPDefID() == VPRecipeBase::VPWidenSC ||
833+
R->getVPDefID() == VPRecipeBase::VPWidenGEPSC ||
834+
R->getVPDefID() == VPRecipeBase::VPWidenCallSC ||
835+
R->getVPDefID() == VPRecipeBase::VPWidenCastSC ||
836+
R->getVPDefID() == VPRecipeBase::VPWidenIntrinsicSC ||
837+
R->getVPDefID() == VPRecipeBase::VPReductionSC ||
838+
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
839+
R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
840+
R->getVPDefID() == VPRecipeBase::VPVectorEndPointerSC ||
841+
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC ||
842+
R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
843+
R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
844+
}
845+
846+
static inline bool classof(const VPUser *U) {
847+
auto *R = dyn_cast<VPRecipeBase>(U);
848+
return R && classof(R);
849+
}
850+
851+
static inline bool classof(const VPValue *V) {
852+
auto *R = dyn_cast_or_null<VPRecipeBase>(V->getDefiningRecipe());
853+
return R && classof(R);
854+
}
855+
856+
void execute(VPTransformState &State) override = 0;
857+
};
858+
859859
/// Helper to access the operand that contains the unroll part for this recipe
860860
/// after unrolling.
861861
template <unsigned PartOpIdx> class VPUnrollPartAccessor {
@@ -958,54 +958,21 @@ class VPInstruction : public VPRecipeWithIRFlags,
958958
/// value for lane \p Lane.
959959
Value *generatePerLane(VPTransformState &State, const VPLane &Lane);
960960

961-
#if !defined(NDEBUG)
962-
/// Return true if the VPInstruction is a floating point math operation, i.e.
963-
/// has fast-math flags.
964-
bool isFPMathOp() const;
965-
#endif
966-
967961
public:
968-
VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL,
962+
VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL = {},
969963
const Twine &Name = "")
970964
: VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DL),
971965
Opcode(Opcode), Name(Name.str()) {}
972966

973-
VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
974-
DebugLoc DL = {}, const Twine &Name = "")
975-
: VPInstruction(Opcode, ArrayRef<VPValue *>(Operands), DL, Name) {}
976-
977-
VPInstruction(unsigned Opcode, CmpInst::Predicate Pred, VPValue *A,
978-
VPValue *B, DebugLoc DL = {}, const Twine &Name = "");
979-
980-
VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
981-
WrapFlagsTy WrapFlags, DebugLoc DL = {}, const Twine &Name = "")
982-
: VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, WrapFlags, DL),
983-
Opcode(Opcode), Name(Name.str()) {}
984-
985-
VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
986-
DisjointFlagsTy DisjointFlag, DebugLoc DL = {},
987-
const Twine &Name = "")
988-
: VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DisjointFlag, DL),
989-
Opcode(Opcode), Name(Name.str()) {
990-
assert(Opcode == Instruction::Or && "only OR opcodes can be disjoint");
991-
}
992-
993-
VPInstruction(VPValue *Ptr, VPValue *Offset, GEPNoWrapFlags Flags,
994-
DebugLoc DL = {}, const Twine &Name = "")
995-
: VPRecipeWithIRFlags(VPDef::VPInstructionSC,
996-
ArrayRef<VPValue *>({Ptr, Offset}), Flags, DL),
997-
Opcode(VPInstruction::PtrAdd), Name(Name.str()) {}
998-
999-
VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
1000-
FastMathFlags FMFs, DebugLoc DL = {}, const Twine &Name = "");
967+
VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands,
968+
const VPIRFlags &Flags, DebugLoc DL = {},
969+
const Twine &Name = "");
1001970

1002971
VP_CLASSOF_IMPL(VPDef::VPInstructionSC)
1003972

1004973
VPInstruction *clone() override {
1005974
SmallVector<VPValue *, 2> Operands(operands());
1006-
auto *New = new VPInstruction(Opcode, Operands, getDebugLoc(), Name);
1007-
New->transferFlags(*this);
1008-
return New;
975+
return new VPInstruction(Opcode, Operands, *this, getDebugLoc(), Name);
1009976
}
1010977

1011978
unsigned getOpcode() const { return Opcode; }
@@ -1082,13 +1049,9 @@ class VPInstructionWithType : public VPInstruction {
10821049

10831050
public:
10841051
VPInstructionWithType(unsigned Opcode, ArrayRef<VPValue *> Operands,
1085-
Type *ResultTy, DebugLoc DL, const Twine &Name = "")
1086-
: VPInstruction(Opcode, Operands, DL, Name), ResultTy(ResultTy) {}
1087-
VPInstructionWithType(unsigned Opcode,
1088-
std::initializer_list<VPValue *> Operands,
1089-
Type *ResultTy, FastMathFlags FMFs, DebugLoc DL = {},
1052+
Type *ResultTy, const VPIRFlags &Flags, DebugLoc DL,
10901053
const Twine &Name = "")
1091-
: VPInstruction(Opcode, Operands, FMFs, DL, Name), ResultTy(ResultTy) {}
1054+
: VPInstruction(Opcode, Operands, Flags, DL, Name), ResultTy(ResultTy) {}
10921055

10931056
static inline bool classof(const VPRecipeBase *R) {
10941057
// VPInstructionWithType are VPInstructions with specific opcodes requiring
@@ -1113,8 +1076,9 @@ class VPInstructionWithType : public VPInstruction {
11131076

11141077
VPInstruction *clone() override {
11151078
SmallVector<VPValue *, 2> Operands(operands());
1116-
auto *New = new VPInstructionWithType(
1117-
getOpcode(), Operands, getResultType(), getDebugLoc(), getName());
1079+
auto *New =
1080+
new VPInstructionWithType(getOpcode(), Operands, getResultType(), *this,
1081+
getDebugLoc(), getName());
11181082
New->setUnderlyingValue(getUnderlyingValue());
11191083
return New;
11201084
}
@@ -1373,15 +1337,12 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags, public VPIRMetadata {
13731337
}
13741338

13751339
VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
1376-
DebugLoc DL = {})
1377-
: VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, DL), VPIRMetadata(),
1378-
Opcode(Opcode), ResultTy(ResultTy) {}
1379-
1380-
VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
1381-
bool IsNonNeg, DebugLoc DL = {})
1382-
: VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, NonNegFlagsTy(IsNonNeg),
1383-
DL),
1384-
Opcode(Opcode), ResultTy(ResultTy) {}
1340+
const VPIRFlags &Flags = {}, DebugLoc DL = {})
1341+
: VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, Flags, DL),
1342+
VPIRMetadata(), Opcode(Opcode), ResultTy(ResultTy) {
1343+
assert(flagsValidForOpcode(Opcode) &&
1344+
"Set flags not supported for the provided opcode");
1345+
}
13851346

13861347
~VPWidenCastRecipe() override = default;
13871348

0 commit comments

Comments
 (0)