Skip to content

Commit 07c157a

Browse files
j2kunftynse
andauthored
[mlir] load dialect in parser for optional parameters (#96667)
#96242 fixed an issue where the auto-generated parsers were not loading dialects whose namespaces are not present in the textual IR. This required the attribute parameter to be a tablegen def with its dialect information attached. This fails when using parameter wrapper classes like `OptionalParameter`. This came up because `RingAttr` uses `OptionalParameter` for its second and third attributes. `OptionalParameter` takes as input the C++ type as a string instead of the tablegen def, and so it doesn't have a dialect member value to trigger the fix from #96242. The docs on this topic say the appropriate solution as overloading `FieldParser` for a particular type. This PR updates `FieldParser` for generic attributes to load the dialect on demand. This requires `mlir-tblgen` to emit a `dialectName` static field on the generated attribute class, and check for it with template metaprogramming, since not all attribute types go through `mlir-tblgen`. --------- Co-authored-by: Jeremy Kun <[email protected]> Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]>
1 parent c65f8d8 commit 07c157a

File tree

7 files changed

+72
-16
lines changed

7 files changed

+72
-16
lines changed

mlir/include/mlir/IR/DialectImplementation.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,22 @@
1515
#define MLIR_IR_DIALECTIMPLEMENTATION_H
1616

1717
#include "mlir/IR/OpImplementation.h"
18+
#include <type_traits>
19+
20+
namespace {
21+
22+
// reference https://stackoverflow.com/a/16000226
23+
template <typename T, typename = void>
24+
struct HasStaticDialectName : std::false_type {};
25+
26+
template <typename T>
27+
struct HasStaticDialectName<
28+
T, typename std::enable_if<
29+
std::is_same<::llvm::StringLiteral,
30+
std::decay_t<decltype(T::dialectName)>>::value,
31+
void>::type> : std::true_type {};
32+
33+
} // namespace
1834

1935
namespace mlir {
2036

@@ -63,6 +79,9 @@ struct FieldParser<
6379
AttributeT, std::enable_if_t<std::is_base_of<Attribute, AttributeT>::value,
6480
AttributeT>> {
6581
static FailureOr<AttributeT> parse(AsmParser &parser) {
82+
if constexpr (HasStaticDialectName<AttributeT>::value) {
83+
parser.getContext()->getOrLoadDialect(AttributeT::dialectName);
84+
}
6685
AttributeT value;
6786
if (parser.parseCustomAttributeWithFallback(value))
6887
return failure();
@@ -112,6 +131,9 @@ struct FieldParser<
112131
std::enable_if_t<std::is_base_of<Attribute, AttributeT>::value,
113132
std::optional<AttributeT>>> {
114133
static FailureOr<std::optional<AttributeT>> parse(AsmParser &parser) {
134+
if constexpr (HasStaticDialectName<AttributeT>::value) {
135+
parser.getContext()->getOrLoadDialect(AttributeT::dialectName);
136+
}
115137
AttributeT attr;
116138
OptionalParseResult result = parser.parseOptionalAttribute(attr);
117139
if (result.has_value()) {

mlir/test/IR/parser.mlir

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,15 +1464,3 @@ test.dialect_custom_format_fallback custom_format_fallback
14641464
// Check that an op with an optional result parses f80 as type.
14651465
// CHECK: test.format_optional_result_d_op : f80
14661466
test.format_optional_result_d_op : f80
1467-
1468-
1469-
// -----
1470-
1471-
// This is a testing that a non-qualified attribute in a custom format
1472-
// correctly preload the dialect before creating the attribute.
1473-
#attr = #test.nested_polynomial<<1 + x**2>>
1474-
// CHECK-lABLE: @parse_correctly
1475-
llvm.func @parse_correctly() {
1476-
test.containing_int_polynomial_attr #attr
1477-
llvm.return
1478-
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: mlir-opt -allow-unregistered-dialect --split-input-file %s | FileCheck %s
2+
3+
// This is a testing that a non-qualified attribute in a custom format
4+
// correctly preload the dialect before creating the attribute.
5+
#attr = #test.nested_polynomial<poly=<1 + x**2>>
6+
// CHECK-LABEL: @parse_correctly
7+
llvm.func @parse_correctly() {
8+
test.containing_int_polynomial_attr #attr
9+
llvm.return
10+
}
11+
12+
// -----
13+
14+
#attr2 = #test.nested_polynomial2<poly=<1 + x**2>>
15+
// CHECK-LABEL: @parse_correctly_2
16+
llvm.func @parse_correctly_2() {
17+
test.containing_int_polynomial_attr2 #attr2
18+
llvm.return
19+
}

mlir/test/lib/Dialect/Test/TestAttrDefs.td

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,17 @@ def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
356356
let mnemonic = "nested_polynomial";
357357
let parameters = (ins Polynomial_IntPolynomialAttr:$poly);
358358
let assemblyFormat = [{
359-
`<` $poly `>`
359+
`<` struct(params) `>`
360360
}];
361361
}
362362

363+
def NestedPolynomialAttr2 : Test_Attr<"NestedPolynomialAttr2"> {
364+
let mnemonic = "nested_polynomial2";
365+
let parameters = (ins OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">:$poly);
366+
let assemblyFormat = [{
367+
`<` struct(params) `>`
368+
}];
369+
}
370+
371+
363372
#endif // TEST_ATTRDEFS

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,11 @@ def ContainingIntPolynomialAttrOp : TEST_Op<"containing_int_polynomial_attr"> {
237237
let assemblyFormat = "$attr attr-dict";
238238
}
239239

240+
def ContainingIntPolynomialAttr2Op : TEST_Op<"containing_int_polynomial_attr2"> {
241+
let arguments = (ins NestedPolynomialAttr2:$attr);
242+
let assemblyFormat = "$attr attr-dict";
243+
}
244+
240245
// A pattern that updates dense<[3.0, 4.0]> to dense<[5.0, 6.0]>.
241246
// This tests both matching and generating float elements attributes.
242247
def UpdateFloatElementsAttr : Pat<

mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ class DefGen {
8989
void emitTopLevelDeclarations();
9090
/// Emit the function that returns the type or attribute name.
9191
void emitName();
92+
/// Emit the dialect name as a static member variable.
93+
void emitDialectName();
9294
/// Emit attribute or type builders.
9395
void emitBuilders();
9496
/// Emit a verifier for the def.
@@ -184,6 +186,8 @@ DefGen::DefGen(const AttrOrTypeDef &def)
184186
emitBuilders();
185187
// Emit the type name.
186188
emitName();
189+
// Emit the dialect name.
190+
emitDialectName();
187191
// Emit the verifier.
188192
if (storageCls && def.genVerifyDecl())
189193
emitVerifier();
@@ -281,6 +285,13 @@ void DefGen::emitName() {
281285
defCls.declare<ExtraClassDeclaration>(std::move(nameDecl));
282286
}
283287

288+
void DefGen::emitDialectName() {
289+
std::string decl =
290+
strfmt("static constexpr ::llvm::StringLiteral dialectName = \"{0}\";\n",
291+
def.getDialect().getName());
292+
defCls.declare<ExtraClassDeclaration>(std::move(decl));
293+
}
294+
284295
void DefGen::emitBuilders() {
285296
if (!def.skipDefaultBuilders()) {
286297
emitDefaultBuilder();

mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,9 +423,11 @@ void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
423423
Dialect dialect(dialectInit->getDef());
424424
auto cppNamespace = dialect.getCppNamespace();
425425
std::string name = dialect.getCppClassName();
426-
dialectLoading = ("\nodsParser.getContext()->getOrLoadDialect<" +
427-
cppNamespace + "::" + name + ">();")
428-
.str();
426+
if (name != "BuiltinDialect" || cppNamespace != "::mlir") {
427+
dialectLoading = ("\nodsParser.getContext()->getOrLoadDialect<" +
428+
cppNamespace + "::" + name + ">();")
429+
.str();
430+
}
429431
}
430432
}
431433
}

0 commit comments

Comments
 (0)