Skip to content

Commit 3dff20c

Browse files
author
mlevesquedion
authored
[mlir] Reformat whitespace in dependent dialects codegen (#78090)
The generated code for dependent dialects is awkwardly formatted, making the code harder to read. This change reformats the whitespace to align code in its context and avoid unnecessary empty lines. Also included are some typo fixes. Below are examples of the codegen for a dialect before and after the change. Before: ``` GPUDialect::GPUDialect(::mlir::MLIRContext *context) : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<GPUDialect>()) { getContext()->loadDialect<arith::ArithDialect>(); initialize(); } ``` After: ``` GPUDialect::GPUDialect(::mlir::MLIRContext *context) : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<GPUDialect>()) { getContext()->loadDialect<arith::ArithDialect>(); initialize(); } ``` Below are examples of the codegen for a pass before and after the change. Before: ``` /// Return the dialect that must be loaded in the context before this pass. void getDependentDialects(::mlir::DialectRegistry &registry) const override { registry.insert<func::FuncDialect>(); registry.insert<tensor::TensorDialect>(); registry.insert<tosa::TosaDialect>(); } ``` After: ``` /// Register the dialects that must be loaded in the context before this pass. void getDependentDialects(::mlir::DialectRegistry &registry) const override { registry.insert<func::FuncDialect>(); registry.insert<tensor::TensorDialect>(); registry.insert<tosa::TosaDialect>(); } ```
1 parent b61e5b0 commit 3dff20c

File tree

2 files changed

+30
-20
lines changed

2 files changed

+30
-20
lines changed

mlir/tools/mlir-tblgen/DialectGen.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,8 @@ class {0} : public ::mlir::{2} {
106106

107107
/// Registration for a single dependent dialect: to be inserted in the ctor
108108
/// above for each dependent dialect.
109-
const char *const dialectRegistrationTemplate = R"(
110-
getContext()->loadDialect<{0}>();
111-
)";
109+
const char *const dialectRegistrationTemplate =
110+
"getContext()->loadDialect<{0}>();";
112111

113112
/// The code block for the attribute parser/printer hooks.
114113
static const char *const attrParserDecl = R"(
@@ -250,8 +249,8 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
250249
/// The code block to generate a dialect constructor definition.
251250
///
252251
/// {0}: The name of the dialect class.
253-
/// {1}: initialization code that is emitted in the ctor body before calling
254-
/// initialize().
252+
/// {1}: Initialization code that is emitted in the ctor body before calling
253+
/// initialize(), such as dependent dialect registration.
255254
/// {2}: The dialect parent class.
256255
static const char *const dialectConstructorStr = R"(
257256
{0}::{0}(::mlir::MLIRContext *context)
@@ -261,7 +260,7 @@ static const char *const dialectConstructorStr = R"(
261260
}
262261
)";
263262

264-
/// The code block to generate a default desturctor definition.
263+
/// The code block to generate a default destructor definition.
265264
///
266265
/// {0}: The name of the dialect class.
267266
static const char *const dialectDestructorStr = R"(
@@ -284,9 +283,13 @@ static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
284283
std::string dependentDialectRegistrations;
285284
{
286285
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
287-
for (StringRef dependentDialect : dialect.getDependentDialects())
288-
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
289-
dependentDialect);
286+
llvm::interleave(
287+
dialect.getDependentDialects(), dialectsOs,
288+
[&](StringRef dependentDialect) {
289+
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
290+
dependentDialect);
291+
},
292+
"\n ");
290293
}
291294

292295
// Emit the constructor and destructor.

mlir/tools/mlir-tblgen/PassGen.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
173173
/// {0}: The def name of the pass record.
174174
/// {1}: The base class for the pass.
175175
/// {2): The command line argument for the pass.
176-
/// {3}: The dependent dialects registration.
176+
/// {3}: The summary for the pass.
177+
/// {4}: The dependent dialects registration.
177178
const char *const baseClassBegin = R"(
178179
template <typename DerivedT>
179180
class {0}Base : public {1} {
@@ -221,9 +222,7 @@ class {0}Base : public {1} {
221222

222223
/// Registration for a single dependent dialect, to be inserted for each
223224
/// dependent dialect in the `getDependentDialects` above.
224-
const char *const dialectRegistrationTemplate = R"(
225-
registry.insert<{0}>();
226-
)";
225+
const char *const dialectRegistrationTemplate = "registry.insert<{0}>();";
227226

228227
const char *const friendDefaultConstructorDeclTemplate = R"(
229228
namespace impl {{
@@ -307,9 +306,13 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
307306
std::string dependentDialectRegistrations;
308307
{
309308
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
310-
for (StringRef dependentDialect : pass.getDependentDialects())
311-
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
312-
dependentDialect);
309+
llvm::interleave(
310+
pass.getDependentDialects(), dialectsOs,
311+
[&](StringRef dependentDialect) {
312+
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
313+
dependentDialect);
314+
},
315+
"\n ");
313316
}
314317

315318
os << "namespace impl {\n";
@@ -402,7 +405,7 @@ class {0}Base : public {1} {
402405
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
403406
}
404407
405-
/// Return the dialect that must be loaded in the context before this pass.
408+
/// Register the dialects that must be loaded in the context before this pass.
406409
void getDependentDialects(::mlir::DialectRegistry &registry) const override {
407410
{4}
408411
}
@@ -422,9 +425,13 @@ static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
422425
std::string dependentDialectRegistrations;
423426
{
424427
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
425-
for (StringRef dependentDialect : pass.getDependentDialects())
426-
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
427-
dependentDialect);
428+
llvm::interleave(
429+
pass.getDependentDialects(), dialectsOs,
430+
[&](StringRef dependentDialect) {
431+
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
432+
dependentDialect);
433+
},
434+
"\n ");
428435
}
429436
os << llvm::formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
430437
pass.getArgument(), pass.getSummary(),

0 commit comments

Comments
 (0)