Skip to content

[TableGen] Add const variants of accessors for backend #106658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clang/utils/TableGen/ClangAttrEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ static StringRef NormalizeGNUAttrSpelling(StringRef AttrSpelling) {

typedef std::vector<std::pair<std::string, const Record *>> ParsedAttrMap;

static ParsedAttrMap getParsedAttrList(const RecordKeeper &Records,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expected to see more const instead of less. Why is this here (and below) passes as non-const?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see the discourse thread: https://discourse.llvm.org/t/changing-tablegen-getallderiveddefinitions-to-return-arrayref-const-record/80586/5

Changing all backends to use const in a single go will be too large a change, so this is a way to stage it. Over time, all backends should move to const variant, and then the non-const ones can be deleted.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to have those conversions queued up to see how this progresses towards that goal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will work on some of them, but not sure if I have the bandwidth to address all of them (Maybe the first few might be more involved and then we will have a better idea of the scope). These will be in similar vein to what I have already done for the IntrinsicEmitter and JSON/Detailed record emitters.

Copy link
Contributor Author

@jurahul jurahul Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose the following next steps (once this is committed):

  1. We need to send a PSA as current downstream backends may break due to this (fix is easy)
  2. I'll work on migrating one of each MLIR. Clang, and may be another LLVM backend to use const.
  3. For the rest, we need to chip away. May be we try to recruit volunteers?
  4. Before we can deprecate the non-const functions completely, give enough time for downstream backends to change. Maybe this should be a part of the PSA in (1).
  5. Delete the non-const functions here, and any other const vs non-const difference we may have introduced during the transition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's an example of migrating one MLIR backend: 8ffa5e9

Copy link
Contributor Author

@jurahul jurahul Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here's for the entire MLIR tablegen code: 061076d

So may be its not that involved. Note that this does not address all the const correctness issues, as const Record* member function may themselves return non-const pointers. So that needs to be fixed gradually as well. So I propose we get this commit in, and then the MLIR one as a follow on. I will see if I can do the clang one as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And for all clang-tablegen its here: abf19e5

So looks like we may be able to migrate quickly, may be wait for a few days (weeks?) for downstream adoption and delete the non-const overloads.

static ParsedAttrMap getParsedAttrList(RecordKeeper &Records,
ParsedAttrMap *Dupes = nullptr,
bool SemaOnly = true) {
std::vector<Record *> Attrs = Records.getAllDerivedDefinitions("Attr");
Expand Down Expand Up @@ -4344,7 +4344,7 @@ static void GenerateAppertainsTo(const Record &Attr, raw_ostream &OS) {
// written into OS and the checks for merging declaration attributes are
// written into MergeOS.
static void GenerateMutualExclusionsChecks(const Record &Attr,
const RecordKeeper &Records,
RecordKeeper &Records,
raw_ostream &OS,
raw_ostream &MergeDeclOS,
raw_ostream &MergeStmtOS) {
Expand Down
2 changes: 1 addition & 1 deletion clang/utils/TableGen/ClangSyntaxEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ using llvm::formatv;
// stable and useful way, where abstract Node subclasses correspond to ranges.
class Hierarchy {
public:
Hierarchy(const llvm::RecordKeeper &Records) {
Hierarchy(llvm::RecordKeeper &Records) {
for (llvm::Record *T : Records.getAllDerivedDefinitions("NodeType"))
add(T);
for (llvm::Record *Derived : Records.getAllDerivedDefinitions("NodeType"))
Expand Down
5 changes: 2 additions & 3 deletions llvm/include/llvm/TableGen/DirectiveEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ namespace llvm {
// DirectiveBase.td and provides helper methods for accessing it.
class DirectiveLanguage {
public:
explicit DirectiveLanguage(const llvm::RecordKeeper &Records)
: Records(Records) {
explicit DirectiveLanguage(llvm::RecordKeeper &Records) : Records(Records) {
const auto &DirectiveLanguages = getDirectiveLanguages();
Def = DirectiveLanguages[0];
}
Expand Down Expand Up @@ -71,7 +70,7 @@ class DirectiveLanguage {

private:
const llvm::Record *Def;
const llvm::RecordKeeper &Records;
llvm::RecordKeeper &Records;

std::vector<Record *> getDirectiveLanguages() const {
return Records.getAllDerivedDefinitions("DirectiveLanguage");
Expand Down
34 changes: 29 additions & 5 deletions llvm/include/llvm/TableGen/Record.h
Original file line number Diff line number Diff line change
Expand Up @@ -2057,19 +2057,28 @@ class RecordKeeper {
//===--------------------------------------------------------------------===//
// High-level helper methods, useful for tablegen backends.

// Non-const methods return std::vector<Record *> by value or reference.
// Const methods return std::vector<const Record *> by value or
// ArrayRef<const Record *>.

/// Get all the concrete records that inherit from the one specified
/// class. The class must be defined.
std::vector<Record *> getAllDerivedDefinitions(StringRef ClassName) const;
ArrayRef<const Record *> getAllDerivedDefinitions(StringRef ClassName) const;
const std::vector<Record *> &getAllDerivedDefinitions(StringRef ClassName);

/// Get all the concrete records that inherit from all the specified
/// classes. The classes must be defined.
std::vector<Record *> getAllDerivedDefinitions(
ArrayRef<StringRef> ClassNames) const;
std::vector<const Record *>
getAllDerivedDefinitions(ArrayRef<StringRef> ClassNames) const;
std::vector<Record *>
getAllDerivedDefinitions(ArrayRef<StringRef> ClassNames);

/// Get all the concrete records that inherit from specified class, if the
/// class is defined. Returns an empty vector if the class is not defined.
std::vector<Record *>
ArrayRef<const Record *>
getAllDerivedDefinitionsIfDefined(StringRef ClassName) const;
const std::vector<Record *> &
getAllDerivedDefinitionsIfDefined(StringRef ClassName);

void dump() const;

Expand All @@ -2081,9 +2090,24 @@ class RecordKeeper {
RecordKeeper &operator=(RecordKeeper &&) = delete;
RecordKeeper &operator=(const RecordKeeper &) = delete;

// Helper template functions for backend accessors.
template <typename VecTy>
const VecTy &
getAllDerivedDefinitionsImpl(StringRef ClassName,
std::map<std::string, VecTy> &Cache) const;

template <typename VecTy>
VecTy getAllDerivedDefinitionsImpl(ArrayRef<StringRef> ClassNames) const;

template <typename VecTy>
const VecTy &getAllDerivedDefinitionsIfDefinedImpl(
StringRef ClassName, std::map<std::string, VecTy> &Cache) const;

std::string InputFilename;
RecordMap Classes, Defs;
mutable StringMap<std::vector<Record *>> ClassRecordsMap;
mutable std::map<std::string, std::vector<const Record *>>
ClassRecordsMapConst;
mutable std::map<std::string, std::vector<Record *>> ClassRecordsMap;
GlobalMap ExtraGlobals;

// These members are for the phase timing feature. We need a timer group,
Expand Down
65 changes: 51 additions & 14 deletions llvm/lib/TableGen/Record.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3248,46 +3248,83 @@ void RecordKeeper::stopBackendTimer() {
}
}

std::vector<Record *>
RecordKeeper::getAllDerivedDefinitions(StringRef ClassName) const {
template <typename VecTy>
const VecTy &RecordKeeper::getAllDerivedDefinitionsImpl(
StringRef ClassName, std::map<std::string, VecTy> &Cache) const {
// We cache the record vectors for single classes. Many backends request
// the same vectors multiple times.
auto Pair = ClassRecordsMap.try_emplace(ClassName);
auto Pair = Cache.try_emplace(ClassName.str());
if (Pair.second)
Pair.first->second = getAllDerivedDefinitions(ArrayRef(ClassName));
Pair.first->second =
getAllDerivedDefinitionsImpl<VecTy>(ArrayRef(ClassName));

return Pair.first->second;
}

std::vector<Record *> RecordKeeper::getAllDerivedDefinitions(
template <typename VecTy>
VecTy RecordKeeper::getAllDerivedDefinitionsImpl(
ArrayRef<StringRef> ClassNames) const {
SmallVector<Record *, 2> ClassRecs;
std::vector<Record *> Defs;
SmallVector<const Record *, 2> ClassRecs;
VecTy Defs;

assert(ClassNames.size() > 0 && "At least one class must be passed.");
for (const auto &ClassName : ClassNames) {
Record *Class = getClass(ClassName);
const Record *Class = getClass(ClassName);
if (!Class)
PrintFatalError("The class '" + ClassName + "' is not defined\n");
ClassRecs.push_back(Class);
}

for (const auto &OneDef : getDefs()) {
if (all_of(ClassRecs, [&OneDef](const Record *Class) {
return OneDef.second->isSubClassOf(Class);
}))
return OneDef.second->isSubClassOf(Class);
}))
Defs.push_back(OneDef.second.get());
}

llvm::sort(Defs, LessRecord());

return Defs;
}

template <typename VecTy>
const VecTy &RecordKeeper::getAllDerivedDefinitionsIfDefinedImpl(
StringRef ClassName, std::map<std::string, VecTy> &Cache) const {
return getClass(ClassName)
? getAllDerivedDefinitionsImpl<VecTy>(ClassName, Cache)
: Cache[""];
}

ArrayRef<const Record *>
RecordKeeper::getAllDerivedDefinitions(StringRef ClassName) const {
return getAllDerivedDefinitionsImpl<std::vector<const Record *>>(
ClassName, ClassRecordsMapConst);
}

const std::vector<Record *> &
RecordKeeper::getAllDerivedDefinitions(StringRef ClassName) {
return getAllDerivedDefinitionsImpl<std::vector<Record *>>(ClassName,
ClassRecordsMap);
}

std::vector<const Record *>
RecordKeeper::getAllDerivedDefinitions(ArrayRef<StringRef> ClassNames) const {
return getAllDerivedDefinitionsImpl<std::vector<const Record *>>(ClassNames);
}

std::vector<Record *>
RecordKeeper::getAllDerivedDefinitions(ArrayRef<StringRef> ClassNames) {
return getAllDerivedDefinitionsImpl<std::vector<Record *>>(ClassNames);
}

ArrayRef<const Record *>
RecordKeeper::getAllDerivedDefinitionsIfDefined(StringRef ClassName) const {
return getClass(ClassName) ? getAllDerivedDefinitions(ClassName)
: std::vector<Record *>();
return getAllDerivedDefinitionsIfDefinedImpl<std::vector<const Record *>>(
ClassName, ClassRecordsMapConst);
}

const std::vector<Record *> &
RecordKeeper::getAllDerivedDefinitionsIfDefined(StringRef ClassName) {
return getAllDerivedDefinitionsIfDefinedImpl<std::vector<Record *>>(
ClassName, ClassRecordsMap);
}

void RecordKeeper::dumpAllocationStats(raw_ostream &OS) const {
Expand Down
2 changes: 1 addition & 1 deletion llvm/utils/TableGen/Basic/CodeGenIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ CodeGenIntrinsicContext::CodeGenIntrinsicContext(const RecordKeeper &RC) {
CodeGenIntrinsicTable::CodeGenIntrinsicTable(const RecordKeeper &RC) {
CodeGenIntrinsicContext Ctx(RC);

std::vector<Record *> Defs = RC.getAllDerivedDefinitions("Intrinsic");
ArrayRef<const Record *> Defs = RC.getAllDerivedDefinitions("Intrinsic");
Intrinsics.reserve(Defs.size());

for (const Record *Def : Defs)
Expand Down
2 changes: 1 addition & 1 deletion llvm/utils/TableGen/Common/SubtargetFeatureInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ LLVM_DUMP_METHOD void SubtargetFeatureInfo::dump() const {
#endif

std::vector<std::pair<Record *, SubtargetFeatureInfo>>
SubtargetFeatureInfo::getAll(const RecordKeeper &Records) {
SubtargetFeatureInfo::getAll(RecordKeeper &Records) {
std::vector<std::pair<Record *, SubtargetFeatureInfo>> SubtargetFeatures;
std::vector<Record *> AllPredicates =
Records.getAllDerivedDefinitions("Predicate");
Expand Down
2 changes: 1 addition & 1 deletion llvm/utils/TableGen/Common/SubtargetFeatureInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct SubtargetFeatureInfo {

void dump() const;
static std::vector<std::pair<Record *, SubtargetFeatureInfo>>
getAll(const RecordKeeper &Records);
getAll(RecordKeeper &Records);

/// Emit the subtarget feature flag definitions.
///
Expand Down
2 changes: 1 addition & 1 deletion llvm/utils/TableGen/ExegesisEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class ExegesisEmitter {
};

static std::map<llvm::StringRef, unsigned>
collectPfmCounters(const RecordKeeper &Records) {
collectPfmCounters(RecordKeeper &Records) {
std::map<llvm::StringRef, unsigned> PfmCounterNameTable;
const auto AddPfmCounterName = [&PfmCounterNameTable](
const Record *PfmCounterDef) {
Expand Down
2 changes: 1 addition & 1 deletion llvm/utils/TableGen/GlobalISelEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ class GlobalISelEmitter final : public GlobalISelMatchTableExecutorEmitter {
private:
std::string ClassName;

const RecordKeeper &RK;
RecordKeeper &RK;
const CodeGenDAGPatterns CGP;
const CodeGenTarget &Target;
CodeGenRegBank &CGRegs;
Expand Down
2 changes: 1 addition & 1 deletion llvm/utils/TableGen/SubtargetEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1545,7 +1545,7 @@ void SubtargetEmitter::EmitSchedModel(raw_ostream &OS) {
EmitProcessorModels(OS);
}

static void emitPredicateProlog(const RecordKeeper &Records, raw_ostream &OS) {
static void emitPredicateProlog(RecordKeeper &Records, raw_ostream &OS) {
std::string Buffer;
raw_string_ostream Stream(Buffer);

Expand Down
2 changes: 1 addition & 1 deletion llvm/utils/TableGen/TableGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ static void PrintEnums(RecordKeeper &Records, raw_ostream &OS) {
static void PrintSets(const RecordKeeper &Records, raw_ostream &OS) {
SetTheory Sets;
Sets.addFieldExpander("Set", "Elements");
for (Record *Rec : Records.getAllDerivedDefinitions("Set")) {
for (const Record *Rec : Records.getAllDerivedDefinitions("Set")) {
OS << Rec->getName() << " = [";
const std::vector<Record *> *Elts = Sets.expand(Rec);
assert(Elts && "Couldn't expand Set instance");
Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/TableGen/GenInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class RecordKeeper;
namespace mlir {

/// Generator function to invoke.
using GenFunction = std::function<bool(const llvm::RecordKeeper &recordKeeper,
raw_ostream &os)>;
using GenFunction =
std::function<bool(llvm::RecordKeeper &recordKeeper, raw_ostream &os)>;

/// Structure to group information about a generator (argument to invoke via
/// mlir-tblgen, description, and generator function).
Expand All @@ -34,7 +34,7 @@ class GenInfo {
: arg(arg), description(description), generator(std::move(generator)) {}

/// Invokes the generator and returns whether the generator failed.
bool invoke(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) const {
bool invoke(llvm::RecordKeeper &recordKeeper, raw_ostream &os) const {
assert(generator && "Cannot call generator with null generator");
return generator(recordKeeper, os);
}
Expand Down
28 changes: 14 additions & 14 deletions mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -690,10 +690,10 @@ class DefGenerator {
bool emitDefs(StringRef selectedDialect);

protected:
DefGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os,
DefGenerator(const std::vector<llvm::Record *> &defs, raw_ostream &os,
StringRef defType, StringRef valueType, bool isAttrGenerator)
: defRecords(std::move(defs)), os(os), defType(defType),
valueType(valueType), isAttrGenerator(isAttrGenerator) {
: defRecords(defs), os(os), defType(defType), valueType(valueType),
isAttrGenerator(isAttrGenerator) {
// Sort by occurrence in file.
llvm::sort(defRecords, [](llvm::Record *lhs, llvm::Record *rhs) {
return lhs->getID() < rhs->getID();
Expand Down Expand Up @@ -721,13 +721,13 @@ class DefGenerator {

/// A specialized generator for AttrDefs.
struct AttrDefGenerator : public DefGenerator {
AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
AttrDefGenerator(llvm::RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
"Attr", "Attribute", /*isAttrGenerator=*/true) {}
};
/// A specialized generator for TypeDefs.
struct TypeDefGenerator : public DefGenerator {
TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
TypeDefGenerator(llvm::RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
"Type", "Type", /*isAttrGenerator=*/false) {}
};
Expand Down Expand Up @@ -1029,7 +1029,7 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {

/// Find all type constraints for which a C++ function should be generated.
static std::vector<Constraint>
getAllTypeConstraints(const llvm::RecordKeeper &records) {
getAllTypeConstraints(llvm::RecordKeeper &records) {
std::vector<Constraint> result;
for (llvm::Record *def :
records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
Expand All @@ -1046,7 +1046,7 @@ getAllTypeConstraints(const llvm::RecordKeeper &records) {
return result;
}

static void emitTypeConstraintDecls(const llvm::RecordKeeper &records,
static void emitTypeConstraintDecls(llvm::RecordKeeper &records,
raw_ostream &os) {
static const char *const typeConstraintDecl = R"(
bool {0}(::mlir::Type type);
Expand All @@ -1056,7 +1056,7 @@ bool {0}(::mlir::Type type);
os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
}

static void emitTypeConstraintDefs(const llvm::RecordKeeper &records,
static void emitTypeConstraintDefs(llvm::RecordKeeper &records,
raw_ostream &os) {
static const char *const typeConstraintDef = R"(
bool {0}(::mlir::Type type) {
Expand Down Expand Up @@ -1087,13 +1087,13 @@ static llvm::cl::opt<std::string>

static mlir::GenRegistration
genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](llvm::RecordKeeper &records, raw_ostream &os) {
AttrDefGenerator generator(records, os);
return generator.emitDefs(attrDialect);
});
static mlir::GenRegistration
genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](llvm::RecordKeeper &records, raw_ostream &os) {
AttrDefGenerator generator(records, os);
return generator.emitDecls(attrDialect);
});
Expand All @@ -1109,28 +1109,28 @@ static llvm::cl::opt<std::string>

static mlir::GenRegistration
genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](llvm::RecordKeeper &records, raw_ostream &os) {
TypeDefGenerator generator(records, os);
return generator.emitDefs(typeDialect);
});
static mlir::GenRegistration
genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](llvm::RecordKeeper &records, raw_ostream &os) {
TypeDefGenerator generator(records, os);
return generator.emitDecls(typeDialect);
});

static mlir::GenRegistration
genTypeConstrDefs("gen-type-constraint-defs",
"Generate type constraint definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](llvm::RecordKeeper &records, raw_ostream &os) {
emitTypeConstraintDefs(records, os);
return false;
});
static mlir::GenRegistration
genTypeConstrDecls("gen-type-constraint-decls",
"Generate type constraint declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](llvm::RecordKeeper &records, raw_ostream &os) {
emitTypeConstraintDecls(records, os);
return false;
});
Loading
Loading