Skip to content

Commit 691a2fa

Browse files
[mlir][linalg][transform][python] Add mix-in for MapCopyToThreadsOp.
Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D157706
1 parent 030e315 commit 691a2fa

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed

mlir/python/mlir/dialects/_structured_transform_ops_ext.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,66 @@ def __init__(
187187
)
188188

189189

190+
class MapCopyToThreadsOp:
191+
"""Specialization for MapCopyToThreadsOp class."""
192+
193+
@overload
194+
def __init__(
195+
self,
196+
forall_op_type: Type,
197+
tiled_op_type: Type,
198+
target: Union[Operation, OpView, Value],
199+
*,
200+
total_num_threads: Union[int, IntegerAttr],
201+
desired_bit_alignment: Union[int, IntegerAttr],
202+
loc=None,
203+
ip=None,
204+
):
205+
...
206+
207+
@overload
208+
def __init__(
209+
self,
210+
target: Union[Operation, OpView, Value],
211+
*,
212+
total_num_threads: Union[int, IntegerAttr],
213+
desired_bit_alignment: Union[int, IntegerAttr],
214+
loc=None,
215+
ip=None,
216+
):
217+
...
218+
219+
def __init__(
220+
self,
221+
forall_op_type_or_target: Union[Operation, OpView, Type, Value],
222+
tiled_op_type_or_none: Optional[Type] = None,
223+
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
224+
*,
225+
total_num_threads: Union[int, IntegerAttr],
226+
desired_bit_alignment: Union[int, IntegerAttr],
227+
loc=None,
228+
ip=None,
229+
):
230+
if isinstance(forall_op_type_or_target, Type):
231+
forall_op_type = forall_op_type_or_target
232+
tiled_op_type = tiled_op_type_or_none
233+
target = target_or_none
234+
else:
235+
forall_op_type = transform.AnyOpType.get()
236+
tiled_op_type = transform.AnyOpType.get()
237+
target = forall_op_type_or_target
238+
239+
super().__init__(
240+
forall_op_type,
241+
tiled_op_type,
242+
target,
243+
total_num_threads=total_num_threads,
244+
desired_bit_alignment=desired_bit_alignment,
245+
loc=loc,
246+
ip=ip,
247+
)
248+
249+
190250
class MatchOp:
191251
"""Specialization for MatchOp class."""
192252

mlir/test/python/dialects/transform_structured_ext.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,44 @@ def testInterchange():
9797
# CHECK: iterator_interchange = [1, 0]
9898

9999

100+
@run
101+
def testMapCopyToThreadsOpCompact():
102+
sequence = transform.SequenceOp(
103+
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
104+
)
105+
with InsertionPoint(sequence.body):
106+
structured.MapCopyToThreadsOp(
107+
sequence.bodyTarget, total_num_threads=32, desired_bit_alignment=128
108+
)
109+
transform.YieldOp()
110+
# CHECK-LABEL: TEST: testMapCopyToThreadsOpCompact
111+
# CHECK: = transform.structured.gpu.map_copy_to_threads
112+
# CHECK-SAME: total_num_threads = 32
113+
# CHECK-SAME: desired_bit_alignment = 128
114+
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
115+
116+
117+
@run
118+
def testMapCopyToThreadsOpTypes():
119+
sequence = transform.SequenceOp(
120+
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
121+
)
122+
with InsertionPoint(sequence.body):
123+
structured.MapCopyToThreadsOp(
124+
transform.OperationType.get("test.opA"),
125+
transform.OperationType.get("test.opB"),
126+
sequence.bodyTarget,
127+
total_num_threads=32,
128+
desired_bit_alignment=128,
129+
)
130+
transform.YieldOp()
131+
# CHECK-LABEL: TEST: testMapCopyToThreadsOpTypes
132+
# CHECK: = transform.structured.gpu.map_copy_to_threads
133+
# CHECK-SAME: total_num_threads = 32
134+
# CHECK-SAME: desired_bit_alignment = 128
135+
# CHECK-SAME: (!transform.any_op) -> (!transform.op<"test.opA">, !transform.op<"test.opB">)
136+
137+
100138
@run
101139
def testMatchOpNamesString():
102140
sequence = transform.SequenceOp(

0 commit comments

Comments
 (0)