Skip to content

[mlir][ODS][NFC] Deduplicate ref and qualified handling #91080

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def InvalidTypeN : InvalidType<"InvalidTypeN", "invalid_n"> {

def InvalidTypeO : InvalidType<"InvalidTypeO", "invalid_o"> {
let parameters = (ins "int":$a);
// CHECK: `ref` is only allowed inside custom directives
// CHECK: 'ref' is only valid within a `custom` directive
let assemblyFormat = "$a ref($a)";
}

Expand Down
52 changes: 10 additions & 42 deletions mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,8 @@ class DefFormatParser : public FormatParser {
ArrayRef<FormatElement *> elements,
FormatElement *anchor) override;

LogicalResult markQualified(SMLoc loc, FormatElement *element) override;

/// Parse an attribute or type variable.
FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
Context ctx) override;
Expand All @@ -950,12 +952,8 @@ class DefFormatParser : public FormatParser {
private:
/// Parse a `params` directive.
FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx);
/// Parse a `qualified` directive.
FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx);
/// Parse a `struct` directive.
FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx);
/// Parse a `ref` directive.
FailureOr<FormatElement *> parseRefDirective(SMLoc loc, Context ctx);

/// Attribute or type tablegen def.
const AttrOrTypeDef &def;
Expand Down Expand Up @@ -1060,6 +1058,14 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
return success();
}

LogicalResult DefFormatParser::markQualified(SMLoc loc,
FormatElement *element) {
if (!isa<ParameterElement>(element))
return emitError(loc, "`qualified` argument list expected a variable");
cast<ParameterElement>(element)->setShouldBeQualified();
return success();
}

FailureOr<DefFormat> DefFormatParser::parse() {
FailureOr<std::vector<FormatElement *>> elements = FormatParser::parse();
if (failed(elements))
Expand Down Expand Up @@ -1107,33 +1113,11 @@ DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
return parseParamsDirective(loc, ctx);
case FormatToken::kw_struct:
return parseStructDirective(loc, ctx);
case FormatToken::kw_ref:
return parseRefDirective(loc, ctx);
case FormatToken::kw_custom:
return parseCustomDirective(loc, ctx);

default:
return emitError(loc, "unsupported directive kind");
}
}

FailureOr<FormatElement *>
DefFormatParser::parseQualifiedDirective(SMLoc loc, Context ctx) {
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before argument list")))
return failure();
FailureOr<FormatElement *> var = parseElement(ctx);
if (failed(var))
return var;
if (!isa<ParameterElement>(*var))
return emitError(loc, "`qualified` argument list expected a variable");
cast<ParameterElement>(*var)->setShouldBeQualified();
if (failed(
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
return failure();
return var;
}

FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc,
Context ctx) {
// It doesn't make sense to allow references to all parameters in a custom
Expand Down Expand Up @@ -1201,22 +1185,6 @@ FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
return create<StructDirective>(std::move(vars));
}

FailureOr<FormatElement *> DefFormatParser::parseRefDirective(SMLoc loc,
Context ctx) {
if (ctx != CustomDirectiveContext)
return emitError(loc, "`ref` is only allowed inside custom directives");

// Parse the child parameter element.
FailureOr<FormatElement *> child;
if (failed(parseToken(FormatToken::l_paren, "expected '('")) ||
failed(child = parseElement(RefDirectiveContext)) ||
failed(parseToken(FormatToken::r_paren, "expeced ')'")))
return failure();

// Only parameter elements are allowed to be parsed under a `ref` directive.
return create<RefDirective>(*child);
}

//===----------------------------------------------------------------------===//
// Interface
//===----------------------------------------------------------------------===//
Expand Down
36 changes: 36 additions & 0 deletions mlir/tools/mlir-tblgen/FormatGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,10 @@ FailureOr<FormatElement *> FormatParser::parseDirective(Context ctx) {

if (tok.is(FormatToken::kw_custom))
return parseCustomDirective(loc, ctx);
if (tok.is(FormatToken::kw_ref))
return parseRefDirective(loc, ctx);
if (tok.is(FormatToken::kw_qualified))
return parseQualifiedDirective(loc, ctx);
return parseDirectiveImpl(loc, tok.getKind(), ctx);
}

Expand Down Expand Up @@ -430,6 +434,38 @@ FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc,
return create<CustomDirective>(nameTok->getSpelling(), std::move(arguments));
}

FailureOr<FormatElement *> FormatParser::parseRefDirective(SMLoc loc,
Context context) {
if (context != CustomDirectiveContext)
return emitError(loc, "'ref' is only valid within a `custom` directive");

FailureOr<FormatElement *> arg;
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before argument list")) ||
failed(arg = parseElement(RefDirectiveContext)) ||
failed(
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
return failure();

return create<RefDirective>(*arg);
}

FailureOr<FormatElement *> FormatParser::parseQualifiedDirective(SMLoc loc,
Context ctx) {
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before argument list")))
return failure();
FailureOr<FormatElement *> var = parseElement(ctx);
if (failed(var))
return var;
if (failed(markQualified(loc, *var)))
return failure();
if (failed(
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
return failure();
return var;
}

//===----------------------------------------------------------------------===//
// Utility Functions
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 9 additions & 1 deletion mlir/tools/mlir-tblgen/FormatGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,12 @@ class FormatParser {
FailureOr<FormatElement *> parseDirective(Context ctx);
/// Parse an optional group.
FailureOr<FormatElement *> parseOptionalGroup(Context ctx);

/// Parse a custom directive.
FailureOr<FormatElement *> parseCustomDirective(llvm::SMLoc loc, Context ctx);
/// Parse a ref directive.
FailureOr<FormatElement *> parseRefDirective(SMLoc loc, Context context);
/// Parse a qualified directive.
FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx);

/// Parse a format-specific variable kind.
virtual FailureOr<FormatElement *>
Expand All @@ -522,6 +525,11 @@ class FormatParser {
ArrayRef<FormatElement *> elements,
FormatElement *anchor) = 0;

/// Mark 'element' as qualified. If 'element' cannot be qualified an error
/// should be emitted and failure returned.
virtual LogicalResult markQualified(llvm::SMLoc loc,
FormatElement *element) = 0;

//===--------------------------------------------------------------------===//
// Lexer Utilities

Expand Down
40 changes: 5 additions & 35 deletions mlir/tools/mlir-tblgen/OpFormatGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2547,6 +2547,8 @@ class OpFormatParser : public FormatParser {
LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element,
bool isAnchor);

LogicalResult markQualified(SMLoc loc, FormatElement *element) override;

/// Parse an operation variable.
FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
Context ctx) override;
Expand Down Expand Up @@ -2622,10 +2624,6 @@ class OpFormatParser : public FormatParser {
FailureOr<FormatElement *> parseOIListDirective(SMLoc loc, Context context);
LogicalResult verifyOIListParsingElement(FormatElement *element, SMLoc loc);
FailureOr<FormatElement *> parseOperandsDirective(SMLoc loc, Context context);
FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc,
Context context);
FailureOr<FormatElement *> parseReferenceDirective(SMLoc loc,
Context context);
FailureOr<FormatElement *> parseRegionsDirective(SMLoc loc, Context context);
FailureOr<FormatElement *> parseResultsDirective(SMLoc loc, Context context);
FailureOr<FormatElement *> parseSuccessorsDirective(SMLoc loc,
Expand Down Expand Up @@ -3224,16 +3222,12 @@ OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
return parseFunctionalTypeDirective(loc, ctx);
case FormatToken::kw_operands:
return parseOperandsDirective(loc, ctx);
case FormatToken::kw_qualified:
return parseQualifiedDirective(loc, ctx);
case FormatToken::kw_regions:
return parseRegionsDirective(loc, ctx);
case FormatToken::kw_results:
return parseResultsDirective(loc, ctx);
case FormatToken::kw_successors:
return parseSuccessorsDirective(loc, ctx);
case FormatToken::kw_ref:
return parseReferenceDirective(loc, ctx);
case FormatToken::kw_type:
return parseTypeDirective(loc, ctx);
case FormatToken::kw_oilist:
Expand Down Expand Up @@ -3338,22 +3332,6 @@ OpFormatParser::parseOperandsDirective(SMLoc loc, Context context) {
return create<OperandsDirective>();
}

FailureOr<FormatElement *>
OpFormatParser::parseReferenceDirective(SMLoc loc, Context context) {
if (context != CustomDirectiveContext)
return emitError(loc, "'ref' is only valid within a `custom` directive");

FailureOr<FormatElement *> arg;
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before argument list")) ||
failed(arg = parseElement(RefDirectiveContext)) ||
failed(
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
return failure();

return create<RefDirective>(*arg);
}

FailureOr<FormatElement *>
OpFormatParser::parseRegionsDirective(SMLoc loc, Context context) {
if (context == TypeDirectiveContext)
Expand Down Expand Up @@ -3495,19 +3473,11 @@ FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc,
return create<TypeDirective>(*operand);
}

FailureOr<FormatElement *>
OpFormatParser::parseQualifiedDirective(SMLoc loc, Context context) {
FailureOr<FormatElement *> element;
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before argument list")) ||
failed(element = parseElement(context)) ||
failed(
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
return failure();
return TypeSwitch<FormatElement *, FailureOr<FormatElement *>>(*element)
LogicalResult OpFormatParser::markQualified(SMLoc loc, FormatElement *element) {
return TypeSwitch<FormatElement *, LogicalResult>(element)
.Case<AttributeVariable, TypeDirective>([](auto *element) {
element->setShouldBeQualified();
return element;
return success();
})
.Default([&](auto *element) {
return this->emitError(
Expand Down