Skip to content

Commit ee8f9c2

Browse files
committed
add in-depth documentation to insert_transform_script
1 parent 86fb474 commit ee8f9c2

File tree

2 files changed

+27
-11
lines changed

2 files changed

+27
-11
lines changed

mlir/python/mlir/dialects/transform/extras/__init__.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,35 @@ def match_ops(
8282
return handle
8383

8484

85-
ValueT = TypeVar("ValueT", bound=Value)
86-
87-
8885
def insert_transform_script(
8986
module: ir.Module,
90-
script: Callable[[ValueT], None],
87+
script: Callable[[OpHandle], None],
9188
dump_script: bool = False,
9289
) -> None:
93-
"""Inserts the transform script of the schedule into the module."""
90+
"""
91+
Inserts the transform script of the schedule into the module. The script
92+
should accept an instance of OpHandle as argument, which will be called with
93+
the block arg of the newly created sequence op.
94+
95+
Example:
96+
This python code
97+
```
98+
module = ir.Module.create()
99+
def test_match_ops_single(module: OpHandle):
100+
module.match_ops(scf.ForOp)
101+
insert_transform_script(module, script)
102+
```
103+
generates the following IR:
104+
```
105+
module {
106+
transform.sequence failures(propagate) {
107+
^bb0(%arg0: !transform.any_op):
108+
%0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.op<"scf.for">
109+
}
110+
}
111+
```
112+
"""
94113

95-
# Insert the script into the IR
96114
with module.context, ir.Location.unknown(module.context):
97115
with ir.InsertionPoint.at_block_begin(module.body):
98116
sequence_op = transform.SequenceOp(

mlir/test/python/dialects/transform_extras.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

3-
from typing import Callable, TypeVar
3+
from typing import Callable
44
from mlir import ir
55
from mlir.dialects import scf
66
from mlir.dialects.transform import structured
7-
from mlir.dialects.transform.extras import Value, OpHandle, insert_transform_script
7+
from mlir.dialects.transform.extras import OpHandle, insert_transform_script
88

9-
ValueT = TypeVar("ValueT", bound=Value)
109

11-
12-
def build_transform_script(script: Callable[[ValueT], None]):
10+
def build_transform_script(script: Callable[[OpHandle], None]):
1311
print("\nTEST:", script.__name__)
1412
with ir.Context(), ir.Location.unknown():
1513
module = ir.Module.create()

0 commit comments

Comments
 (0)