Skip to content

Reformat whitespace in dependent dialects codegen #78090

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions mlir/tools/mlir-tblgen/DialectGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,8 @@ class {0} : public ::mlir::{2} {

/// Registration for a single dependent dialect: to be inserted in the ctor
/// above for each dependent dialect.
const char *const dialectRegistrationTemplate = R"(
getContext()->loadDialect<{0}>();
)";
const char *const dialectRegistrationTemplate =
"getContext()->loadDialect<{0}>();";

/// The code block for the attribute parser/printer hooks.
static const char *const attrParserDecl = R"(
Expand Down Expand Up @@ -250,8 +249,8 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
/// The code block to generate a dialect constructor definition.
///
/// {0}: The name of the dialect class.
/// {1}: initialization code that is emitted in the ctor body before calling
/// initialize().
/// {1}: Initialization code that is emitted in the ctor body before calling
/// initialize(), such as dependent dialect registration.
/// {2}: The dialect parent class.
static const char *const dialectConstructorStr = R"(
{0}::{0}(::mlir::MLIRContext *context)
Expand All @@ -261,7 +260,7 @@ static const char *const dialectConstructorStr = R"(
}
)";

/// The code block to generate a default desturctor definition.
/// The code block to generate a default destructor definition.
///
/// {0}: The name of the dialect class.
static const char *const dialectDestructorStr = R"(
Expand All @@ -284,9 +283,13 @@ static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
std::string dependentDialectRegistrations;
{
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
for (StringRef dependentDialect : dialect.getDependentDialects())
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
dependentDialect);
llvm::interleave(
dialect.getDependentDialects(), dialectsOs,
[&](StringRef dependentDialect) {
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
dependentDialect);
},
"\n ");
}

// Emit the constructor and destructor.
Expand Down
29 changes: 18 additions & 11 deletions mlir/tools/mlir-tblgen/PassGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
/// {0}: The def name of the pass record.
/// {1}: The base class for the pass.
/// {2): The command line argument for the pass.
/// {3}: The dependent dialects registration.
/// {3}: The summary for the pass.
/// {4}: The dependent dialects registration.
const char *const baseClassBegin = R"(
template <typename DerivedT>
class {0}Base : public {1} {
Expand Down Expand Up @@ -221,9 +222,7 @@ class {0}Base : public {1} {

/// Registration for a single dependent dialect, to be inserted for each
/// dependent dialect in the `getDependentDialects` above.
const char *const dialectRegistrationTemplate = R"(
registry.insert<{0}>();
)";
const char *const dialectRegistrationTemplate = "registry.insert<{0}>();";

const char *const friendDefaultConstructorDeclTemplate = R"(
namespace impl {{
Expand Down Expand Up @@ -307,9 +306,13 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
std::string dependentDialectRegistrations;
{
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
for (StringRef dependentDialect : pass.getDependentDialects())
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
dependentDialect);
llvm::interleave(
pass.getDependentDialects(), dialectsOs,
[&](StringRef dependentDialect) {
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
dependentDialect);
},
"\n ");
}

os << "namespace impl {\n";
Expand Down Expand Up @@ -402,7 +405,7 @@ class {0}Base : public {1} {
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
}

/// Return the dialect that must be loaded in the context before this pass.
/// Register the dialects that must be loaded in the context before this pass.
void getDependentDialects(::mlir::DialectRegistry &registry) const override {
{4}
}
Expand All @@ -422,9 +425,13 @@ static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
std::string dependentDialectRegistrations;
{
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
for (StringRef dependentDialect : pass.getDependentDialects())
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
dependentDialect);
llvm::interleave(
pass.getDependentDialects(), dialectsOs,
[&](StringRef dependentDialect) {
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
dependentDialect);
},
"\n ");
}
os << llvm::formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
pass.getArgument(), pass.getSummary(),
Expand Down