Skip to content

[mlir] Better Python diagnostics #128581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 10, 2025
Merged

Conversation

nikalra
Copy link
Contributor

@nikalra nikalra commented Feb 24, 2025

Updated the Python diagnostics handler to emit notes (in addition to errors) into the output stream so that users have more context as to where in the IR the error is occurring.

@llvmbot
Copy link
Member

llvmbot commented Feb 24, 2025

@llvm/pr-subscribers-mlir

Author: Nikhil Kalra (nikalra)

Changes

Updated the Python diagnostics handler to emit notes (in addition to errors) into the output stream so that users have more context as to where in the IR the error is occurring.

To test this, I also updated the CAPI with an option to set printStackTraceOnDiagnostic so that notes are available in the diagnostic for the Python test.


Full diff: https://github.com/llvm/llvm-project/pull/128581.diff

5 Files Affected:

  • (modified) mlir/include/mlir-c/IR.h (+5)
  • (modified) mlir/include/mlir/Bindings/Python/Diagnostics.h (+20-11)
  • (modified) mlir/lib/CAPI/IR/IR.cpp (+4)
  • (modified) mlir/test/python/ir/diagnostic_handler.py (+14)
  • (modified) mlir/test/python/lib/PythonTestModuleNanobind.cpp (+12)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 14ccae650606a..f661e90105704 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -162,6 +162,11 @@ MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context,
 MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context,
                                                  MlirLlvmThreadPool threadPool);
 
+/// Sets the context to attach the stack trace for the source code location at
+/// which a diagnostic is emitted.
+MLIR_CAPI_EXPORTED void
+mlirContextPrintStackTraceOnDiagnostic(MlirContext context, bool enable);
+
 //===----------------------------------------------------------------------===//
 // Dialect API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Bindings/Python/Diagnostics.h b/mlir/include/mlir/Bindings/Python/Diagnostics.h
index ea80e14dde0f3..4f9be844dc1ac 100644
--- a/mlir/include/mlir/Bindings/Python/Diagnostics.h
+++ b/mlir/include/mlir/Bindings/Python/Diagnostics.h
@@ -9,12 +9,13 @@
 #ifndef MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
 #define MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
 
-#include <cassert>
-#include <string>
-
 #include "mlir-c/Diagnostics.h"
 #include "mlir-c/IR.h"
-#include "llvm/ADT/StringRef.h"
+
+#include <cassert>
+#include <cstdint>
+#include <sstream>
+#include <string>
 
 namespace mlir {
 namespace python {
@@ -28,29 +29,37 @@ class CollectDiagnosticsToStringScope {
                                                    /*deleteUserData=*/nullptr);
   }
   ~CollectDiagnosticsToStringScope() {
-    assert(errorMessage.empty() && "unchecked error message");
     mlirContextDetachDiagnosticHandler(context, handlerID);
   }
 
-  [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
+  [[nodiscard]] std::string takeMessage() {
+    std::ostringstream stream;
+    std::swap(stream, errorMessage);
+    return stream.str();
+  }
 
 private:
   static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
     auto printer = +[](MlirStringRef message, void *data) {
-      *static_cast<std::string *>(data) +=
-          llvm::StringRef(message.data, message.length);
+      *static_cast<std::ostringstream *>(data)
+          << std::string_view(message.data, message.length);
     };
     MlirLocation loc = mlirDiagnosticGetLocation(diag);
-    *static_cast<std::string *>(data) += "at ";
+    *static_cast<std::ostringstream *>(data) << "at ";
     mlirLocationPrint(loc, printer, data);
-    *static_cast<std::string *>(data) += ": ";
+    *static_cast<std::ostringstream *>(data) << ": ";
     mlirDiagnosticPrint(diag, printer, data);
+    for (intptr_t i = 0; i < mlirDiagnosticGetNumNotes(diag); i++) {
+      *static_cast<std::ostringstream *>(data) << "\n";
+      MlirDiagnostic note = mlirDiagnosticGetNote(diag, i);
+      handler(note, data);
+    }
     return mlirLogicalResultSuccess();
   }
 
   MlirContext context;
   MlirDiagnosticHandlerID handlerID;
-  std::string errorMessage = "";
+  std::ostringstream errorMessage;
 };
 
 } // namespace python
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 999e8cbda1295..2249519ad4eef 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -114,6 +114,10 @@ void mlirContextSetThreadPool(MlirContext context,
   unwrap(context)->setThreadPool(*unwrap(threadPool));
 }
 
+void mlirContextPrintStackTraceOnDiagnostic(MlirContext context, bool enable) {
+  unwrap(context)->printStackTraceOnDiagnostic(enable);
+}
+
 //===----------------------------------------------------------------------===//
 // Dialect API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/diagnostic_handler.py b/mlir/test/python/ir/diagnostic_handler.py
index d516cda819897..5f6696850682a 100644
--- a/mlir/test/python/ir/diagnostic_handler.py
+++ b/mlir/test/python/ir/diagnostic_handler.py
@@ -2,6 +2,7 @@
 
 import gc
 from mlir.ir import *
+from mlir._mlir_libs._mlirPythonTestNanobind import test_diagnostics_with_errors_and_notes
 
 
 def run(f):
@@ -222,3 +223,16 @@ def callback2(d):
     # CHECK: CALLBACK2: foobar
     # CHECK: CALLBACK1: foobar
     loc.emit_error("foobar")
+
+# CHECK-LABEL: TEST: testBuiltInDiagnosticsHandler
+@run
+def testBuiltInDiagnosticsHandler():
+    ctx = Context()
+
+    try:
+        test_diagnostics_with_errors_and_notes(ctx)
+    except ValueError as e:
+        # CHECK: created error
+        # CHECK: MLIRPythonCAPI
+        print(e)
+
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index 99c81eae97a0c..daf3b4602b367 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -11,9 +11,12 @@
 #include "PythonTestCAPI.h"
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
+#include "mlir-c/Diagnostics.h"
 #include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/Diagnostics.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "nanobind/nanobind.h"
 
 namespace nb = nanobind;
 using namespace mlir::python::nanobind_adaptors;
@@ -45,6 +48,15 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
       },
       nb::arg("registry"));
 
+  m.def("test_diagnostics_with_errors_and_notes", [](MlirContext ctx) {
+    mlirContextPrintStackTraceOnDiagnostic(ctx, true);
+    mlir::python::CollectDiagnosticsToStringScope handler(ctx);
+
+    auto loc = mlirLocationUnknownGet(ctx);
+    mlirEmitError(loc, "created error");
+    throw nb::value_error(handler.takeMessage().c_str());
+  });
+
   mlir_attribute_subclass(m, "TestAttr",
                           mlirAttributeIsAPythonTestTestAttribute,
                           mlirPythonTestTestAttributeGetTypeID)

Copy link

github-actions bot commented Feb 24, 2025

✅ With the latest revision this PR passed the Python code formatter.

@nikalra
Copy link
Contributor Author

nikalra commented Mar 3, 2025

Pinging for reviews :)

@makslevental
Copy link
Contributor

Sorry I meant to do this.

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Only comment is I believe we're able to use llvm::*stream types here (in the bindings) but it's fine - I doubt there's a huge difference std::ostream and llvm::ostream (or whatever it's called).

@nikalra nikalra merged commit b15ccd4 into llvm:main Mar 10, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants