Skip to content

[NFC][MLIR][TableGen] Eliminate llvm:: for common types in LSP Server #110867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 51 additions & 55 deletions mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
#include <optional>

using namespace mlir;
using llvm::Record;
using llvm::RecordKeeper;
using llvm::RecordVal;
using llvm::SourceMgr;

/// Returns the range of a lexical token given a SMLoc corresponding to the
/// start of an token location. The range is computed heuristically, and
Expand All @@ -32,7 +36,7 @@ static SMRange convertTokenLocToRange(SMLoc loc) {

/// Returns a language server uri for the given source location. `mainFileURI`
/// corresponds to the uri for the main file of the source manager.
static lsp::URIForFile getURIFromLoc(const llvm::SourceMgr &mgr, SMLoc loc,
static lsp::URIForFile getURIFromLoc(const SourceMgr &mgr, SMLoc loc,
const lsp::URIForFile &mainFileURI) {
int bufferId = mgr.FindBufferContainingLoc(loc);
if (bufferId == 0 || bufferId == static_cast<int>(mgr.getMainFileID()))
Expand All @@ -47,12 +51,12 @@ static lsp::URIForFile getURIFromLoc(const llvm::SourceMgr &mgr, SMLoc loc,
}

/// Returns a language server location from the given source range.
static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange loc,
static lsp::Location getLocationFromLoc(SourceMgr &mgr, SMRange loc,
const lsp::URIForFile &uri) {
return lsp::Location(getURIFromLoc(mgr, loc.Start, uri),
lsp::Range(mgr, loc));
}
static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMLoc loc,
static lsp::Location getLocationFromLoc(SourceMgr &mgr, SMLoc loc,
const lsp::URIForFile &uri) {
return getLocationFromLoc(mgr, convertTokenLocToRange(loc), uri);
}
Expand All @@ -61,7 +65,7 @@ static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMLoc loc,
static std::optional<lsp::Diagnostic>
getLspDiagnoticFromDiag(const llvm::SMDiagnostic &diag,
const lsp::URIForFile &uri) {
auto *sourceMgr = const_cast<llvm::SourceMgr *>(diag.getSourceMgr());
auto *sourceMgr = const_cast<SourceMgr *>(diag.getSourceMgr());
if (!sourceMgr || !diag.getLoc().isValid())
return std::nullopt;

Expand All @@ -79,17 +83,17 @@ getLspDiagnoticFromDiag(const llvm::SMDiagnostic &diag,

// Convert the severity for the diagnostic.
switch (diag.getKind()) {
case llvm::SourceMgr::DK_Warning:
case SourceMgr::DK_Warning:
lspDiag.severity = lsp::DiagnosticSeverity::Warning;
break;
case llvm::SourceMgr::DK_Error:
case SourceMgr::DK_Error:
lspDiag.severity = lsp::DiagnosticSeverity::Error;
break;
case llvm::SourceMgr::DK_Note:
case SourceMgr::DK_Note:
// Notes are emitted separately from the main diagnostic, so we just treat
// them as remarks given that we can't determine the diagnostic to relate
// them to.
case llvm::SourceMgr::DK_Remark:
case SourceMgr::DK_Remark:
lspDiag.severity = lsp::DiagnosticSeverity::Information;
break;
}
Expand All @@ -100,16 +104,15 @@ getLspDiagnoticFromDiag(const llvm::SMDiagnostic &diag,

/// Get the base definition of the given record value, or nullptr if one
/// couldn't be found.
static std::pair<const llvm::Record *, const llvm::RecordVal *>
getBaseValue(const llvm::Record *record, const llvm::RecordVal *value) {
static std::pair<const Record *, const RecordVal *>
getBaseValue(const Record *record, const RecordVal *value) {
if (value->isTemplateArg())
return {nullptr, nullptr};

// Find a base value for the field in the super classes of the given record.
// On success, `record` is updated to the new parent record.
StringRef valueName = value->getName();
auto findValueInSupers =
[&](const llvm::Record *&record) -> llvm::RecordVal * {
auto findValueInSupers = [&](const Record *&record) -> RecordVal * {
for (auto [parentRecord, loc] : record->getSuperClasses()) {
if (auto *newBase = parentRecord->getValue(valueName)) {
record = parentRecord;
Expand All @@ -120,8 +123,8 @@ getBaseValue(const llvm::Record *record, const llvm::RecordVal *value) {
};

// Try to find the lowest definition of the record value.
std::pair<const llvm::Record *, const llvm::RecordVal *> baseValue = {};
while (const llvm::RecordVal *newBase = findValueInSupers(record))
std::pair<const Record *, const RecordVal *> baseValue = {};
while (const RecordVal *newBase = findValueInSupers(record))
baseValue = {record, newBase};

// Check that the base isn't the same as the current value (e.g. if the value
Expand All @@ -140,15 +143,15 @@ namespace {
/// contains the definition of the symbol, the location of the symbol, and any
/// recorded references.
struct TableGenIndexSymbol {
TableGenIndexSymbol(const llvm::Record *record)
TableGenIndexSymbol(const Record *record)
: definition(record),
defLoc(convertTokenLocToRange(record->getLoc().front())) {}
TableGenIndexSymbol(const llvm::RecordVal *value)
TableGenIndexSymbol(const RecordVal *value)
: definition(value), defLoc(convertTokenLocToRange(value->getLoc())) {}
virtual ~TableGenIndexSymbol() = default;

// The main definition of the symbol.
PointerUnion<const llvm::Record *, const llvm::RecordVal *> definition;
PointerUnion<const Record *, const RecordVal *> definition;

/// The source location of the definition.
SMRange defLoc;
Expand All @@ -158,37 +161,33 @@ struct TableGenIndexSymbol {
};
/// This class represents a single record symbol.
struct TableGenRecordSymbol : public TableGenIndexSymbol {
TableGenRecordSymbol(const llvm::Record *record)
: TableGenIndexSymbol(record) {}
TableGenRecordSymbol(const Record *record) : TableGenIndexSymbol(record) {}
~TableGenRecordSymbol() override = default;

static bool classof(const TableGenIndexSymbol *symbol) {
return symbol->definition.is<const llvm::Record *>();
return symbol->definition.is<const Record *>();
}

/// Return the value of this symbol.
const llvm::Record *getValue() const {
return definition.get<const llvm::Record *>();
}
const Record *getValue() const { return definition.get<const Record *>(); }
};
/// This class represents a single record value symbol.
struct TableGenRecordValSymbol : public TableGenIndexSymbol {
TableGenRecordValSymbol(const llvm::Record *record,
const llvm::RecordVal *value)
TableGenRecordValSymbol(const Record *record, const RecordVal *value)
: TableGenIndexSymbol(value), record(record) {}
~TableGenRecordValSymbol() override = default;

static bool classof(const TableGenIndexSymbol *symbol) {
return symbol->definition.is<const llvm::RecordVal *>();
return symbol->definition.is<const RecordVal *>();
}

/// Return the value of this symbol.
const llvm::RecordVal *getValue() const {
return definition.get<const llvm::RecordVal *>();
const RecordVal *getValue() const {
return definition.get<const RecordVal *>();
}

/// The parent record of this symbol.
const llvm::Record *record;
const Record *record;
};

/// This class provides an index for definitions/uses within a TableGen
Expand All @@ -199,7 +198,7 @@ class TableGenIndex {
TableGenIndex() : intervalMap(allocator) {}

/// Initialize the index with the given RecordKeeper.
void initialize(const llvm::RecordKeeper &records);
void initialize(const RecordKeeper &records);

/// Lookup a symbol for the given location. Returns nullptr if no symbol could
/// be found. If provided, `overlappedRange` is set to the range that the
Expand All @@ -217,15 +216,15 @@ class TableGenIndex {
llvm::IntervalMapHalfOpenInfo<const char *>>;

/// Get or insert a symbol for the given record.
TableGenIndexSymbol *getOrInsertDef(const llvm::Record *record) {
TableGenIndexSymbol *getOrInsertDef(const Record *record) {
auto it = defToSymbol.try_emplace(record, nullptr);
if (it.second)
it.first->second = std::make_unique<TableGenRecordSymbol>(record);
return &*it.first->second;
}
/// Get or insert a symbol for the given record value.
TableGenIndexSymbol *getOrInsertDef(const llvm::Record *record,
const llvm::RecordVal *value) {
TableGenIndexSymbol *getOrInsertDef(const Record *record,
const RecordVal *value) {
auto it = defToSymbol.try_emplace(value, nullptr);
if (it.second) {
it.first->second =
Expand All @@ -246,7 +245,7 @@ class TableGenIndex {
};
} // namespace

void TableGenIndex::initialize(const llvm::RecordKeeper &records) {
void TableGenIndex::initialize(const RecordKeeper &records) {
intervalMap.clear();
defToSymbol.clear();

Expand Down Expand Up @@ -282,7 +281,7 @@ void TableGenIndex::initialize(const llvm::RecordKeeper &records) {
llvm::make_pointee_range(llvm::make_second_range(records.getClasses()));
auto defs =
llvm::make_pointee_range(llvm::make_second_range(records.getDefs()));
for (const llvm::Record &def : llvm::concat<llvm::Record>(classes, defs)) {
for (const Record &def : llvm::concat<Record>(classes, defs)) {
auto *sym = getOrInsertDef(&def);
insertRef(sym, sym->defLoc, /*isDef=*/true);

Expand All @@ -293,7 +292,7 @@ void TableGenIndex::initialize(const llvm::RecordKeeper &records) {
insertRef(sym, loc);

// Add definitions for any values.
for (const llvm::RecordVal &value : def.getValues()) {
for (const RecordVal &value : def.getValues()) {
auto *sym = getOrInsertDef(&def, &value);
insertRef(sym, sym->defLoc, /*isDef=*/true);
for (SMRange refLoc : value.getReferenceLocs())
Expand Down Expand Up @@ -359,13 +358,12 @@ class TableGenTextFile {

std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
const lsp::Position &hoverPos);
lsp::Hover buildHoverForRecord(const llvm::Record *record,
lsp::Hover buildHoverForRecord(const Record *record,
const SMRange &hoverRange);
lsp::Hover buildHoverForTemplateArg(const llvm::Record *record,
const llvm::RecordVal *value,
lsp::Hover buildHoverForTemplateArg(const Record *record,
const RecordVal *value,
const SMRange &hoverRange);
lsp::Hover buildHoverForField(const llvm::Record *record,
const llvm::RecordVal *value,
lsp::Hover buildHoverForField(const Record *record, const RecordVal *value,
const SMRange &hoverRange);

private:
Expand All @@ -383,10 +381,10 @@ class TableGenTextFile {
std::vector<std::string> includeDirs;

/// The source manager containing the contents of the input file.
llvm::SourceMgr sourceMgr;
SourceMgr sourceMgr;

/// The record keeper containing the parsed tablegen constructs.
std::unique_ptr<llvm::RecordKeeper> recordKeeper;
std::unique_ptr<RecordKeeper> recordKeeper;

/// The index of the parsed file.
TableGenIndex index;
Expand Down Expand Up @@ -430,8 +428,8 @@ void TableGenTextFile::initialize(const lsp::URIForFile &uri,
int64_t newVersion,
std::vector<lsp::Diagnostic> &diagnostics) {
version = newVersion;
sourceMgr = llvm::SourceMgr();
recordKeeper = std::make_unique<llvm::RecordKeeper>();
sourceMgr = SourceMgr();
recordKeeper = std::make_unique<RecordKeeper>();

// Build a buffer for this file.
auto memBuffer = llvm::MemoryBuffer::getMemBuffer(contents, uri.file());
Expand All @@ -442,7 +440,7 @@ void TableGenTextFile::initialize(const lsp::URIForFile &uri,
sourceMgr.setIncludeDirs(includeDirs);
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());

// This class provides a context argument for the llvm::SourceMgr diagnostic
// This class provides a context argument for the SourceMgr diagnostic
// handler.
struct DiagHandlerContext {
std::vector<lsp::Diagnostic> &diagnostics;
Expand Down Expand Up @@ -543,13 +541,13 @@ TableGenTextFile::findHover(const lsp::URIForFile &uri,
// Build hover for a RecordVal, which is either a template argument or a
// field.
auto *recordVal = cast<TableGenRecordValSymbol>(symbol);
const llvm::RecordVal *value = recordVal->getValue();
const RecordVal *value = recordVal->getValue();
if (value->isTemplateArg())
return buildHoverForTemplateArg(recordVal->record, value, hoverRange);
return buildHoverForField(recordVal->record, value, hoverRange);
}

lsp::Hover TableGenTextFile::buildHoverForRecord(const llvm::Record *record,
lsp::Hover TableGenTextFile::buildHoverForRecord(const Record *record,
const SMRange &hoverRange) {
lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
{
Expand All @@ -570,7 +568,7 @@ lsp::Hover TableGenTextFile::buildHoverForRecord(const llvm::Record *record,
auto printAndFormatField = [&](StringRef fieldName) {
// Check that the record actually has the given field, and that it's a
// string.
const llvm::RecordVal *value = record->getValue(fieldName);
const RecordVal *value = record->getValue(fieldName);
if (!value || !value->getValue())
return;
auto *stringValue = dyn_cast<llvm::StringInit>(value->getValue());
Expand All @@ -593,10 +591,8 @@ lsp::Hover TableGenTextFile::buildHoverForRecord(const llvm::Record *record,
return hover;
}

lsp::Hover
TableGenTextFile::buildHoverForTemplateArg(const llvm::Record *record,
const llvm::RecordVal *value,
const SMRange &hoverRange) {
lsp::Hover TableGenTextFile::buildHoverForTemplateArg(
const Record *record, const RecordVal *value, const SMRange &hoverRange) {
lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
{
llvm::raw_string_ostream hoverOS(hover.contents.value);
Expand All @@ -609,8 +605,8 @@ TableGenTextFile::buildHoverForTemplateArg(const llvm::Record *record,
return hover;
}

lsp::Hover TableGenTextFile::buildHoverForField(const llvm::Record *record,
const llvm::RecordVal *value,
lsp::Hover TableGenTextFile::buildHoverForField(const Record *record,
const RecordVal *value,
const SMRange &hoverRange) {
lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
{
Expand Down
Loading