Skip to content

[mlir] load dialects for non-namespaced attrs #94838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from

Conversation

j2kun
Copy link
Contributor

@j2kun j2kun commented Jun 8, 2024

The mlir-translate tool calls into the parser without loading registered dependent dialects, and the parser only loads attributes if the fully-namespaced attribute is present in the textual IR. This causes parsing to break when an op has an attribute that prints/parses without the namespaced attribute.

@j2kun j2kun marked this pull request as draft June 8, 2024 04:32
@llvmbot
Copy link
Member

llvmbot commented Jun 8, 2024

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Jeremy Kun (j2kun)

Changes

The mlir-translate tool calls into the parser without loading registered dependent dialects, and the parser only loads attributes if the fully-namespaced attribute is present in the textual IR. This causes parsing to break when an op has an attribute that prints/parses without the namespaced attribute.


Full diff: https://github.com/llvm/llvm-project/pull/94838.diff

4 Files Affected:

  • (modified) mlir/test/Target/LLVMIR/test.mlir (+15)
  • (modified) mlir/test/lib/Dialect/Test/CMakeLists.txt (+1)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+7-1)
  • (modified) mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp (+2)
diff --git a/mlir/test/Target/LLVMIR/test.mlir b/mlir/test/Target/LLVMIR/test.mlir
index 0ab1b7267d959..b21254b9f3464 100644
--- a/mlir/test/Target/LLVMIR/test.mlir
+++ b/mlir/test/Target/LLVMIR/test.mlir
@@ -40,3 +40,18 @@ llvm.func @dialect_attr_translation_multi(%a: i64, %b: i64, %c: i64) -> i64 {
 // CHECK-DAG: ![[MD_ID_ADD]] = !{!"annotation_from_test: add"}
 // CHECK-DAG: ![[MD_ID_MUL]] = !{!"annotation_from_test: mul"}
 // CHECK-DAG: ![[MD_ID_RET]] = !{!"annotation_from_test: ret"}
+
+
+// -----
+
+// This is a regression test for a bug where, during an mlir-translate call the
+// parser would only load the dialect if the fully namespaced attribute was
+// present in the IR.
+func.func @parse_correctly() {
+  "test.containing_int_polynomial_attr"() {
+    // CHECK: <1 + x**2>
+    attr = <1 + x**2>
+  } : () -> ()
+  return
+}
+
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index fab8937809332..e50e17741e39e 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -86,6 +86,7 @@ add_mlir_library(MLIRTestDialect
   MLIRLinalgTransforms
   MLIRLLVMDialect
   MLIRPass
+  MLIRPolynomialDialect
   MLIRReduce
   MLIRTensorDialect
   MLIRTransformUtils
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9d7e0a7928ab8..0db092aed4c22 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -13,6 +13,7 @@ include "TestDialect.td"
 include "TestInterfaces.td"
 include "mlir/Dialect/DLTI/DLTIBase.td"
 include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
+include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.td"
 include "mlir/IR/EnumAttr.td"
 include "mlir/Interfaces/FunctionInterfaces.td"
 include "mlir/IR/OpBase.td"
@@ -232,6 +233,11 @@ def FloatElementsAttrOp : TEST_Op<"float_elements_attr"> {
   );
 }
 
+def ContainingIntPolynomialAttrOp : TEST_Op<"containing_int_polynomial_attr"> {
+  let arguments = (ins Polynomial_IntPolynomialAttr:$attr);
+  let assemblyFormat = "$attr attr-dict";
+}
+
 // A pattern that updates dense<[3.0, 4.0]> to dense<[5.0, 6.0]>.
 // This tests both matching and generating float elements attributes.
 def UpdateFloatElementsAttr : Pat<
@@ -2204,7 +2210,7 @@ def ForwardBufferOp : TEST_Op<"forward_buffer", [Pure]> {
 def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> {
   let description = [{
     Reify a bound for the given index-typed value or dimension size of a shaped
-    value. "LB", "EQ" and "UB" bounds are supported. If `scalable` is set, 
+    value. "LB", "EQ" and "UB" bounds are supported. If `scalable` is set,
     `vscale_min` and `vscale_max` must be provided, which allows computing
     a bound in terms of "vector.vscale" for a given range of vscale.
   }];
diff --git a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
index 57e7d658fb501..c4e5d259adc24 100644
--- a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
+++ b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
@@ -15,6 +15,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/Dialect/Polynomial/IR/PolynomialDialect.h"
 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
@@ -127,6 +128,7 @@ void registerTestToLLVMIR() {
       },
       [](DialectRegistry &registry) {
         registry.insert<test::TestDialect>();
+        registry.insert<polynomial::PolynomialDialect>();
         registerBuiltinDialectTranslation(registry);
         registerLLVMDialectTranslation(registry);
         registry.addExtension(

Copy link

github-actions bot commented Jun 8, 2024

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

Copy link

github-actions bot commented Jun 8, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@j2kun j2kun force-pushed the attribute-parser-fix branch 3 times, most recently from 8537219 to 9f8f0d7 Compare June 8, 2024 05:51
@j2kun
Copy link
Contributor Author

j2kun commented Jun 8, 2024

As of now this PR just reproduces the bug. Will work on a fix soon.

The mlir-translate tool calls into the parser without loading registered
dependent dialects, and the parser only loads attributes if the
fully-namespaced attribute is present in the textual IR. This causes
parsing to break when an op has an attribute that prints/parses
without the namespaced attribute.
@j2kun
Copy link
Contributor Author

j2kun commented Jun 12, 2024

Hi @ftynse @joker-eph I was able to reproduce this segfault and found a fix, but I'm not sure if it's the right fix, and I'm also not sure if attaching this to the LLVM translation is the best way to fix it. Could you please advise?

@j2kun j2kun marked this pull request as ready for review June 12, 2024 20:59
@llvmbot llvmbot added the mlir:core MLIR Core Infrastructure label Jun 12, 2024
@j2kun j2kun requested review from ftynse and joker-eph June 12, 2024 21:07
// Ensure all registered dialects are loaded
for (const auto &dialectName : registry.getDialectNames()) {
context->getOrLoadDialect(dialectName);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That does not seem like the right fix indeed, the parser should load on demand.
I need to reproduce to figure out how we get to crash.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR contains a reproducer. Remove those lines and the test dialect to LLVM translation will segfault.

@ftynse
Copy link
Member

ftynse commented Jun 14, 2024

Yeah, we shouldn't just load all available dialects. This defies the point of separating registration from loading.

It is unclear to me why #attr = #test.nested_polynomial<<1 + x**2>> is the accepted syntax. Looks ambiguous. From looking at it in isolation, it I expect that this is just a NestedPolynomialAttr from the test dialect with <1 + x**2> being its syntax. IMO, we shouldn't elide the attribute mnemonic nor the dialect prefix here at all (why/where do we?) for the nested attribute. We could elide the dialect prefix if the attribute is immediately nested in another attribute from the same dialect, but the rest sounds ambiguous to me.

@joker-eph
Copy link
Collaborator

joker-eph commented Jun 14, 2024

Why do you think it is ambiguous? The content of this attribute is always a Polynomial_IntPolynomialAttr, which removes all ambiguity as far as I can tell.
I believe that the problem here is more one of the generated parsed not loading the dialect on the fly here before calling poly::IntPolynomialAttr::parse.

@j2kun
Copy link
Contributor Author

j2kun commented Jun 16, 2024

generated parser not loading the dialect on the fly

@joker-eph could you point me to where this parser is generated? Should it be loading the dialect in the parser for IntPolynomialAttr, or for the container attribute?

I tried this and it didn't seem to do anything

--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -711,12 +711,16 @@ void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) {
 /// {0}: the dialect fully qualified class name.
 /// {1}: the optional code for the dynamic attribute parser dispatch.
 /// {2}: the optional code for the dynamic attribute printer dispatch.
 static const char *const dialectDefaultAttrPrinterParserDispatch = R"(
 /// Parse an attribute registered to this dialect.
 ::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser,
                                       ::mlir::Type type) const {{
   ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
   ::llvm::StringRef attrTag;
+  getContext()->getOrLoadDialect<{0}>();
   {{
     ::mlir::Attribute attr;
     auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);

@j2kun
Copy link
Contributor Author

j2kun commented Jun 18, 2024

Logging my attempts


Changing FieldParser::parse for a custom attribute:

diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h
index 1e4f7f787a1..9d51994b012 100644
--- a/mlir/include/mlir/IR/DialectImplementation.h
+++ b/mlir/include/mlir/IR/DialectImplementation.h
@@ -64,6 +64,11 @@ struct FieldParser<
                                  AttributeT>> {
   static FailureOr<AttributeT> parse(AsmParser &parser) {
     AttributeT value;
+
+    SmallVector<StringRef, 2> vec;
+    StringRef(AttributeT::name).split(vec, '.');
+    parser.getContext()->getOrLoadDialect(vec[0]);
+
     if (parser.parseCustomAttributeWithFallback(value))
       return failure();
     return value;

Doesn't work because not all attributes have the name static field. Also somewhat messy with the string parsing. Similar obstacle to doing this lower in the call stack in parseCustomAttributeWithFallback


Changing DefFormat::genParser doesn't seem to work because:

  1. Loading the context object's dialect doesn't work because the example in this PR is a custom parser so it doesn't get a genParser'ed parse method. In this case the test dialect is forced to load, but not polynomial.
  2. Adding it when iterating over fields to parse doesn't work because, as far as I can tell, AttrOrTypeParameter does not have access to the dialect object or the dialect name, even though it logically comes from an AttrOrTypeDef, which does.
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -166,6 +166,7 @@ static const char *const parserErrorStr =
 /// {4}: C++ class of the parameter.
 static const char *const variableParser = R"(
 // Parse variable '{0}'
+odsParser.getContext()->getOrLoadDialect("{5}");
 _result_{0} = {1};
 if (::mlir::failed(_result_{0})) {{
   {2}"failed to parse {3} parameter '{0}' which is to be a `{4}`");
@@ -291,6 +292,10 @@ void DefFormat::genParser(MethodBody &os) {
   os.indent();
   os << "::mlir::Builder odsBuilder(odsParser.getContext());\n";

+  // Ensure the dialect is loaded
+  os << "odsParser.getContext()->getOrLoadDialect(\""
+     << def.getDialect().getName() << "\");\n"; // loads "test" dialect here
+
   // Store the initial location of the parser.
   ctx.addSubst("_loc", "odsLoc");
   os << tgfmt("::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n"
@@ -411,9 +416,13 @@ void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
   auto customParser = param.getParser();
   auto parser =
       customParser ? *customParser : StringRef(defaultParameterParser);
+
+  StringRef dialectName = ???;
+
   os << formatv(variableParser, param.getName(),
                 tgfmt(parser, &ctx, param.getCppStorageType()),
-                tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType());
+                tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType(),
+                dialectName);
 }

@j2kun
Copy link
Contributor Author

j2kun commented Jun 18, 2024

After more failed attempts, I'm starting to be convinced that fixing this will require some nontrivial changes to mlir-tblgen to give access to attribute parameters' dialect names at the time the parser is generated.

Either that or force the attribute to have a fully-qualified name even when unambiguous, which I think is the simplest approach. In particular, this came up because mlir-opt chose to write the output IR without the namespace, and then that output was piped to mlir-translate.

@joker-eph
Copy link
Collaborator

I'm starting to be convinced that fixing this will require some nontrivial changes to mlir-tblgen to give access to attribute parameters' dialect names at the time the parser is generated.

Regardless how non-trivial the changes are, they are actually needed here.

@j2kun
Copy link
Contributor Author

j2kun commented Jun 18, 2024

I'm wondering if there is a small workaround I could implement in the meantime, as this breaks some of the out of tree stuff I've been doing. Would you be open to me changing the polynomial dialect tablegen so that the nested polynomial attributes are qualified? (They use the struct(params) directive, so I would have to change that to list the parameters explicitly inside struct so as to mark them qualified)

@joker-eph
Copy link
Collaborator

Can you workaround this downstream by preloading the polynomial dialect?

@j2kun
Copy link
Contributor Author

j2kun commented Jun 18, 2024

Can you workaround this downstream by preloading the polynomial dialect?

MLIRTranslateMain doesn't give you access to the context, as far as I can tell, so I can't load it.

@joker-eph
Copy link
Collaborator

MLIRTranslateMain does not, but the callback should (that is the translate function you're registering)

@j2kun
Copy link
Contributor Author

j2kun commented Jun 18, 2024

MLIRTranslateMain does not, but the callback should (that is the translate function you're registering)

I don't think the DialectRegistrationFunction has the context, and by the time you get to the TranslateFromMLIRFunction, the parsing has already happened.

Cf.

TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
StringRef name, StringRef description,
const TranslateFromMLIRFunction &function,
const DialectRegistrationFunction &dialectRegistration) {
registerTranslation(
name, description, /*inputAlignment=*/std::nullopt,
[function,
dialectRegistration](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
raw_ostream &output, MLIRContext *context) {
DialectRegistry registry;
dialectRegistration(registry);
context->appendDialectRegistry(registry);
bool implicitModule =
(!clOptions.isConstructed() || !clOptions->noImplicitModule);
OwningOpRef<Operation *> op =
parseSourceFileForTool(sourceMgr, context, implicitModule);
if (!op || failed(verify(*op)))
return failure();
return function(op.get(), output);
});

Unless you mean I should be calling registerTranslation with all that boilerplate repeated? (edit: nvm, registerTranslation is not publicly exposed)

Perhaps I could add another overload here that passes the context to the registration function?

@j2kun
Copy link
Contributor Author

j2kun commented Jun 20, 2024

As an intermediate workaround, what about providing the MLIRContext as an optional parameter to mlirTranslateMain?

diff --git a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
index bd9928950ec..92313d679f8 100644
--- a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
+++ b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
@@ -47,7 +47,8 @@ public:
 //===----------------------------------------------------------------------===//

 LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
-                                      llvm::StringRef toolName) {
+                                      llvm::StringRef toolName,
+                                      MLIRContext *userProvidedContext) {

   static llvm::cl::opt<std::string> inputFilename(
       llvm::cl::Positional, llvm::cl::desc("<input file>"),
@@ -148,9 +149,12 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
       TimingScope translationTiming =
           timing.nest(translationRequested->getDescription());

-      MLIRContext context;
-      context.allowUnregisteredDialects(allowUnregisteredDialects);
-      context.printOpOnDiagnostic(!verifyDiagnostics);
+      MLIRContext defaultContext;
+      MLIRContext *context = &defaultContext;
+      if (userProvidedContext)
+        context = userProvidedContext;
+      context->allowUnregisteredDialects(allowUnregisteredDialects);
+      context->printOpOnDiagnostic(!verifyDiagnostics);
       auto sourceMgr = std::make_shared<llvm::SourceMgr>();
       sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());

I could also pass in a callable that generates a fresh context, since I noticed a new context is created on each iteration of the split buffer loop.

@joker-eph
Copy link
Collaborator

I am just not sure why we would need any workaround instead of a principled fix really. I am not convinced there is anything really difficult here yet.

@j2kun
Copy link
Contributor Author

j2kun commented Jun 20, 2024

I am just not sure why we would need any workaround instead of a principled fix really. I am not convinced there is anything really difficult here yet.

I don't know what the principled fix is. I think mlir-tblgen needs more information at the time it's generating these parsers, but I don't know how to get it that information. Maybe you could sketch the solution you have in mind?

@joker-eph
Copy link
Collaborator

I am just not sure why we would need any workaround instead of a principled fix really. I am not convinced there is anything really difficult here yet.

I don't know what the principled fix is. I think mlir-tblgen needs more information at the time it's generating these parsers, but I don't know how to get it that information. Maybe you could sketch the solution you have in mind?

I just give it a quick try here: #96242
I didn't have time to test it yet though (it fixes your example), but I'm located in Europe so my day is kind of ending now, I'll pick this up maybe tomorrow.

@j2kun j2kun closed this Jun 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:llvm mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants