Skip to content

Commit c08502d

Browse files
authored
[MLIR][Transform] expose transform.debug extension in Python (#145550)
Removes the Debug... prefix on the ops in tablegen, in line with pretty much all other Transform-dialect extension ops. This means that the ops in Python look like `debug.EmitParamAsRemarkOp`/`debug.emit_param_as_remark` instead of `debug.DebugEmitParamAsRemarkOp`/`debug.debug_emit_param_as_remark`.
1 parent 46ee7f1 commit c08502d

File tree

6 files changed

+163
-8
lines changed

6 files changed

+163
-8
lines changed

mlir/include/mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.td"
2020
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
2121
include "mlir/Dialect/Transform/IR/TransformDialect.td"
2222

23-
def DebugEmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at",
23+
def EmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at",
2424
[MatchOpInterface,
2525
DeclareOpInterfaceMethods<TransformOpInterface>,
2626
MemoryEffectsOpInterface, NavigationTransformOpTrait]> {
@@ -39,7 +39,7 @@ def DebugEmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at",
3939
let assemblyFormat = "$at `,` $message attr-dict `:` type($at)";
4040
}
4141

42-
def DebugEmitParamAsRemarkOp
42+
def EmitParamAsRemarkOp
4343
: TransformDialectOp<"debug.emit_param_as_remark",
4444
[MatchOpInterface,
4545
DeclareOpInterfaceMethods<TransformOpInterface>,

mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ using namespace mlir;
1919
#include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp.inc"
2020

2121
DiagnosedSilenceableFailure
22-
transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
23-
transform::TransformResults &results,
24-
transform::TransformState &state) {
22+
transform::EmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
23+
transform::TransformResults &results,
24+
transform::TransformState &state) {
2525
if (isa<TransformHandleTypeInterface>(getAt().getType())) {
2626
auto payload = state.getPayloadOps(getAt());
2727
for (Operation *op : payload)
@@ -52,9 +52,10 @@ transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
5252
return DiagnosedSilenceableFailure::success();
5353
}
5454

55-
DiagnosedSilenceableFailure transform::DebugEmitParamAsRemarkOp::apply(
56-
transform::TransformRewriter &rewriter,
57-
transform::TransformResults &results, transform::TransformState &state) {
55+
DiagnosedSilenceableFailure
56+
transform::EmitParamAsRemarkOp::apply(transform::TransformRewriter &rewriter,
57+
transform::TransformResults &results,
58+
transform::TransformState &state) {
5859
std::string str;
5960
llvm::raw_string_ostream os(str);
6061
if (getMessage())

mlir/python/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
171171
DIALECT_NAME transform
172172
EXTENSION_NAME transform_pdl_extension)
173173

174+
declare_mlir_dialect_extension_python_bindings(
175+
ADD_TO_PARENT MLIRPythonSources.Dialects
176+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
177+
TD_FILE dialects/TransformDebugExtensionOps.td
178+
SOURCES
179+
dialects/transform/debug.py
180+
DIALECT_NAME transform
181+
EXTENSION_NAME transform_debug_extension)
182+
174183
declare_mlir_dialect_python_bindings(
175184
ADD_TO_PARENT MLIRPythonSources.Dialects
176185
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//===-- TransformDebugExtensionOps.td - Binding entry point *- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Entry point of the generated Python bindings for the Debug extension of the
10+
// Transform dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
15+
#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
16+
17+
include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td"
18+
19+
#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from typing import Optional
6+
7+
from ...ir import Attribute, Operation, Value, StringAttr
8+
from .._transform_debug_extension_ops_gen import *
9+
from .._transform_pdl_extension_ops_gen import _Dialect
10+
11+
try:
12+
from .._ods_common import _cext as _ods_cext
13+
except ImportError as e:
14+
raise RuntimeError("Error loading imports from extension module") from e
15+
16+
from typing import Union
17+
18+
19+
@_ods_cext.register_operation(_Dialect, replace=True)
20+
class EmitParamAsRemarkOp(EmitParamAsRemarkOp):
21+
def __init__(
22+
self,
23+
param: Attribute,
24+
*,
25+
anchor: Optional[Operation] = None,
26+
message: Optional[Union[StringAttr, str]] = None,
27+
loc=None,
28+
ip=None,
29+
):
30+
if isinstance(message, str):
31+
message = StringAttr.get(message)
32+
33+
super().__init__(
34+
param,
35+
anchor=anchor,
36+
message=message,
37+
loc=loc,
38+
ip=ip,
39+
)
40+
41+
42+
def emit_param_as_remark(
43+
param: Attribute,
44+
*,
45+
anchor: Optional[Operation] = None,
46+
message: Optional[Union[StringAttr, str]] = None,
47+
loc=None,
48+
ip=None,
49+
):
50+
return EmitParamAsRemarkOp(param, anchor=anchor, message=message, loc=loc, ip=ip)
51+
52+
53+
@_ods_cext.register_operation(_Dialect, replace=True)
54+
class EmitRemarkAtOp(EmitRemarkAtOp):
55+
def __init__(
56+
self,
57+
at: Union[Operation, Value],
58+
message: Optional[Union[StringAttr, str]] = None,
59+
*,
60+
loc=None,
61+
ip=None,
62+
):
63+
if isinstance(message, str):
64+
message = StringAttr.get(message)
65+
66+
super().__init__(
67+
at,
68+
message,
69+
loc=loc,
70+
ip=ip,
71+
)
72+
73+
74+
def emit_remark_at(
75+
at: Union[Operation, Value],
76+
message: Optional[Union[StringAttr, str]] = None,
77+
*,
78+
loc=None,
79+
ip=None,
80+
):
81+
return EmitRemarkAtOp(at, message, loc=loc, ip=ip)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# RUN: %PYTHON %s | FileCheck %s
2+
3+
from mlir.ir import *
4+
from mlir.dialects import transform
5+
from mlir.dialects.transform import debug
6+
7+
8+
def run(f):
9+
print("\nTEST:", f.__name__)
10+
with Context(), Location.unknown():
11+
module = Module.create()
12+
with InsertionPoint(module.body):
13+
sequence = transform.SequenceOp(
14+
transform.FailurePropagationMode.Propagate,
15+
[],
16+
transform.AnyOpType.get(),
17+
)
18+
with InsertionPoint(sequence.body):
19+
f(sequence.bodyTarget)
20+
transform.YieldOp()
21+
print(module)
22+
return f
23+
24+
25+
@run
26+
def testDebugEmitParamAsRemark(target):
27+
i0 = IntegerAttr.get(IntegerType.get_signless(32), 0)
28+
i0_param = transform.ParamConstantOp(transform.AnyParamType.get(), i0)
29+
debug.emit_param_as_remark(i0_param)
30+
debug.emit_param_as_remark(i0_param, anchor=target, message="some text")
31+
# CHECK-LABEL: TEST: testDebugEmitParamAsRemark
32+
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
33+
# CHECK: %[[PARAM:.*]] = transform.param.constant
34+
# CHECK: transform.debug.emit_param_as_remark %[[PARAM]]
35+
# CHECK: transform.debug.emit_param_as_remark %[[PARAM]]
36+
# CHECK-SAME: "some text"
37+
# CHECK-SAME: at %[[ARG0]]
38+
39+
40+
@run
41+
def testDebugEmitRemarkAtOp(target):
42+
debug.emit_remark_at(target, "some text")
43+
# CHECK-LABEL: TEST: testDebugEmitRemarkAtOp
44+
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
45+
# CHECK: transform.debug.emit_remark_at %[[ARG0]], "some text"

0 commit comments

Comments
 (0)