Skip to content

[mlir][ODS][NFC] Deduplicate ref and qualified handling #91080

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 4, 2024

Conversation

zero9178
Copy link
Member

@zero9178 zero9178 commented May 4, 2024

Both the attribute and type format generator and the op format generator independently implemented the parsing and verification of the ref and qualified directives with little to no differences.

This PR moves the implementation of these into the common FormatParser class to deduplicate the implementations.

Both the attribute and type format generator and the op format generator independently implemented the parsing and verification of the `ref` and `qualified` directives with little to no differences.

This PR moves the implementation of these into the common `FormatParser` class to deduplicate the implementations.
@zero9178 zero9178 changed the title [mlir][ODS] Deduplicate ref and qualified handling [mlir][ODS][NFC] Deduplicate ref and qualified handling May 4, 2024
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels May 4, 2024
@llvmbot
Copy link
Member

llvmbot commented May 4, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Markus Böck (zero9178)

Changes

Both the attribute and type format generator and the op format generator independently implemented the parsing and verification of the ref and qualified directives with little to no differences.

This PR moves the implementation of these into the common FormatParser class to deduplicate the implementations.


Full diff: https://github.com/llvm/llvm-project/pull/91080.diff

5 Files Affected:

  • (modified) mlir/test/mlir-tblgen/attr-or-type-format-invalid.td (+1-1)
  • (modified) mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp (+10-42)
  • (modified) mlir/tools/mlir-tblgen/FormatGen.cpp (+36)
  • (modified) mlir/tools/mlir-tblgen/FormatGen.h (+9-1)
  • (modified) mlir/tools/mlir-tblgen/OpFormatGen.cpp (+5-35)
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
index d3be4d8b8022a0..3a57cbca4d7bb7 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
@@ -111,7 +111,7 @@ def InvalidTypeN : InvalidType<"InvalidTypeN", "invalid_n"> {
 
 def InvalidTypeO : InvalidType<"InvalidTypeO", "invalid_o"> {
   let parameters = (ins "int":$a);
-  // CHECK: `ref` is only allowed inside custom directives
+  // CHECK: 'ref' is only valid within a `custom` directive
   let assemblyFormat = "$a ref($a)";
 }
 
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 6098808c646f76..abd1fbdaf8c649 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -940,6 +940,8 @@ class DefFormatParser : public FormatParser {
                                             ArrayRef<FormatElement *> elements,
                                             FormatElement *anchor) override;
 
+  LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
+
   /// Parse an attribute or type variable.
   FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
                                                Context ctx) override;
@@ -950,12 +952,8 @@ class DefFormatParser : public FormatParser {
 private:
   /// Parse a `params` directive.
   FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx);
-  /// Parse a `qualified` directive.
-  FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx);
   /// Parse a `struct` directive.
   FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx);
-  /// Parse a `ref` directive.
-  FailureOr<FormatElement *> parseRefDirective(SMLoc loc, Context ctx);
 
   /// Attribute or type tablegen def.
   const AttrOrTypeDef &def;
@@ -1060,6 +1058,14 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
   return success();
 }
 
+LogicalResult DefFormatParser::markQualified(SMLoc loc,
+                                             FormatElement *element) {
+  if (!isa<ParameterElement>(element))
+    return emitError(loc, "`qualified` argument list expected a variable");
+  cast<ParameterElement>(element)->setShouldBeQualified();
+  return success();
+}
+
 FailureOr<DefFormat> DefFormatParser::parse() {
   FailureOr<std::vector<FormatElement *>> elements = FormatParser::parse();
   if (failed(elements))
@@ -1107,33 +1113,11 @@ DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
     return parseParamsDirective(loc, ctx);
   case FormatToken::kw_struct:
     return parseStructDirective(loc, ctx);
-  case FormatToken::kw_ref:
-    return parseRefDirective(loc, ctx);
-  case FormatToken::kw_custom:
-    return parseCustomDirective(loc, ctx);
-
   default:
     return emitError(loc, "unsupported directive kind");
   }
 }
 
-FailureOr<FormatElement *>
-DefFormatParser::parseQualifiedDirective(SMLoc loc, Context ctx) {
-  if (failed(parseToken(FormatToken::l_paren,
-                        "expected '(' before argument list")))
-    return failure();
-  FailureOr<FormatElement *> var = parseElement(ctx);
-  if (failed(var))
-    return var;
-  if (!isa<ParameterElement>(*var))
-    return emitError(loc, "`qualified` argument list expected a variable");
-  cast<ParameterElement>(*var)->setShouldBeQualified();
-  if (failed(
-          parseToken(FormatToken::r_paren, "expected ')' after argument list")))
-    return failure();
-  return var;
-}
-
 FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc,
                                                                  Context ctx) {
   // It doesn't make sense to allow references to all parameters in a custom
@@ -1201,22 +1185,6 @@ FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
   return create<StructDirective>(std::move(vars));
 }
 
-FailureOr<FormatElement *> DefFormatParser::parseRefDirective(SMLoc loc,
-                                                              Context ctx) {
-  if (ctx != CustomDirectiveContext)
-    return emitError(loc, "`ref` is only allowed inside custom directives");
-
-  // Parse the child parameter element.
-  FailureOr<FormatElement *> child;
-  if (failed(parseToken(FormatToken::l_paren, "expected '('")) ||
-      failed(child = parseElement(RefDirectiveContext)) ||
-      failed(parseToken(FormatToken::r_paren, "expeced ')'")))
-    return failure();
-
-  // Only parameter elements are allowed to be parsed under a `ref` directive.
-  return create<RefDirective>(*child);
-}
-
 //===----------------------------------------------------------------------===//
 // Interface
 //===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp
index d402748b96ad5f..7540e584b8fac5 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/FormatGen.cpp
@@ -308,6 +308,10 @@ FailureOr<FormatElement *> FormatParser::parseDirective(Context ctx) {
 
   if (tok.is(FormatToken::kw_custom))
     return parseCustomDirective(loc, ctx);
+  if (tok.is(FormatToken::kw_ref))
+    return parseRefDirective(loc, ctx);
+  if (tok.is(FormatToken::kw_qualified))
+    return parseQualifiedDirective(loc, ctx);
   return parseDirectiveImpl(loc, tok.getKind(), ctx);
 }
 
@@ -430,6 +434,38 @@ FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc,
   return create<CustomDirective>(nameTok->getSpelling(), std::move(arguments));
 }
 
+FailureOr<FormatElement *> FormatParser::parseRefDirective(SMLoc loc,
+                                                           Context context) {
+  if (context != CustomDirectiveContext)
+    return emitError(loc, "'ref' is only valid within a `custom` directive");
+
+  FailureOr<FormatElement *> arg;
+  if (failed(parseToken(FormatToken::l_paren,
+                        "expected '(' before argument list")) ||
+      failed(arg = parseElement(RefDirectiveContext)) ||
+      failed(
+          parseToken(FormatToken::r_paren, "expected ')' after argument list")))
+    return failure();
+
+  return create<RefDirective>(*arg);
+}
+
+FailureOr<FormatElement *> FormatParser::parseQualifiedDirective(SMLoc loc,
+                                                                 Context ctx) {
+  if (failed(parseToken(FormatToken::l_paren,
+                        "expected '(' before argument list")))
+    return failure();
+  FailureOr<FormatElement *> var = parseElement(ctx);
+  if (failed(var))
+    return var;
+  if (failed(markQualified(loc, *var)))
+    return failure();
+  if (failed(
+          parseToken(FormatToken::r_paren, "expected ')' after argument list")))
+    return failure();
+  return var;
+}
+
 //===----------------------------------------------------------------------===//
 // Utility Functions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h
index 18a410277fc108..b061d4d8ea7f03 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.h
+++ b/mlir/tools/mlir-tblgen/FormatGen.h
@@ -495,9 +495,12 @@ class FormatParser {
   FailureOr<FormatElement *> parseDirective(Context ctx);
   /// Parse an optional group.
   FailureOr<FormatElement *> parseOptionalGroup(Context ctx);
-
   /// Parse a custom directive.
   FailureOr<FormatElement *> parseCustomDirective(llvm::SMLoc loc, Context ctx);
+  /// Parse a ref directive.
+  FailureOr<FormatElement *> parseRefDirective(SMLoc loc, Context context);
+  /// Parse a qualified directive.
+  FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx);
 
   /// Parse a format-specific variable kind.
   virtual FailureOr<FormatElement *>
@@ -522,6 +525,11 @@ class FormatParser {
                               ArrayRef<FormatElement *> elements,
                               FormatElement *anchor) = 0;
 
+  /// Mark 'element' as qualified. If 'element' cannot be qualified an error
+  /// should be emitted and failure returned.
+  virtual LogicalResult markQualified(llvm::SMLoc loc,
+                                      FormatElement *element) = 0;
+
   //===--------------------------------------------------------------------===//
   // Lexer Utilities
 
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 806991035e6685..f7cc0a292b8c53 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -2547,6 +2547,8 @@ class OpFormatParser : public FormatParser {
   LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element,
                                            bool isAnchor);
 
+  LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
+
   /// Parse an operation variable.
   FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
                                                Context ctx) override;
@@ -2622,10 +2624,6 @@ class OpFormatParser : public FormatParser {
   FailureOr<FormatElement *> parseOIListDirective(SMLoc loc, Context context);
   LogicalResult verifyOIListParsingElement(FormatElement *element, SMLoc loc);
   FailureOr<FormatElement *> parseOperandsDirective(SMLoc loc, Context context);
-  FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc,
-                                                     Context context);
-  FailureOr<FormatElement *> parseReferenceDirective(SMLoc loc,
-                                                     Context context);
   FailureOr<FormatElement *> parseRegionsDirective(SMLoc loc, Context context);
   FailureOr<FormatElement *> parseResultsDirective(SMLoc loc, Context context);
   FailureOr<FormatElement *> parseSuccessorsDirective(SMLoc loc,
@@ -3224,16 +3222,12 @@ OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
     return parseFunctionalTypeDirective(loc, ctx);
   case FormatToken::kw_operands:
     return parseOperandsDirective(loc, ctx);
-  case FormatToken::kw_qualified:
-    return parseQualifiedDirective(loc, ctx);
   case FormatToken::kw_regions:
     return parseRegionsDirective(loc, ctx);
   case FormatToken::kw_results:
     return parseResultsDirective(loc, ctx);
   case FormatToken::kw_successors:
     return parseSuccessorsDirective(loc, ctx);
-  case FormatToken::kw_ref:
-    return parseReferenceDirective(loc, ctx);
   case FormatToken::kw_type:
     return parseTypeDirective(loc, ctx);
   case FormatToken::kw_oilist:
@@ -3338,22 +3332,6 @@ OpFormatParser::parseOperandsDirective(SMLoc loc, Context context) {
   return create<OperandsDirective>();
 }
 
-FailureOr<FormatElement *>
-OpFormatParser::parseReferenceDirective(SMLoc loc, Context context) {
-  if (context != CustomDirectiveContext)
-    return emitError(loc, "'ref' is only valid within a `custom` directive");
-
-  FailureOr<FormatElement *> arg;
-  if (failed(parseToken(FormatToken::l_paren,
-                        "expected '(' before argument list")) ||
-      failed(arg = parseElement(RefDirectiveContext)) ||
-      failed(
-          parseToken(FormatToken::r_paren, "expected ')' after argument list")))
-    return failure();
-
-  return create<RefDirective>(*arg);
-}
-
 FailureOr<FormatElement *>
 OpFormatParser::parseRegionsDirective(SMLoc loc, Context context) {
   if (context == TypeDirectiveContext)
@@ -3495,19 +3473,11 @@ FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc,
   return create<TypeDirective>(*operand);
 }
 
-FailureOr<FormatElement *>
-OpFormatParser::parseQualifiedDirective(SMLoc loc, Context context) {
-  FailureOr<FormatElement *> element;
-  if (failed(parseToken(FormatToken::l_paren,
-                        "expected '(' before argument list")) ||
-      failed(element = parseElement(context)) ||
-      failed(
-          parseToken(FormatToken::r_paren, "expected ')' after argument list")))
-    return failure();
-  return TypeSwitch<FormatElement *, FailureOr<FormatElement *>>(*element)
+LogicalResult OpFormatParser::markQualified(SMLoc loc, FormatElement *element) {
+  return TypeSwitch<FormatElement *, LogicalResult>(element)
       .Case<AttributeVariable, TypeDirective>([](auto *element) {
         element->setShouldBeQualified();
-        return element;
+        return success();
       })
       .Default([&](auto *element) {
         return this->emitError(

@zero9178 zero9178 merged commit 3cbfc9d into llvm:main May 4, 2024
@zero9178 zero9178 deleted the format-gen-deduplicate branch May 4, 2024 19:57
sookach pushed a commit to sookach/llvm-project that referenced this pull request May 4, 2024
Both the attribute and type format generator and the op format generator
independently implemented the parsing and verification of the `ref` and
`qualified` directives with little to no differences.

This PR moves the implementation of these into the common `FormatParser`
class to deduplicate the implementations.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants