-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Better Python diagnostics #128581
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Better Python diagnostics #128581
Conversation
@llvm/pr-subscribers-mlir Author: Nikhil Kalra (nikalra) ChangesUpdated the Python diagnostics handler to emit notes (in addition to errors) into the output stream so that users have more context as to where in the IR the error is occurring. To test this, I also updated the CAPI with an option to set Full diff: https://github.com/llvm/llvm-project/pull/128581.diff 5 Files Affected:
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 14ccae650606a..f661e90105704 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -162,6 +162,11 @@ MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context,
MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context,
MlirLlvmThreadPool threadPool);
+/// Sets the context to attach the stack trace for the source code location at
+/// which a diagnostic is emitted.
+MLIR_CAPI_EXPORTED void
+mlirContextPrintStackTraceOnDiagnostic(MlirContext context, bool enable);
+
//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Bindings/Python/Diagnostics.h b/mlir/include/mlir/Bindings/Python/Diagnostics.h
index ea80e14dde0f3..4f9be844dc1ac 100644
--- a/mlir/include/mlir/Bindings/Python/Diagnostics.h
+++ b/mlir/include/mlir/Bindings/Python/Diagnostics.h
@@ -9,12 +9,13 @@
#ifndef MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
#define MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
-#include <cassert>
-#include <string>
-
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
-#include "llvm/ADT/StringRef.h"
+
+#include <cassert>
+#include <cstdint>
+#include <sstream>
+#include <string>
namespace mlir {
namespace python {
@@ -28,29 +29,37 @@ class CollectDiagnosticsToStringScope {
/*deleteUserData=*/nullptr);
}
~CollectDiagnosticsToStringScope() {
- assert(errorMessage.empty() && "unchecked error message");
mlirContextDetachDiagnosticHandler(context, handlerID);
}
- [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
+ [[nodiscard]] std::string takeMessage() {
+ std::ostringstream stream;
+ std::swap(stream, errorMessage);
+ return stream.str();
+ }
private:
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
auto printer = +[](MlirStringRef message, void *data) {
- *static_cast<std::string *>(data) +=
- llvm::StringRef(message.data, message.length);
+ *static_cast<std::ostringstream *>(data)
+ << std::string_view(message.data, message.length);
};
MlirLocation loc = mlirDiagnosticGetLocation(diag);
- *static_cast<std::string *>(data) += "at ";
+ *static_cast<std::ostringstream *>(data) << "at ";
mlirLocationPrint(loc, printer, data);
- *static_cast<std::string *>(data) += ": ";
+ *static_cast<std::ostringstream *>(data) << ": ";
mlirDiagnosticPrint(diag, printer, data);
+ for (intptr_t i = 0; i < mlirDiagnosticGetNumNotes(diag); i++) {
+ *static_cast<std::ostringstream *>(data) << "\n";
+ MlirDiagnostic note = mlirDiagnosticGetNote(diag, i);
+ handler(note, data);
+ }
return mlirLogicalResultSuccess();
}
MlirContext context;
MlirDiagnosticHandlerID handlerID;
- std::string errorMessage = "";
+ std::ostringstream errorMessage;
};
} // namespace python
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 999e8cbda1295..2249519ad4eef 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -114,6 +114,10 @@ void mlirContextSetThreadPool(MlirContext context,
unwrap(context)->setThreadPool(*unwrap(threadPool));
}
+void mlirContextPrintStackTraceOnDiagnostic(MlirContext context, bool enable) {
+ unwrap(context)->printStackTraceOnDiagnostic(enable);
+}
+
//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/diagnostic_handler.py b/mlir/test/python/ir/diagnostic_handler.py
index d516cda819897..5f6696850682a 100644
--- a/mlir/test/python/ir/diagnostic_handler.py
+++ b/mlir/test/python/ir/diagnostic_handler.py
@@ -2,6 +2,7 @@
import gc
from mlir.ir import *
+from mlir._mlir_libs._mlirPythonTestNanobind import test_diagnostics_with_errors_and_notes
def run(f):
@@ -222,3 +223,16 @@ def callback2(d):
# CHECK: CALLBACK2: foobar
# CHECK: CALLBACK1: foobar
loc.emit_error("foobar")
+
+# CHECK-LABEL: TEST: testBuiltInDiagnosticsHandler
+@run
+def testBuiltInDiagnosticsHandler():
+ ctx = Context()
+
+ try:
+ test_diagnostics_with_errors_and_notes(ctx)
+ except ValueError as e:
+ # CHECK: created error
+ # CHECK: MLIRPythonCAPI
+ print(e)
+
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index 99c81eae97a0c..daf3b4602b367 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -11,9 +11,12 @@
#include "PythonTestCAPI.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
+#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/Diagnostics.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "nanobind/nanobind.h"
namespace nb = nanobind;
using namespace mlir::python::nanobind_adaptors;
@@ -45,6 +48,15 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
},
nb::arg("registry"));
+ m.def("test_diagnostics_with_errors_and_notes", [](MlirContext ctx) {
+ mlirContextPrintStackTraceOnDiagnostic(ctx, true);
+ mlir::python::CollectDiagnosticsToStringScope handler(ctx);
+
+ auto loc = mlirLocationUnknownGet(ctx);
+ mlirEmitError(loc, "created error");
+ throw nb::value_error(handler.takeMessage().c_str());
+ });
+
mlir_attribute_subclass(m, "TestAttr",
mlirAttributeIsAPythonTestTestAttribute,
mlirPythonTestTestAttributeGetTypeID)
|
✅ With the latest revision this PR passed the Python code formatter. |
Pinging for reviews :) |
Sorry I meant to do this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Only comment is I believe we're able to use llvm::*stream
types here (in the bindings) but it's fine - I doubt there's a huge difference std::ostream
and llvm::ostream
(or whatever it's called).
Updated the Python diagnostics handler to emit notes (in addition to errors) into the output stream so that users have more context as to where in the IR the error is occurring.