Skip to content

Commit 0764d6e

Browse files
gflegarcopybara-github
authored andcommitted
Only print out diagnostic messages if an environment variable is set
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228 It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages. PiperOrigin-RevId: 607293980
1 parent 0642934 commit 0764d6e

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

third_party/triton/cl607293980.patch

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
Remove once b/325453581 is fixed
2+
diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp
3+
--- a/include/triton/Tools/Sys/GetEnv.hpp
4+
+++ b/include/triton/Tools/Sys/GetEnv.hpp
5+
@@ -32,7 +32,7 @@ namespace triton {
6+
const std::set<std::string> ENV_VARS = {
7+
"DISABLE_MMA_V3", "TRITON_DISABLE_LINE_INFO", "DISABLE_FAST_REDUCTION",
8+
"ENABLE_TMA", "MLIR_ENABLE_DUMP", "LLVM_IR_ENABLE_DUMP",
9+
- "AMDGCN_ENABLE_DUMP", "DISABLE_LLVM_OPT"};
10+
+ "AMDGCN_ENABLE_DUMP", "DISABLE_LLVM_OPT", "MLIR_ENABLE_DIAGNOSTICS"};
11+
12+
namespace tools {
13+
14+
diff --git a/python/src/ir.cc b/python/src/ir.cc
15+
--- a/python/src/ir.cc
16+
+++ b/python/src/ir.cc
17+
@@ -1551,29 +1551,36 @@ void init_triton_ir(py::module &&m) {
18+
.def("enable_debug",
19+
[](mlir::PassManager &self) {
20+
auto *context = self.getContext();
21+
- context->printOpOnDiagnostic(true);
22+
- context->printStackTraceOnDiagnostic(true);
23+
- context->disableMultithreading();
24+
- context->getDiagEngine().registerHandler(
25+
- [](mlir::Diagnostic &diag) {
26+
- llvm::outs() << diag << "\n";
27+
- return mlir::success();
28+
- });
29+
-
30+
- if (!::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP"))
31+
- return;
32+
- auto printingFlags = mlir::OpPrintingFlags();
33+
- printingFlags.elideLargeElementsAttrs(16);
34+
- printingFlags.enableDebugInfo();
35+
- auto print_always = [](mlir::Pass *, mlir::Operation *) {
36+
- return true;
37+
- };
38+
- self.enableIRPrinting(
39+
- /*shouldPrintBeforePass=*/print_always,
40+
- /*shouldPrintAfterPass=*/print_always,
41+
- /*printModuleScope=*/true,
42+
- /*printAfterOnlyOnChange=*/false,
43+
- /*printAfterOnlyOnFailure*/ true, llvm::dbgs(), printingFlags);
44+
+ bool have_diagnostics =
45+
+ ::triton::tools::getBoolEnv("MLIR_ENABLE_DIAGNOSTICS");
46+
+ bool have_dump = ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP");
47+
+ if (have_diagnostics || have_dump) {
48+
+ context->disableMultithreading();
49+
+ }
50+
+ if (have_diagnostics) {
51+
+ context->printOpOnDiagnostic(true);
52+
+ context->printStackTraceOnDiagnostic(true);
53+
+ context->getDiagEngine().registerHandler(
54+
+ [](mlir::Diagnostic &diag) {
55+
+ llvm::outs() << diag << "\n";
56+
+ return mlir::success();
57+
+ });
58+
+ }
59+
+ if (have_dump) {
60+
+ auto printingFlags = mlir::OpPrintingFlags();
61+
+ printingFlags.elideLargeElementsAttrs(16);
62+
+ printingFlags.enableDebugInfo();
63+
+ auto print_always = [](mlir::Pass *, mlir::Operation *) {
64+
+ return true;
65+
+ };
66+
+ self.enableIRPrinting(
67+
+ /*shouldPrintBeforePass=*/print_always,
68+
+ /*shouldPrintAfterPass=*/print_always,
69+
+ /*printModuleScope=*/true,
70+
+ /*printAfterOnlyOnChange=*/false,
71+
+ /*printAfterOnlyOnFailure*/ true, llvm::dbgs(),
72+
+ printingFlags);
73+
+ }
74+
})
75+
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) {
76+
// TODO: maybe dump module to file and print error for better

third_party/triton/workspace.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@ def repo():
1616
patch_file = [
1717
# Upstream this in the next integrate
1818
#"//third_party/triton:cl602997103.patch"
19+
"//third_party/triton:cl607293980.patch",
1920
],
2021
)

0 commit comments

Comments
 (0)