Skip to content

Commit 3cbfc9d

Browse files
authored
[mlir][ODS][NFC] Deduplicate ref and qualified handling (llvm#91080)
Both the attribute and type format generator and the op format generator independently implemented the parsing and verification of the `ref` and `qualified` directives with little to no differences. This PR moves the implementation of these into the common `FormatParser` class to deduplicate the implementations.
1 parent b54a78d commit 3cbfc9d

File tree

5 files changed

+61
-79
lines changed

5 files changed

+61
-79
lines changed

mlir/test/mlir-tblgen/attr-or-type-format-invalid.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def InvalidTypeN : InvalidType<"InvalidTypeN", "invalid_n"> {
111111

112112
def InvalidTypeO : InvalidType<"InvalidTypeO", "invalid_o"> {
113113
let parameters = (ins "int":$a);
114-
// CHECK: `ref` is only allowed inside custom directives
114+
// CHECK: 'ref' is only valid within a `custom` directive
115115
let assemblyFormat = "$a ref($a)";
116116
}
117117

mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp

Lines changed: 10 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,8 @@ class DefFormatParser : public FormatParser {
940940
ArrayRef<FormatElement *> elements,
941941
FormatElement *anchor) override;
942942

943+
LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
944+
943945
/// Parse an attribute or type variable.
944946
FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
945947
Context ctx) override;
@@ -950,12 +952,8 @@ class DefFormatParser : public FormatParser {
950952
private:
951953
/// Parse a `params` directive.
952954
FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx);
953-
/// Parse a `qualified` directive.
954-
FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx);
955955
/// Parse a `struct` directive.
956956
FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx);
957-
/// Parse a `ref` directive.
958-
FailureOr<FormatElement *> parseRefDirective(SMLoc loc, Context ctx);
959957

960958
/// Attribute or type tablegen def.
961959
const AttrOrTypeDef &def;
@@ -1060,6 +1058,14 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
10601058
return success();
10611059
}
10621060

1061+
LogicalResult DefFormatParser::markQualified(SMLoc loc,
1062+
FormatElement *element) {
1063+
if (!isa<ParameterElement>(element))
1064+
return emitError(loc, "`qualified` argument list expected a variable");
1065+
cast<ParameterElement>(element)->setShouldBeQualified();
1066+
return success();
1067+
}
1068+
10631069
FailureOr<DefFormat> DefFormatParser::parse() {
10641070
FailureOr<std::vector<FormatElement *>> elements = FormatParser::parse();
10651071
if (failed(elements))
@@ -1107,33 +1113,11 @@ DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
11071113
return parseParamsDirective(loc, ctx);
11081114
case FormatToken::kw_struct:
11091115
return parseStructDirective(loc, ctx);
1110-
case FormatToken::kw_ref:
1111-
return parseRefDirective(loc, ctx);
1112-
case FormatToken::kw_custom:
1113-
return parseCustomDirective(loc, ctx);
1114-
11151116
default:
11161117
return emitError(loc, "unsupported directive kind");
11171118
}
11181119
}
11191120

1120-
FailureOr<FormatElement *>
1121-
DefFormatParser::parseQualifiedDirective(SMLoc loc, Context ctx) {
1122-
if (failed(parseToken(FormatToken::l_paren,
1123-
"expected '(' before argument list")))
1124-
return failure();
1125-
FailureOr<FormatElement *> var = parseElement(ctx);
1126-
if (failed(var))
1127-
return var;
1128-
if (!isa<ParameterElement>(*var))
1129-
return emitError(loc, "`qualified` argument list expected a variable");
1130-
cast<ParameterElement>(*var)->setShouldBeQualified();
1131-
if (failed(
1132-
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
1133-
return failure();
1134-
return var;
1135-
}
1136-
11371121
FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc,
11381122
Context ctx) {
11391123
// It doesn't make sense to allow references to all parameters in a custom
@@ -1201,22 +1185,6 @@ FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
12011185
return create<StructDirective>(std::move(vars));
12021186
}
12031187

1204-
FailureOr<FormatElement *> DefFormatParser::parseRefDirective(SMLoc loc,
1205-
Context ctx) {
1206-
if (ctx != CustomDirectiveContext)
1207-
return emitError(loc, "`ref` is only allowed inside custom directives");
1208-
1209-
// Parse the child parameter element.
1210-
FailureOr<FormatElement *> child;
1211-
if (failed(parseToken(FormatToken::l_paren, "expected '('")) ||
1212-
failed(child = parseElement(RefDirectiveContext)) ||
1213-
failed(parseToken(FormatToken::r_paren, "expeced ')'")))
1214-
return failure();
1215-
1216-
// Only parameter elements are allowed to be parsed under a `ref` directive.
1217-
return create<RefDirective>(*child);
1218-
}
1219-
12201188
//===----------------------------------------------------------------------===//
12211189
// Interface
12221190
//===----------------------------------------------------------------------===//

mlir/tools/mlir-tblgen/FormatGen.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,10 @@ FailureOr<FormatElement *> FormatParser::parseDirective(Context ctx) {
308308

309309
if (tok.is(FormatToken::kw_custom))
310310
return parseCustomDirective(loc, ctx);
311+
if (tok.is(FormatToken::kw_ref))
312+
return parseRefDirective(loc, ctx);
313+
if (tok.is(FormatToken::kw_qualified))
314+
return parseQualifiedDirective(loc, ctx);
311315
return parseDirectiveImpl(loc, tok.getKind(), ctx);
312316
}
313317

@@ -430,6 +434,38 @@ FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc,
430434
return create<CustomDirective>(nameTok->getSpelling(), std::move(arguments));
431435
}
432436

437+
FailureOr<FormatElement *> FormatParser::parseRefDirective(SMLoc loc,
438+
Context context) {
439+
if (context != CustomDirectiveContext)
440+
return emitError(loc, "'ref' is only valid within a `custom` directive");
441+
442+
FailureOr<FormatElement *> arg;
443+
if (failed(parseToken(FormatToken::l_paren,
444+
"expected '(' before argument list")) ||
445+
failed(arg = parseElement(RefDirectiveContext)) ||
446+
failed(
447+
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
448+
return failure();
449+
450+
return create<RefDirective>(*arg);
451+
}
452+
453+
FailureOr<FormatElement *> FormatParser::parseQualifiedDirective(SMLoc loc,
454+
Context ctx) {
455+
if (failed(parseToken(FormatToken::l_paren,
456+
"expected '(' before argument list")))
457+
return failure();
458+
FailureOr<FormatElement *> var = parseElement(ctx);
459+
if (failed(var))
460+
return var;
461+
if (failed(markQualified(loc, *var)))
462+
return failure();
463+
if (failed(
464+
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
465+
return failure();
466+
return var;
467+
}
468+
433469
//===----------------------------------------------------------------------===//
434470
// Utility Functions
435471
//===----------------------------------------------------------------------===//

mlir/tools/mlir-tblgen/FormatGen.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,9 +495,12 @@ class FormatParser {
495495
FailureOr<FormatElement *> parseDirective(Context ctx);
496496
/// Parse an optional group.
497497
FailureOr<FormatElement *> parseOptionalGroup(Context ctx);
498-
499498
/// Parse a custom directive.
500499
FailureOr<FormatElement *> parseCustomDirective(llvm::SMLoc loc, Context ctx);
500+
/// Parse a ref directive.
501+
FailureOr<FormatElement *> parseRefDirective(SMLoc loc, Context context);
502+
/// Parse a qualified directive.
503+
FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx);
501504

502505
/// Parse a format-specific variable kind.
503506
virtual FailureOr<FormatElement *>
@@ -522,6 +525,11 @@ class FormatParser {
522525
ArrayRef<FormatElement *> elements,
523526
FormatElement *anchor) = 0;
524527

528+
/// Mark 'element' as qualified. If 'element' cannot be qualified an error
529+
/// should be emitted and failure returned.
530+
virtual LogicalResult markQualified(llvm::SMLoc loc,
531+
FormatElement *element) = 0;
532+
525533
//===--------------------------------------------------------------------===//
526534
// Lexer Utilities
527535

mlir/tools/mlir-tblgen/OpFormatGen.cpp

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2547,6 +2547,8 @@ class OpFormatParser : public FormatParser {
25472547
LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element,
25482548
bool isAnchor);
25492549

2550+
LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
2551+
25502552
/// Parse an operation variable.
25512553
FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
25522554
Context ctx) override;
@@ -2622,10 +2624,6 @@ class OpFormatParser : public FormatParser {
26222624
FailureOr<FormatElement *> parseOIListDirective(SMLoc loc, Context context);
26232625
LogicalResult verifyOIListParsingElement(FormatElement *element, SMLoc loc);
26242626
FailureOr<FormatElement *> parseOperandsDirective(SMLoc loc, Context context);
2625-
FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc,
2626-
Context context);
2627-
FailureOr<FormatElement *> parseReferenceDirective(SMLoc loc,
2628-
Context context);
26292627
FailureOr<FormatElement *> parseRegionsDirective(SMLoc loc, Context context);
26302628
FailureOr<FormatElement *> parseResultsDirective(SMLoc loc, Context context);
26312629
FailureOr<FormatElement *> parseSuccessorsDirective(SMLoc loc,
@@ -3224,16 +3222,12 @@ OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
32243222
return parseFunctionalTypeDirective(loc, ctx);
32253223
case FormatToken::kw_operands:
32263224
return parseOperandsDirective(loc, ctx);
3227-
case FormatToken::kw_qualified:
3228-
return parseQualifiedDirective(loc, ctx);
32293225
case FormatToken::kw_regions:
32303226
return parseRegionsDirective(loc, ctx);
32313227
case FormatToken::kw_results:
32323228
return parseResultsDirective(loc, ctx);
32333229
case FormatToken::kw_successors:
32343230
return parseSuccessorsDirective(loc, ctx);
3235-
case FormatToken::kw_ref:
3236-
return parseReferenceDirective(loc, ctx);
32373231
case FormatToken::kw_type:
32383232
return parseTypeDirective(loc, ctx);
32393233
case FormatToken::kw_oilist:
@@ -3338,22 +3332,6 @@ OpFormatParser::parseOperandsDirective(SMLoc loc, Context context) {
33383332
return create<OperandsDirective>();
33393333
}
33403334

3341-
FailureOr<FormatElement *>
3342-
OpFormatParser::parseReferenceDirective(SMLoc loc, Context context) {
3343-
if (context != CustomDirectiveContext)
3344-
return emitError(loc, "'ref' is only valid within a `custom` directive");
3345-
3346-
FailureOr<FormatElement *> arg;
3347-
if (failed(parseToken(FormatToken::l_paren,
3348-
"expected '(' before argument list")) ||
3349-
failed(arg = parseElement(RefDirectiveContext)) ||
3350-
failed(
3351-
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3352-
return failure();
3353-
3354-
return create<RefDirective>(*arg);
3355-
}
3356-
33573335
FailureOr<FormatElement *>
33583336
OpFormatParser::parseRegionsDirective(SMLoc loc, Context context) {
33593337
if (context == TypeDirectiveContext)
@@ -3495,19 +3473,11 @@ FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc,
34953473
return create<TypeDirective>(*operand);
34963474
}
34973475

3498-
FailureOr<FormatElement *>
3499-
OpFormatParser::parseQualifiedDirective(SMLoc loc, Context context) {
3500-
FailureOr<FormatElement *> element;
3501-
if (failed(parseToken(FormatToken::l_paren,
3502-
"expected '(' before argument list")) ||
3503-
failed(element = parseElement(context)) ||
3504-
failed(
3505-
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3506-
return failure();
3507-
return TypeSwitch<FormatElement *, FailureOr<FormatElement *>>(*element)
3476+
LogicalResult OpFormatParser::markQualified(SMLoc loc, FormatElement *element) {
3477+
return TypeSwitch<FormatElement *, LogicalResult>(element)
35083478
.Case<AttributeVariable, TypeDirective>([](auto *element) {
35093479
element->setShouldBeQualified();
3510-
return element;
3480+
return success();
35113481
})
35123482
.Default([&](auto *element) {
35133483
return this->emitError(

0 commit comments

Comments
 (0)