Skip to content

Commit f391de8

Browse files
committed
[TableGen] Store predicates in PatternToMatch as ListInit *. Add string for HwModeFeatures
This uses to be how predicates were handled prior to HwMode being added. When the Predicates were converted to a std::vector it significantly increased the cost of a compare in GenerateVariants. Since ListInit's are uniquified by tablegen, we can use a simple pointer comparison to check for identical lists. In order to store the HwMode, we now add a separate string to PatternToMatch. This will be appended separately to the predicate string in getPredicateCheck. A new getPredicateRecords is added to allow GlobalISel and getPredicateCheck to both get the sorted list of Records. GlobalISel was ignoring any HwMode predicates before and still is. There is one slight change here, ListInits with different predicate orders aren't sorted so the filtering in GenerateVariants might fail to detect two isomorphic patterns with different predicate orders. This doesn't seem to be happening in tree today. My hope is this will allow us to remove all the BitVector tracking in GenerateVariants that was making up for predicates beeing expensive to compare. There's a decent amount of heap allocations there on large targets like X86, AMDGPU, and RISCV. Differential Revision: https://reviews.llvm.org/D100691
1 parent 01b0980 commit f391de8

File tree

3 files changed

+79
-114
lines changed

3 files changed

+79
-114
lines changed

llvm/utils/TableGen/CodeGenDAGPatterns.cpp

Lines changed: 56 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,24 +1440,50 @@ getPatternComplexity(const CodeGenDAGPatterns &CGP) const {
14401440
return getPatternSize(getSrcPattern(), CGP) + getAddedComplexity();
14411441
}
14421442

1443+
void PatternToMatch::getPredicateRecords(
1444+
SmallVectorImpl<Record *> &PredicateRecs) const {
1445+
for (Init *I : Predicates->getValues()) {
1446+
if (DefInit *Pred = dyn_cast<DefInit>(I)) {
1447+
Record *Def = Pred->getDef();
1448+
if (!Def->isSubClassOf("Predicate")) {
1449+
#ifndef NDEBUG
1450+
Def->dump();
1451+
#endif
1452+
llvm_unreachable("Unknown predicate type!");
1453+
}
1454+
PredicateRecs.push_back(Def);
1455+
}
1456+
}
1457+
// Sort so that different orders get canonicalized to the same string.
1458+
llvm::sort(PredicateRecs, LessRecord());
1459+
}
1460+
14431461
/// getPredicateCheck - Return a single string containing all of this
14441462
/// pattern's predicates concatenated with "&&" operators.
14451463
///
14461464
std::string PatternToMatch::getPredicateCheck() const {
1447-
SmallVector<const Predicate*,4> PredList;
1448-
for (const Predicate &P : Predicates) {
1449-
if (!P.getCondString().empty())
1450-
PredList.push_back(&P);
1465+
SmallVector<Record *, 4> PredicateRecs;
1466+
getPredicateRecords(PredicateRecs);
1467+
1468+
SmallString<128> PredicateCheck;
1469+
for (Record *Pred : PredicateRecs) {
1470+
StringRef CondString = Pred->getValueAsString("CondString");
1471+
if (CondString.empty())
1472+
continue;
1473+
if (!PredicateCheck.empty())
1474+
PredicateCheck += " && ";
1475+
PredicateCheck += "(";
1476+
PredicateCheck += CondString;
1477+
PredicateCheck += ")";
14511478
}
1452-
llvm::sort(PredList, deref<std::less<>>());
14531479

1454-
std::string Check;
1455-
for (unsigned i = 0, e = PredList.size(); i != e; ++i) {
1456-
if (i != 0)
1457-
Check += " && ";
1458-
Check += '(' + PredList[i]->getCondString() + ')';
1480+
if (!HwModeFeatures.empty()) {
1481+
if (!PredicateCheck.empty())
1482+
PredicateCheck += " && ";
1483+
PredicateCheck += HwModeFeatures;
14591484
}
1460-
return Check;
1485+
1486+
return std::string(PredicateCheck);
14611487
}
14621488

14631489
//===----------------------------------------------------------------------===//
@@ -3930,20 +3956,6 @@ static void FindNames(TreePatternNode *P,
39303956
}
39313957
}
39323958

3933-
std::vector<Predicate> CodeGenDAGPatterns::makePredList(ListInit *L) {
3934-
std::vector<Predicate> Preds;
3935-
for (Init *I : L->getValues()) {
3936-
if (DefInit *Pred = dyn_cast<DefInit>(I))
3937-
Preds.push_back(Pred->getDef());
3938-
else
3939-
llvm_unreachable("Non-def on the list");
3940-
}
3941-
3942-
// Sort so that different orders get canonicalized to the same string.
3943-
llvm::sort(Preds);
3944-
return Preds;
3945-
}
3946-
39473959
void CodeGenDAGPatterns::AddPatternToMatch(TreePattern *Pattern,
39483960
PatternToMatch &&PTM) {
39493961
// Do some sanity checking on the pattern we're about to match.
@@ -4254,8 +4266,7 @@ void CodeGenDAGPatterns::ParseOnePattern(Record *TheDef,
42544266
for (const auto &T : Pattern.getTrees())
42554267
if (T->hasPossibleType())
42564268
AddPatternToMatch(&Pattern,
4257-
PatternToMatch(TheDef, makePredList(Preds),
4258-
T, Temp.getOnlyTree(),
4269+
PatternToMatch(TheDef, Preds, T, Temp.getOnlyTree(),
42594270
InstImpResults, Complexity,
42604271
TheDef->getID()));
42614272
}
@@ -4310,20 +4321,17 @@ void CodeGenDAGPatterns::ExpandHwModeBasedTypes() {
43104321
PatternsToMatch.swap(Copy);
43114322

43124323
auto AppendPattern = [this](PatternToMatch &P, unsigned Mode,
4313-
ArrayRef<Predicate> Check) {
4324+
StringRef Check) {
43144325
TreePatternNodePtr NewSrc = P.getSrcPattern()->clone();
43154326
TreePatternNodePtr NewDst = P.getDstPattern()->clone();
43164327
if (!NewSrc->setDefaultMode(Mode) || !NewDst->setDefaultMode(Mode)) {
43174328
return;
43184329
}
43194330

4320-
std::vector<Predicate> Preds = P.getPredicates();
4321-
llvm::append_range(Preds, Check);
4322-
PatternsToMatch.emplace_back(P.getSrcRecord(), std::move(Preds),
4331+
PatternsToMatch.emplace_back(P.getSrcRecord(), P.getPredicates(),
43234332
std::move(NewSrc), std::move(NewDst),
4324-
P.getDstRegs(),
4325-
P.getAddedComplexity(), Record::getNewUID(),
4326-
Mode);
4333+
P.getDstRegs(), P.getAddedComplexity(),
4334+
Record::getNewUID(), Mode, Check);
43274335
};
43284336

43294337
for (PatternToMatch &P : Copy) {
@@ -4354,18 +4362,22 @@ void CodeGenDAGPatterns::ExpandHwModeBasedTypes() {
43544362
// duplicated patterns with different predicate checks, construct the
43554363
// default check as a negation of all predicates that are actually present
43564364
// in the source/destination patterns.
4357-
SmallVector<Predicate, 2> DefaultCheck;
4365+
SmallString<128> DefaultCheck;
43584366

43594367
for (unsigned M : Modes) {
43604368
if (M == DefaultMode)
43614369
continue;
43624370

43634371
// Fill the map entry for this mode.
43644372
const HwMode &HM = CGH.getMode(M);
4365-
AppendPattern(P, M, Predicate(HM.Features, true));
4373+
AppendPattern(P, M, "(MF->getSubtarget().checkFeatures(\"" + HM.Features + "\"))");
43664374

43674375
// Add negations of the HM's predicates to the default predicate.
4368-
DefaultCheck.push_back(Predicate(HM.Features, false));
4376+
if (!DefaultCheck.empty())
4377+
DefaultCheck += " && ";
4378+
DefaultCheck += "(!(MF->getSubtarget().checkFeatures(\"";
4379+
DefaultCheck += HM.Features;
4380+
DefaultCheck += "\")))";
43694381
}
43704382

43714383
bool HasDefault = Modes.count(DefaultMode);
@@ -4685,8 +4697,8 @@ void CodeGenDAGPatterns::GenerateVariants() {
46854697
if (MatchedPatterns[i])
46864698
continue;
46874699

4688-
const std::vector<Predicate> &Predicates =
4689-
PatternsToMatch[i].getPredicates();
4700+
ListInit *Predicates = PatternsToMatch[i].getPredicates();
4701+
StringRef HwModeFeatures = PatternsToMatch[i].getHwModeFeatures();
46904702

46914703
BitVector &Matches = MatchedPredicates[i];
46924704
MatchedPatterns.set(i);
@@ -4695,7 +4707,8 @@ void CodeGenDAGPatterns::GenerateVariants() {
46954707
// Don't test patterns that have already been cached - it won't match.
46964708
for (unsigned p = 0; p != NumOriginalPatterns; ++p)
46974709
if (!MatchedPatterns[p])
4698-
Matches[p] = (Predicates == PatternsToMatch[p].getPredicates());
4710+
Matches[p] = (Predicates == PatternsToMatch[p].getPredicates()) &&
4711+
(HwModeFeatures == PatternsToMatch[p].getHwModeFeatures());
46994712

47004713
// Copy this to all the matching patterns.
47014714
for (int p = Matches.find_first(); p != -1; p = Matches.find_next(p))
@@ -4739,7 +4752,9 @@ void CodeGenDAGPatterns::GenerateVariants() {
47394752
PatternsToMatch[i].getSrcRecord(), PatternsToMatch[i].getPredicates(),
47404753
Variant, PatternsToMatch[i].getDstPatternShared(),
47414754
PatternsToMatch[i].getDstRegs(),
4742-
PatternsToMatch[i].getAddedComplexity(), Record::getNewUID());
4755+
PatternsToMatch[i].getAddedComplexity(), Record::getNewUID(),
4756+
PatternsToMatch[i].getForceMode(),
4757+
PatternsToMatch[i].getHwModeFeatures().str());
47434758
MatchedPredicates.push_back(Matches);
47444759

47454760
// Add a new match the same as this pattern.

llvm/utils/TableGen/CodeGenDAGPatterns.h

Lines changed: 13 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,92 +1047,43 @@ class DAGInstruction {
10471047
TreePatternNodePtr getResultPattern() const { return ResultPattern; }
10481048
};
10491049

1050-
/// This class represents a condition that has to be satisfied for a pattern
1051-
/// to be tried. It is a generalization of a class "Pattern" from Target.td:
1052-
/// in addition to the Target.td's predicates, this class can also represent
1053-
/// conditions associated with HW modes. Both types will eventually become
1054-
/// strings containing C++ code to be executed, the difference is in how
1055-
/// these strings are generated.
1056-
class Predicate {
1057-
public:
1058-
Predicate(Record *R, bool C = true) : Def(R), IfCond(C), IsHwMode(false) {
1059-
assert(R->isSubClassOf("Predicate") &&
1060-
"Predicate objects should only be created for records derived"
1061-
"from Predicate class");
1062-
}
1063-
Predicate(StringRef FS, bool C = true) : Def(nullptr), Features(FS.str()),
1064-
IfCond(C), IsHwMode(true) {}
1065-
1066-
/// Return a string which contains the C++ condition code that will serve
1067-
/// as a predicate during instruction selection.
1068-
std::string getCondString() const {
1069-
// The string will excute in a subclass of SelectionDAGISel.
1070-
// Cast to std::string explicitly to avoid ambiguity with StringRef.
1071-
std::string C = IsHwMode
1072-
? std::string("MF->getSubtarget().checkFeatures(\"" +
1073-
Features + "\")")
1074-
: std::string(Def->getValueAsString("CondString"));
1075-
if (C.empty())
1076-
return "";
1077-
return IfCond ? C : "!("+C+')';
1078-
}
1079-
1080-
bool operator==(const Predicate &P) const {
1081-
return IfCond == P.IfCond && IsHwMode == P.IsHwMode && Def == P.Def &&
1082-
Features == P.Features;
1083-
}
1084-
bool operator<(const Predicate &P) const {
1085-
if (IsHwMode != P.IsHwMode)
1086-
return IsHwMode < P.IsHwMode;
1087-
assert(!Def == !P.Def && "Inconsistency between Def and IsHwMode");
1088-
if (IfCond != P.IfCond)
1089-
return IfCond < P.IfCond;
1090-
if (Def)
1091-
return LessRecord()(Def, P.Def);
1092-
return Features < P.Features;
1093-
}
1094-
Record *Def; ///< Predicate definition from .td file, null for
1095-
///< HW modes.
1096-
std::string Features; ///< Feature string for HW mode.
1097-
bool IfCond; ///< The boolean value that the condition has to
1098-
///< evaluate to for this predicate to be true.
1099-
bool IsHwMode; ///< Does this predicate correspond to a HW mode?
1100-
};
1101-
11021050
/// PatternToMatch - Used by CodeGenDAGPatterns to keep tab of patterns
11031051
/// processed to produce isel.
11041052
class PatternToMatch {
11051053
Record *SrcRecord; // Originating Record for the pattern.
1054+
ListInit *Predicates; // Top level predicate conditions to match.
11061055
TreePatternNodePtr SrcPattern; // Source pattern to match.
11071056
TreePatternNodePtr DstPattern; // Resulting pattern.
1108-
std::vector<Predicate> Predicates; // Top level predicate conditions
1109-
// to match.
11101057
std::vector<Record*> Dstregs; // Physical register defs being matched.
1058+
std::string HwModeFeatures;
11111059
int AddedComplexity; // Add to matching pattern complexity.
11121060
unsigned ID; // Unique ID for the record.
11131061
unsigned ForceMode; // Force this mode in type inference when set.
11141062

11151063
public:
1116-
PatternToMatch(Record *srcrecord, std::vector<Predicate> preds,
1117-
TreePatternNodePtr src, TreePatternNodePtr dst,
1118-
std::vector<Record *> dstregs, int complexity,
1119-
unsigned uid, unsigned setmode = 0)
1120-
: SrcRecord(srcrecord), SrcPattern(src), DstPattern(dst),
1121-
Predicates(std::move(preds)), Dstregs(std::move(dstregs)),
1122-
AddedComplexity(complexity), ID(uid), ForceMode(setmode) {}
1064+
PatternToMatch(Record *srcrecord, ListInit *preds, TreePatternNodePtr src,
1065+
TreePatternNodePtr dst, std::vector<Record *> dstregs,
1066+
int complexity, unsigned uid, unsigned setmode = 0,
1067+
const Twine &hwmodefeatures = "")
1068+
: SrcRecord(srcrecord), Predicates(preds), SrcPattern(src),
1069+
DstPattern(dst), Dstregs(std::move(dstregs)),
1070+
HwModeFeatures(hwmodefeatures.str()), AddedComplexity(complexity),
1071+
ID(uid), ForceMode(setmode) {}
11231072

11241073
Record *getSrcRecord() const { return SrcRecord; }
1074+
ListInit *getPredicates() const { return Predicates; }
11251075
TreePatternNode *getSrcPattern() const { return SrcPattern.get(); }
11261076
TreePatternNodePtr getSrcPatternShared() const { return SrcPattern; }
11271077
TreePatternNode *getDstPattern() const { return DstPattern.get(); }
11281078
TreePatternNodePtr getDstPatternShared() const { return DstPattern; }
11291079
const std::vector<Record*> &getDstRegs() const { return Dstregs; }
1080+
StringRef getHwModeFeatures() const { return HwModeFeatures; }
11301081
int getAddedComplexity() const { return AddedComplexity; }
1131-
const std::vector<Predicate> &getPredicates() const { return Predicates; }
11321082
unsigned getID() const { return ID; }
11331083
unsigned getForceMode() const { return ForceMode; }
11341084

11351085
std::string getPredicateCheck() const;
1086+
void getPredicateRecords(SmallVectorImpl<Record *> &PredicateRecs) const;
11361087

11371088
/// Compute the complexity metric for the input pattern. This roughly
11381089
/// corresponds to the number of nodes that are covered.
@@ -1290,8 +1241,6 @@ class CodeGenDAGPatterns {
12901241
void GenerateVariants();
12911242
void VerifyInstructionFlags();
12921243

1293-
std::vector<Predicate> makePredList(ListInit *L);
1294-
12951244
void ParseOnePattern(Record *TheDef,
12961245
TreePattern &Pattern, TreePattern &Result,
12971246
const std::vector<Record *> &InstImpResults);

llvm/utils/TableGen/GlobalISelEmitter.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3536,7 +3536,7 @@ class GlobalISelEmitter {
35363536
const CodeGenInstruction *getEquivNode(Record &Equiv,
35373537
const TreePatternNode *N) const;
35383538

3539-
Error importRulePredicates(RuleMatcher &M, ArrayRef<Predicate> Predicates);
3539+
Error importRulePredicates(RuleMatcher &M, ArrayRef<Record *> Predicates);
35403540
Expected<InstructionMatcher &>
35413541
createAndImportSelDAGMatcher(RuleMatcher &Rule,
35423542
InstructionMatcher &InsnMatcher,
@@ -3723,14 +3723,13 @@ GlobalISelEmitter::GlobalISelEmitter(RecordKeeper &RK)
37233723

37243724
//===- Emitter ------------------------------------------------------------===//
37253725

3726-
Error
3727-
GlobalISelEmitter::importRulePredicates(RuleMatcher &M,
3728-
ArrayRef<Predicate> Predicates) {
3729-
for (const Predicate &P : Predicates) {
3730-
if (!P.Def || P.getCondString().empty())
3726+
Error GlobalISelEmitter::importRulePredicates(RuleMatcher &M,
3727+
ArrayRef<Record *> Predicates) {
3728+
for (Record *Pred : Predicates) {
3729+
if (Pred->getValueAsString("CondString").empty())
37313730
continue;
3732-
declareSubtargetFeature(P.Def);
3733-
M.addRequiredFeature(P.Def);
3731+
declareSubtargetFeature(Pred);
3732+
M.addRequiredFeature(Pred);
37343733
}
37353734

37363735
return Error::success();
@@ -5042,7 +5041,9 @@ Expected<RuleMatcher> GlobalISelEmitter::runOnPattern(const PatternToMatch &P) {
50425041
" => " +
50435042
llvm::to_string(*P.getDstPattern()));
50445043

5045-
if (auto Error = importRulePredicates(M, P.getPredicates()))
5044+
SmallVector<Record *, 4> Predicates;
5045+
P.getPredicateRecords(Predicates);
5046+
if (auto Error = importRulePredicates(M, Predicates))
50465047
return std::move(Error);
50475048

50485049
// Next, analyze the pattern operators.

0 commit comments

Comments
 (0)