Skip to content

Commit b15ccd4

Browse files
authored
[mlir] Better Python diagnostics (#128581)
Updated the Python diagnostics handler to emit notes (in addition to errors) into the output stream so that users have more context as to where in the IR the error is occurring.
1 parent 9b066f0 commit b15ccd4

File tree

5 files changed

+63
-13
lines changed

5 files changed

+63
-13
lines changed

mlir/include/mlir/Bindings/Python/Diagnostics.h

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
#ifndef MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
1010
#define MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
1111

12-
#include <cassert>
13-
#include <string>
14-
1512
#include "mlir-c/Diagnostics.h"
1613
#include "mlir-c/IR.h"
17-
#include "llvm/ADT/StringRef.h"
14+
#include "llvm/Support/raw_ostream.h"
15+
16+
#include <cassert>
17+
#include <cstdint>
18+
#include <string>
1819

1920
namespace mlir {
2021
namespace python {
@@ -24,33 +25,45 @@ namespace python {
2425
class CollectDiagnosticsToStringScope {
2526
public:
2627
explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
27-
handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
28-
/*deleteUserData=*/nullptr);
28+
handlerID =
29+
mlirContextAttachDiagnosticHandler(ctx, &handler, &messageStream,
30+
/*deleteUserData=*/nullptr);
2931
}
3032
~CollectDiagnosticsToStringScope() {
31-
assert(errorMessage.empty() && "unchecked error message");
33+
assert(message.empty() && "unchecked error message");
3234
mlirContextDetachDiagnosticHandler(context, handlerID);
3335
}
3436

35-
[[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
37+
[[nodiscard]] std::string takeMessage() {
38+
std::string newMessage;
39+
std::swap(message, newMessage);
40+
return newMessage;
41+
}
3642

3743
private:
3844
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
3945
auto printer = +[](MlirStringRef message, void *data) {
40-
*static_cast<std::string *>(data) +=
41-
llvm::StringRef(message.data, message.length);
46+
*static_cast<llvm::raw_string_ostream *>(data)
47+
<< std::string_view(message.data, message.length);
4248
};
4349
MlirLocation loc = mlirDiagnosticGetLocation(diag);
44-
*static_cast<std::string *>(data) += "at ";
50+
*static_cast<llvm::raw_string_ostream *>(data) << "at ";
4551
mlirLocationPrint(loc, printer, data);
46-
*static_cast<std::string *>(data) += ": ";
52+
*static_cast<llvm::raw_string_ostream *>(data) << ": ";
4753
mlirDiagnosticPrint(diag, printer, data);
54+
for (intptr_t i = 0; i < mlirDiagnosticGetNumNotes(diag); i++) {
55+
*static_cast<llvm::raw_string_ostream *>(data) << "\n";
56+
MlirDiagnostic note = mlirDiagnosticGetNote(diag, i);
57+
handler(note, data);
58+
}
4859
return mlirLogicalResultSuccess();
4960
}
5061

5162
MlirContext context;
5263
MlirDiagnosticHandlerID handlerID;
53-
std::string errorMessage = "";
64+
65+
std::string message;
66+
llvm::raw_string_ostream messageStream{message};
5467
};
5568

5669
} // namespace python

mlir/test/python/ir/diagnostic_handler.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
import gc
44
from mlir.ir import *
5+
from mlir._mlir_libs._mlirPythonTestNanobind import (
6+
test_diagnostics_with_errors_and_notes,
7+
)
58

69

710
def run(f):
@@ -222,3 +225,16 @@ def callback2(d):
222225
# CHECK: CALLBACK2: foobar
223226
# CHECK: CALLBACK1: foobar
224227
loc.emit_error("foobar")
228+
229+
230+
# CHECK-LABEL: TEST: testBuiltInDiagnosticsHandler
231+
@run
232+
def testBuiltInDiagnosticsHandler():
233+
ctx = Context()
234+
235+
try:
236+
test_diagnostics_with_errors_and_notes(ctx)
237+
except ValueError as e:
238+
# CHECK: created error
239+
# CHECK: attached note
240+
print(e)

mlir/test/python/lib/PythonTestCAPI.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include "mlir-c/BuiltinTypes.h"
1212
#include "mlir/CAPI/Registration.h"
1313
#include "mlir/CAPI/Wrap.h"
14+
#include "mlir/IR/Diagnostics.h"
15+
#include "mlir/IR/Location.h"
1416

1517
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test,
1618
python_test::PythonTestDialect)
@@ -42,3 +44,9 @@ MlirTypeID mlirPythonTestTestTypeGetTypeID(void) {
4244
bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value) {
4345
return mlirTypeIsATensor(wrap(unwrap(value).getType()));
4446
}
47+
48+
void mlirPythonTestEmitDiagnosticWithNote(MlirContext ctx) {
49+
auto diag =
50+
mlir::emitError(unwrap(mlirLocationUnknownGet(ctx)), "created error");
51+
diag.attachNote() << "attached note";
52+
}

mlir/test/python/lib/PythonTestCAPI.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_TEST_PYTHON_LIB_PYTHONTESTCAPI_H
1111

1212
#include "mlir-c/IR.h"
13+
#include "mlir-c/Support.h"
1314

1415
#ifdef __cplusplus
1516
extern "C" {
@@ -33,6 +34,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirPythonTestTestTypeGetTypeID(void);
3334

3435
MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value);
3536

37+
MLIR_CAPI_EXPORTED void mlirPythonTestEmitDiagnosticWithNote(MlirContext ctx);
38+
3639
#ifdef __cplusplus
3740
}
3841
#endif

mlir/test/python/lib/PythonTestModuleNanobind.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111
#include "PythonTestCAPI.h"
1212
#include "mlir-c/BuiltinAttributes.h"
1313
#include "mlir-c/BuiltinTypes.h"
14+
#include "mlir-c/Diagnostics.h"
1415
#include "mlir-c/IR.h"
16+
#include "mlir/Bindings/Python/Diagnostics.h"
1517
#include "mlir/Bindings/Python/Nanobind.h"
1618
#include "mlir/Bindings/Python/NanobindAdaptors.h"
19+
#include "nanobind/nanobind.h"
1720

1821
namespace nb = nanobind;
1922
using namespace mlir::python::nanobind_adaptors;
@@ -45,6 +48,13 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
4548
},
4649
nb::arg("registry"));
4750

51+
m.def("test_diagnostics_with_errors_and_notes", [](MlirContext ctx) {
52+
mlir::python::CollectDiagnosticsToStringScope handler(ctx);
53+
54+
mlirPythonTestEmitDiagnosticWithNote(ctx);
55+
throw nb::value_error(handler.takeMessage().c_str());
56+
});
57+
4858
mlir_attribute_subclass(m, "TestAttr",
4959
mlirAttributeIsAPythonTestTestAttribute,
5060
mlirPythonTestTestAttributeGetTypeID)

0 commit comments

Comments
 (0)