Skip to content

Commit 4957518

Browse files
committed
[mlir][ods] Simplify useDefaultType/AttributePrinterParser
The current behaviour of `useDefaultTypePrinterParser` and `useDefaultAttributePrinterParser` is that they are set by default, but the dialect generator only generates the declarations for the parsing and printing hooks if it sees dialect types and attributes. Same goes for the definitions generated by the AttrOrTypeDef generator. This can lead to confusing and undesirable behaviour if the dialect generator doesn't see the definitions of the attributes and types, for example, if they are sensibly separated into different files: `Dialect.td`, `Ops.td`, `Attributes.td`, and `Types.td`. Now, these bits are unset by default. Setting them will always result in the dialect generator emitting the declarations for the parsing hooks. And if the AttrOrTypeDef generator sees it set, it will generate the default implementations. Reviewed By: rriddle, stellaraccident Differential Revision: https://reviews.llvm.org/D125809
1 parent 91a8caa commit 4957518

File tree

19 files changed

+64
-51
lines changed

19 files changed

+64
-51
lines changed

mlir/examples/toy/Ch7/include/toy/Ops.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def Toy_Dialect : Dialect {
3030
// We set this bit to generate a declaration of the `materializeConstant`
3131
// method so that we can materialize constants for our toy operations.
3232
let hasConstantMaterializer = 1;
33+
34+
// We set this bit to generate the declarations for the dialect's type parsing
35+
// and printing hooks.
36+
let useDefaultTypePrinterParser = 1;
3337
}
3438

3539
// Base class for toy dialect operations. This operation inherits from the base

mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,22 @@ include "mlir/IR/OpBase.td"
2121

2222
def AsyncDialect : Dialect {
2323
let name = "async";
24+
let cppNamespace = "::mlir::async";
2425

2526
let summary = "Types and operations for async dialect";
2627
let description = [{
2728
This dialect contains operations for modeling asynchronous execution.
2829
}];
2930

30-
let cppNamespace = "::mlir::async";
31+
let useDefaultTypePrinterParser = 1;
3132

3233
let extraClassDeclaration = [{
33-
// The name of a unit attribute on funcs that are allowed to have a blocking
34-
// async.runtime.await ops. Only useful in combination with
35-
// 'eliminate-blocking-await-ops' option, which in absence of this attribute
36-
// might convert a func to a coroutine.
37-
static constexpr StringRef kAllowedToBlockAttrName = "async.allowed_to_block";
34+
/// The name of a unit attribute on funcs that are allowed to have a
35+
/// blocking async.runtime.await ops. Only useful in combination with
36+
/// 'eliminate-blocking-await-ops' option, which in absence of this
37+
/// attribute might convert a func to a coroutine.
38+
static constexpr StringRef kAllowedToBlockAttrName =
39+
"async.allowed_to_block";
3840
}];
3941

4042
}

mlir/include/mlir/Dialect/DLTI/DLTIBase.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def DLTI_Dialect : Dialect {
3535
constexpr const static ::llvm::StringLiteral
3636
kDataLayoutEndiannessLittle = "little";
3737
}];
38+
39+
let useDefaultAttributePrinterParser = 1;
3840
}
3941

4042
def DLTI_DataLayoutEntryAttr : DialectAttr<

mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def EmitC_Dialect : Dialect {
3030

3131
let hasConstantMaterializer = 1;
3232
let useDefaultTypePrinterParser = 1;
33+
let useDefaultAttributePrinterParser = 1;
3334
}
3435

3536
#endif // MLIR_DIALECT_EMITC_IR_EMITCBASE

mlir/include/mlir/Dialect/GPU/GPUBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def GPU_Dialect : Dialect {
5454

5555
let dependentDialects = ["arith::ArithmeticDialect"];
5656
let useDefaultAttributePrinterParser = 1;
57+
let useDefaultTypePrinterParser = 1;
5758
}
5859

5960
def GPU_AsyncToken : DialectType<

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,12 @@ def LLVM_Dialect : Dialect {
2626
let name = "llvm";
2727
let cppNamespace = "::mlir::LLVM";
2828

29+
let useDefaultTypePrinterParser = 1;
30+
let useDefaultAttributePrinterParser = 1;
2931
let hasRegionArgAttrVerify = 1;
3032
let hasRegionResultAttrVerify = 1;
3133
let hasOperationAttrVerify = 1;
34+
3235
let extraClassDeclaration = [{
3336
/// Name of the data layout attributes.
3437
static StringRef getDataLayoutAttrName() { return "llvm.data_layout"; }

mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def Linalg_Dialect : Dialect {
4040
"memref::MemRefDialect",
4141
"tensor::TensorDialect",
4242
];
43+
let useDefaultAttributePrinterParser = 1;
4344
let hasCanonicalizer = 1;
4445
let hasOperationAttrVerify = 1;
4546
let hasConstantMaterializer = 1;

mlir/include/mlir/Dialect/NVGPU/NVGPU.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def NVGPU_Dialect : Dialect {
3232
representing PTX specific operations while using MLIR high level concepts
3333
like memref and 2-D vector.
3434
}];
35-
let useDefaultAttributePrinterParser = 1;
35+
36+
let useDefaultTypePrinterParser = 1;
3637
}
3738

3839
/// Device-side synchronization token.

mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ include "mlir/IR/OpBase.td"
2020
//===----------------------------------------------------------------------===//
2121

2222
def PDL_Dialect : Dialect {
23-
string summary = "High level pattern definition dialect";
24-
string description = [{
23+
let summary = "High level pattern definition dialect";
24+
let description = [{
2525
PDL presents a high level abstraction for the rewrite pattern infrastructure
2626
available in MLIR. This abstraction allows for representing patterns
2727
transforming MLIR, as MLIR. This allows for applying all of the benefits
@@ -64,6 +64,8 @@ def PDL_Dialect : Dialect {
6464

6565
let name = "pdl";
6666
let cppNamespace = "::mlir::pdl";
67+
68+
let useDefaultTypePrinterParser = 1;
6769
let extraClassDeclaration = [{
6870
void registerTypes();
6971
}];

mlir/include/mlir/Dialect/Quant/QuantOpsBase.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ include "mlir/IR/OpBase.td"
1818
def Quantization_Dialect : Dialect {
1919
let name = "quant";
2020
let cppNamespace = "::mlir::quant";
21+
22+
let useDefaultTypePrinterParser = 1;
2123
}
2224

2325
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def SPIRV_Dialect : Dialect {
4747
}];
4848

4949
let cppNamespace = "::mlir::spirv";
50+
let useDefaultTypePrinterParser = 1;
51+
let useDefaultAttributePrinterParser = 1;
5052
let hasConstantMaterializer = 1;
5153
let hasOperationAttrVerify = 1;
5254
let hasRegionArgAttrVerify = 1;

mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def ShapeDialect : Dialect {
3737
let cppNamespace = "::mlir::shape";
3838
let dependentDialects = ["arith::ArithmeticDialect", "tensor::TensorDialect"];
3939

40+
let useDefaultTypePrinterParser = 1;
4041
let hasConstantMaterializer = 1;
4142
let hasOperationAttrVerify = 1;
4243
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- Shape.td - Shape operations definition --------------*- tablegen -*-===//
1+
//===- ShapeOps.td - Shape operations definition -----------*- tablegen -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def SparseTensor_Dialect : Dialect {
7272
* [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation.
7373
PhD thesis, MIT, February, 2020.
7474
}];
75+
76+
let useDefaultAttributePrinterParser = 1;
7577
}
7678

7779
#endif // SPARSETENSOR_BASE

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ include "mlir/Interfaces/ViewLikeInterface.td"
2323
def Vector_Dialect : Dialect {
2424
let name = "vector";
2525
let cppNamespace = "::mlir::vector";
26+
27+
let useDefaultAttributePrinterParser = 1;
2628
let hasConstantMaterializer = 1;
2729
let dependentDialects = ["arith::ArithmeticDialect"];
2830
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;

mlir/include/mlir/IR/DialectBase.td

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,17 @@ class Dialect {
7373
// If this dialect overrides the hook for op interface fallback.
7474
bit hasOperationInterfaceFallback = 0;
7575

76-
// If this dialect should use default generated attribute parser boilerplate:
77-
// it'll dispatch the parsing to every individual attributes directly.
78-
bit useDefaultAttributePrinterParser = 1;
76+
// If this dialect should use default generated attribute parser boilerplate.
77+
// When set, ODS will generate declarations for the attribute parsing and
78+
// printing hooks in the dialect and default implementations that dispatch to
79+
// each individual attribute directly.
80+
bit useDefaultAttributePrinterParser = 0;
7981

8082
// If this dialect should use default generated type parser boilerplate:
81-
// it'll dispatch the parsing to every individual types directly.
82-
bit useDefaultTypePrinterParser = 1;
83+
// When set, ODS will generate declarations for the type parsing and printing
84+
// hooks in the dialect and default implementations that dispatch to each
85+
// individual type directly.
86+
bit useDefaultTypePrinterParser = 0;
8387

8488
// If this dialect overrides the hook for canonicalization patterns.
8589
bit hasCanonicalizer = 0;

mlir/test/python/python_test_ops.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
1717
def Python_Test_Dialect : Dialect {
1818
let name = "python_test";
1919
let cppNamespace = "python_test";
20+
21+
let useDefaultTypePrinterParser = 1;
22+
let useDefaultAttributePrinterParser = 1;
2023
}
2124

2225
class TestType<string name, string typeMnemonic>

mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -574,11 +574,9 @@ class DefGenerator {
574574

575575
protected:
576576
DefGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os,
577-
StringRef defType, StringRef valueType, bool isAttrGenerator,
578-
bool needsDialectParserPrinter)
577+
StringRef defType, StringRef valueType, bool isAttrGenerator)
579578
: defRecords(std::move(defs)), os(os), defType(defType),
580-
valueType(valueType), isAttrGenerator(isAttrGenerator),
581-
needsDialectParserPrinter(needsDialectParserPrinter) {}
579+
valueType(valueType), isAttrGenerator(isAttrGenerator) {}
582580

583581
/// Emit the list of def type names.
584582
void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs);
@@ -597,30 +595,19 @@ class DefGenerator {
597595
/// Flag indicating if this generator is for Attributes. False if the
598596
/// generator is for types.
599597
bool isAttrGenerator;
600-
/// Track if we need to emit the printAttribute/parseAttribute
601-
/// implementations.
602-
bool needsDialectParserPrinter;
603598
};
604599

605600
/// A specialized generator for AttrDefs.
606601
struct AttrDefGenerator : public DefGenerator {
607602
AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
608603
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
609-
"Attr", "Attribute",
610-
/*isAttrGenerator=*/true,
611-
/*needsDialectParserPrinter=*/
612-
!records.getAllDerivedDefinitions("DialectAttr").empty()) {
613-
}
604+
"Attr", "Attribute", /*isAttrGenerator=*/true) {}
614605
};
615606
/// A specialized generator for TypeDefs.
616607
struct TypeDefGenerator : public DefGenerator {
617608
TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
618609
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
619-
"Type", "Type",
620-
/*isAttrGenerator=*/false,
621-
/*needsDialectParserPrinter=*/
622-
!records.getAllDerivedDefinitions("DialectType").empty()) {
623-
}
610+
"Type", "Type", /*isAttrGenerator=*/false) {}
624611
};
625612
} // namespace
626613

@@ -879,10 +866,9 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
879866
}
880867

881868
Dialect firstDialect = defs.front().getDialect();
882-
// Emit the default parser/printer for Attributes if the dialect asked for
883-
// it.
884-
if (valueType == "Attribute" && needsDialectParserPrinter &&
885-
firstDialect.useDefaultAttributePrinterParser()) {
869+
870+
// Emit the default parser/printer for Attributes if the dialect asked for it.
871+
if (isAttrGenerator && firstDialect.useDefaultAttributePrinterParser()) {
886872
NamespaceEmitter nsEmitter(os, firstDialect);
887873
if (firstDialect.isExtensible()) {
888874
os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch,
@@ -896,8 +882,7 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
896882
}
897883

898884
// Emit the default parser/printer for Types if the dialect asked for it.
899-
if (valueType == "Type" && needsDialectParserPrinter &&
900-
firstDialect.useDefaultTypePrinterParser()) {
885+
if (!isAttrGenerator && firstDialect.useDefaultTypePrinterParser()) {
901886
NamespaceEmitter nsEmitter(os, firstDialect);
902887
if (firstDialect.isExtensible()) {
903888
os << llvm::formatv(dialectDefaultTypePrinterParserDispatch,

mlir/tools/mlir-tblgen/DialectGen.cpp

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,7 @@ static const char *const operationInterfaceFallbackDecl = R"(
182182
)";
183183

184184
/// Generate the declaration for the given dialect class.
185-
static void
186-
emitDialectDecl(Dialect &dialect,
187-
const iterator_range<DialectFilterIterator> &dialectAttrs,
188-
const iterator_range<DialectFilterIterator> &dialectTypes,
189-
raw_ostream &os) {
185+
static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
190186
// Emit all nested namespaces.
191187
{
192188
NamespaceEmitter nsEmitter(os, dialect);
@@ -198,11 +194,13 @@ emitDialectDecl(Dialect &dialect,
198194
os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
199195
superClassName);
200196

201-
// Check for any attributes/types registered to this dialect. If there are,
202-
// add the hooks for parsing/printing.
203-
if (!dialectAttrs.empty() && dialect.useDefaultAttributePrinterParser())
197+
// If the dialect requested the default attribute printer and parser, emit
198+
// the declarations for the hooks.
199+
if (dialect.useDefaultAttributePrinterParser())
204200
os << attrParserDecl;
205-
if (!dialectTypes.empty() && dialect.useDefaultTypePrinterParser())
201+
// If the dialect requested the default type printer and parser, emit the
202+
// delcarations for the hooks.
203+
if (dialect.useDefaultTypePrinterParser())
206204
os << typeParserDecl;
207205

208206
// Add the decls for the various features of the dialect.
@@ -242,10 +240,7 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
242240
Optional<Dialect> dialect = findDialectToGenerate(dialects);
243241
if (!dialect)
244242
return true;
245-
auto attrDefs = recordKeeper.getAllDerivedDefinitions("DialectAttr");
246-
auto typeDefs = recordKeeper.getAllDerivedDefinitions("DialectType");
247-
emitDialectDecl(*dialect, filterForDialect<Attribute>(attrDefs, *dialect),
248-
filterForDialect<Type>(typeDefs, *dialect), os);
243+
emitDialectDecl(*dialect, os);
249244
return false;
250245
}
251246

0 commit comments

Comments
 (0)