Skip to content

[SelectionDAG] Add space-optimized forms of OPC_CheckPredicate #77763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGISel.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ class SelectionDAGISel : public MachineFunctionPass {
OPC_CheckPatternPredicate7,
OPC_CheckPatternPredicateTwoByte,
OPC_CheckPredicate,
OPC_CheckPredicate0,
OPC_CheckPredicate1,
OPC_CheckPredicate2,
OPC_CheckPredicate3,
OPC_CheckPredicate4,
OPC_CheckPredicate5,
OPC_CheckPredicate6,
OPC_CheckPredicate7,
OPC_CheckPredicateWithOperands,
OPC_CheckOpcode,
OPC_SwitchOpcode,
Expand Down
30 changes: 25 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2712,9 +2712,13 @@ CheckPatternPredicate(unsigned Opcode, const unsigned char *MatcherTable,

/// CheckNodePredicate - Implements OP_CheckNodePredicate.
LLVM_ATTRIBUTE_ALWAYS_INLINE static bool
CheckNodePredicate(const unsigned char *MatcherTable, unsigned &MatcherIndex,
const SelectionDAGISel &SDISel, SDNode *N) {
return SDISel.CheckNodePredicate(N, MatcherTable[MatcherIndex++]);
CheckNodePredicate(unsigned Opcode, const unsigned char *MatcherTable,
unsigned &MatcherIndex, const SelectionDAGISel &SDISel,
SDNode *N) {
unsigned PredNo = Opcode == SelectionDAGISel::OPC_CheckPredicate
? MatcherTable[MatcherIndex++]
: Opcode - SelectionDAGISel::OPC_CheckPredicate0;
return SDISel.CheckNodePredicate(N, PredNo);
}

LLVM_ATTRIBUTE_ALWAYS_INLINE static bool
Expand Down Expand Up @@ -2868,7 +2872,15 @@ static unsigned IsPredicateKnownToFail(const unsigned char *Table,
Result = !::CheckPatternPredicate(Opcode, Table, Index, SDISel);
return Index;
case SelectionDAGISel::OPC_CheckPredicate:
Result = !::CheckNodePredicate(Table, Index, SDISel, N.getNode());
case SelectionDAGISel::OPC_CheckPredicate0:
case SelectionDAGISel::OPC_CheckPredicate1:
case SelectionDAGISel::OPC_CheckPredicate2:
case SelectionDAGISel::OPC_CheckPredicate3:
case SelectionDAGISel::OPC_CheckPredicate4:
case SelectionDAGISel::OPC_CheckPredicate5:
case SelectionDAGISel::OPC_CheckPredicate6:
case SelectionDAGISel::OPC_CheckPredicate7:
Result = !::CheckNodePredicate(Opcode, Table, Index, SDISel, N.getNode());
return Index;
case SelectionDAGISel::OPC_CheckOpcode:
Result = !::CheckOpcode(Table, Index, N.getNode());
Expand Down Expand Up @@ -3359,8 +3371,16 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
if (!::CheckPatternPredicate(Opcode, MatcherTable, MatcherIndex, *this))
break;
continue;
case SelectionDAGISel::OPC_CheckPredicate0:
case SelectionDAGISel::OPC_CheckPredicate1:
case SelectionDAGISel::OPC_CheckPredicate2:
case SelectionDAGISel::OPC_CheckPredicate3:
case SelectionDAGISel::OPC_CheckPredicate4:
case SelectionDAGISel::OPC_CheckPredicate5:
case SelectionDAGISel::OPC_CheckPredicate6:
case SelectionDAGISel::OPC_CheckPredicate7:
case OPC_CheckPredicate:
if (!::CheckNodePredicate(MatcherTable, MatcherIndex, *this,
if (!::CheckNodePredicate(Opcode, MatcherTable, MatcherIndex, *this,
N.getNode()))
break;
continue;
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/TableGen/address-space-patfrags.td
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def inst_d : Instruction {
let InOperandList = (ins GPR32:$src0, GPR32:$src1);
}

// SDAG: case 2: {
// SDAG: case 1: {
// SDAG-NEXT: // Predicate_pat_frag_b
// SDAG-NEXT: // Predicate_truncstorei16_addrspace
// SDAG-NEXT: SDNode *N = Node;
Expand All @@ -69,7 +69,7 @@ def : Pat <
>;


// SDAG: case 3: {
// SDAG: case 6: {
// SDAG: // Predicate_pat_frag_a
// SDAG-NEXT: SDNode *N = Node;
// SDAG-NEXT: (void)N;
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/TableGen/predicate-patfags.td
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def TGTmul24_oneuse : PatFrag<
}

// SDAG: OPC_CheckOpcode, TARGET_VAL(ISD::INTRINSIC_W_CHAIN),
// SDAG: OPC_CheckPredicate, 0, // Predicate_TGTmul24_oneuse
// SDAG: OPC_CheckPredicate0, // Predicate_TGTmul24_oneuse

// SDAG: OPC_CheckOpcode, TARGET_VAL(TargetISD::MUL24),
// SDAG: OPC_CheckPredicate, 0, // Predicate_TGTmul24_oneuse
// SDAG: OPC_CheckPredicate0, // Predicate_TGTmul24_oneuse

// GISEL: GIM_CheckOpcode, /*MI*/1, GIMT_Encode2(TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS),
// GISEL: GIM_CheckIntrinsicID, /*MI*/1, /*Op*/1, GIMT_Encode2(Intrinsic::tgt_mul24),
Expand Down
96 changes: 59 additions & 37 deletions llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ class MatcherTableEmitter {

SmallVector<unsigned, Matcher::HighestKind+1> OpcodeCounts;

DenseMap<TreePattern *, unsigned> NodePredicateMap;
std::vector<TreePredicateFn> NodePredicates;
std::vector<TreePredicateFn> NodePredicatesWithOperands;
std::vector<TreePattern *> NodePredicates;
std::vector<TreePattern *> NodePredicatesWithOperands;

// We de-duplicate the predicates by code string, and use this map to track
// all the patterns with "identical" predicates.
Expand Down Expand Up @@ -87,7 +86,9 @@ class MatcherTableEmitter {
// Record the usage of ComplexPattern.
MapVector<const ComplexPattern *, unsigned> ComplexPatternUsage;
// Record the usage of PatternPredicate.
std::map<StringRef, unsigned> PatternPredicateUsage;
MapVector<StringRef, unsigned> PatternPredicateUsage;
// Record the usage of Predicate.
MapVector<TreePattern *, unsigned> PredicateUsage;

// Iterate the whole MatcherTable once and do some statistics.
std::function<void(const Matcher *)> Statistic = [&](const Matcher *N) {
Expand All @@ -105,6 +106,8 @@ class MatcherTableEmitter {
++ComplexPatternUsage[&CPM->getPattern()];
else if (auto *CPPM = dyn_cast<CheckPatternPredicateMatcher>(N))
++PatternPredicateUsage[CPPM->getPredicate()];
else if (auto *PM = dyn_cast<CheckPredicateMatcher>(N))
++PredicateUsage[PM->getPredicate().getOrigPatFragRecord()];
N = N->getNext();
}
};
Expand All @@ -127,6 +130,40 @@ class MatcherTableEmitter {
});
for (const auto &PatternPredicate : PatternPredicateList)
PatternPredicates.push_back(PatternPredicate.first);

// Sort Predicates by usage.
// Merge predicates with same code.
for (const auto &Usage : PredicateUsage) {
TreePattern *TP = Usage.first;
TreePredicateFn Pred(TP);
NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()].push_back(TP);
}

std::vector<std::pair<TreePattern *, unsigned>> PredicateList;
// Sum the usage.
for (auto &Predicate : NodePredicatesByCodeToRun) {
TinyPtrVector<TreePattern *> &TPs = Predicate.second;
stable_sort(TPs, [](const auto *A, const auto *B) {
return A->getRecord()->getName() < B->getRecord()->getName();
});
unsigned Uses = 0;
for (TreePattern *TP : TPs)
Uses += PredicateUsage[TP];

// We only add the first predicate here since they are with the same code.
PredicateList.push_back({TPs[0], Uses});
}

stable_sort(PredicateList, [](const auto &A, const auto &B) {
return A.second > B.second;
});
for (const auto &Predicate : PredicateList) {
TreePattern *TP = Predicate.first;
if (TreePredicateFn(TP).usesOperands())
NodePredicatesWithOperands.push_back(TP);
else
NodePredicates.push_back(TP);
}
}

unsigned EmitMatcherList(const Matcher *N, const unsigned Indent,
Expand All @@ -141,7 +178,7 @@ class MatcherTableEmitter {
void EmitPatternMatchTable(raw_ostream &OS);

private:
void EmitNodePredicatesFunction(const std::vector<TreePredicateFn> &Preds,
void EmitNodePredicatesFunction(const std::vector<TreePattern *> &Preds,
StringRef Decl, raw_ostream &OS);

unsigned SizeMatcher(Matcher *N, raw_ostream &OS);
Expand All @@ -150,33 +187,13 @@ class MatcherTableEmitter {
raw_ostream &OS);

unsigned getNodePredicate(TreePredicateFn Pred) {
TreePattern *TP = Pred.getOrigPatFragRecord();
unsigned &Entry = NodePredicateMap[TP];
if (Entry == 0) {
TinyPtrVector<TreePattern *> &SameCodePreds =
NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()];
if (SameCodePreds.empty()) {
// We've never seen a predicate with the same code: allocate an entry.
if (Pred.usesOperands()) {
NodePredicatesWithOperands.push_back(Pred);
Entry = NodePredicatesWithOperands.size();
} else {
NodePredicates.push_back(Pred);
Entry = NodePredicates.size();
}
} else {
// We did see an identical predicate: re-use it.
Entry = NodePredicateMap[SameCodePreds.front()];
assert(Entry != 0);
assert(TreePredicateFn(SameCodePreds.front()).usesOperands() ==
Pred.usesOperands() &&
"PatFrags with some code must have same usesOperands setting");
}
// In both cases, we've never seen this particular predicate before, so
// mark it in the list of predicates sharing the same code.
SameCodePreds.push_back(TP);
}
return Entry-1;
// We use the first predicate.
TreePattern *PredPat =
NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()][0];
return Pred.usesOperands()
? llvm::find(NodePredicatesWithOperands, PredPat) -
NodePredicatesWithOperands.begin()
: llvm::find(NodePredicates, PredPat) - NodePredicates.begin();
}

unsigned getPatternPredicate(StringRef PredName) {
Expand Down Expand Up @@ -531,6 +548,7 @@ EmitMatcher(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
case Matcher::CheckPredicate: {
TreePredicateFn Pred = cast<CheckPredicateMatcher>(N)->getPredicate();
unsigned OperandBytes = 0;
unsigned PredNo = getNodePredicate(Pred);

if (Pred.usesOperands()) {
unsigned NumOps = cast<CheckPredicateMatcher>(N)->getNumOperands();
Expand All @@ -539,10 +557,15 @@ EmitMatcher(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
OS << cast<CheckPredicateMatcher>(N)->getOperandNo(i) << ", ";
OperandBytes = 1 + NumOps;
} else {
OS << "OPC_CheckPredicate, ";
if (PredNo < 8) {
OperandBytes = -1;
OS << "OPC_CheckPredicate" << PredNo << ", ";
} else
OS << "OPC_CheckPredicate, ";
}

OS << getNodePredicate(Pred) << ',';
if (PredNo >= 8 || Pred.usesOperands())
OS << PredNo << ',';
if (!OmitComments)
OS << " // " << Pred.getFnName();
OS << '\n';
Expand Down Expand Up @@ -1031,8 +1054,7 @@ EmitMatcherList(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
}

void MatcherTableEmitter::EmitNodePredicatesFunction(
const std::vector<TreePredicateFn> &Preds, StringRef Decl,
raw_ostream &OS) {
const std::vector<TreePattern *> &Preds, StringRef Decl, raw_ostream &OS) {
if (Preds.empty())
return;

Expand All @@ -1042,7 +1064,7 @@ void MatcherTableEmitter::EmitNodePredicatesFunction(
OS << " default: llvm_unreachable(\"Invalid predicate in table?\");\n";
for (unsigned i = 0, e = Preds.size(); i != e; ++i) {
// Emit the predicate code corresponding to this pattern.
const TreePredicateFn PredFn = Preds[i];
TreePredicateFn PredFn(Preds[i]);
assert(!PredFn.isAlwaysTrue() && "No code in this predicate");
std::string PredFnCodeStr = PredFn.getCodeToRunOnSDNode();

Expand Down