Skip to content

Commit 1bc5fe6

Browse files
authored
[mlir][python] implement GenericOp bindings (#124496)
1 parent 822954b commit 1bc5fe6

File tree

2 files changed

+141
-1
lines changed

2 files changed

+141
-1
lines changed

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py.
1111
from .._linalg_ops_gen import *
1212
from .._linalg_enum_gen import *
13+
from .._linalg_enum_gen import _iteratortypeenum
1314

1415
# These are the ground truth functions defined as:
1516
# ```
@@ -58,6 +59,7 @@
5859

5960
from ...ir import *
6061
from .._ods_common import get_op_result_or_value as _get_op_result_or_value
62+
from ...extras.meta import region_op
6163

6264

6365
def transpose(
@@ -102,3 +104,46 @@ def broadcast(
102104
)
103105
fill_builtin_region(op.operation)
104106
return op
107+
108+
109+
@register_attribute_builder("IteratorTypeArrayAttr")
110+
def _IteratorTypeArrayAttr(x, context):
111+
return ArrayAttr.get([_iteratortypeenum(v, context) for v in x])
112+
113+
114+
# The underscore is needed here so that there's no collision with opdsl generation.
115+
class GenericOp_(GenericOp):
116+
def __init__(
117+
self,
118+
inputs,
119+
outputs,
120+
indexing_maps,
121+
iterator_types,
122+
*,
123+
doc=None,
124+
library_call=None,
125+
loc=None,
126+
ip=None,
127+
):
128+
result_types = []
129+
if isinstance(outputs[0].type, RankedTensorType):
130+
result_types = [o.type for o in outputs]
131+
132+
super().__init__(
133+
result_types,
134+
inputs,
135+
outputs,
136+
indexing_maps,
137+
iterator_types,
138+
doc=doc,
139+
library_call=library_call,
140+
loc=loc,
141+
ip=ip,
142+
)
143+
element_types = [i.type.element_type for i in inputs] + [
144+
o.type.element_type for o in outputs
145+
]
146+
self.regions[0].blocks.append(*element_types)
147+
148+
149+
generic = region_op(GenericOp_, terminator=YieldOp)

mlir/test/python/dialects/linalg/ops.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

3-
from mlir.dialects import arith, builtin, func, linalg, tensor
3+
from mlir.dialects import arith, func, linalg, tensor, memref
44
from mlir.dialects.linalg.opdsl.lang import *
55
from mlir.ir import *
66

@@ -84,6 +84,7 @@ def named_form(lhs, rhs):
8484

8585
print(module)
8686

87+
8788
# CHECK-LABEL: TEST: testIdentityRegionOps
8889
@run
8990
def testIdentityRegionOps():
@@ -161,3 +162,97 @@ def broadcast_op(op1, op2, op3):
161162
op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])
162163

163164
print(module)
165+
166+
167+
# CHECK-LABEL: TEST: testGenericOp
168+
@run
169+
def testGenericOp():
170+
with Context(), Location.unknown():
171+
module = Module.create()
172+
f32 = F32Type.get()
173+
memref_t = MemRefType.get([10, 10], f32)
174+
with InsertionPoint(module.body):
175+
id_map_1 = AffineMap.get_identity(2)
176+
# CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32>
177+
# CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32>
178+
x = tensor.empty((16, 16), f32)
179+
y = tensor.empty((16, 16), f32)
180+
181+
# CHECK: %[[VAL_2:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) {
182+
# CHECK: ^bb0(%in: f32, %out: f32):
183+
# CHECK: linalg.yield %in : f32
184+
# CHECK: } -> tensor<16x16xf32>
185+
@linalg.generic(
186+
[x],
187+
[y],
188+
[id_map_1, id_map_1],
189+
[linalg.IteratorType.parallel, linalg.IteratorType.parallel],
190+
)
191+
def f(a, b):
192+
assert isinstance(a, Value)
193+
assert isinstance(a.type, F32Type)
194+
assert isinstance(b, Value)
195+
assert isinstance(b.type, F32Type)
196+
return a
197+
198+
assert isinstance(f, Value)
199+
assert isinstance(f.type, RankedTensorType)
200+
201+
# CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32>
202+
z = tensor.empty((16, 16, 16), f32)
203+
204+
minor_id = AffineMap.get_minor_identity(3, 2)
205+
id_map_2 = AffineMap.get_identity(3)
206+
207+
# CHECK: %[[VAL_4:.+]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) {
208+
# CHECK: ^bb0(%in: f32, %out: f32, %out_1: f32):
209+
# CHECK: linalg.yield %in, %out : f32, f32
210+
# CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>)
211+
@linalg.generic(
212+
[x],
213+
[z, z],
214+
[minor_id, id_map_2, id_map_2],
215+
[
216+
linalg.IteratorType.parallel,
217+
linalg.IteratorType.parallel,
218+
linalg.IteratorType.parallel,
219+
],
220+
)
221+
def g(a, b, c):
222+
assert isinstance(a, Value)
223+
assert isinstance(a.type, F32Type)
224+
assert isinstance(b, Value)
225+
assert isinstance(b.type, F32Type)
226+
assert isinstance(c, Value)
227+
assert isinstance(c.type, F32Type)
228+
return a, b
229+
230+
assert isinstance(g, OpResultList)
231+
assert len(g) == 2
232+
assert isinstance(g[0].type, RankedTensorType)
233+
assert isinstance(g[1].type, RankedTensorType)
234+
235+
# CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<10x10xf32>
236+
# CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<10x10xf32>
237+
xx = memref.alloc(memref_t, [], [])
238+
yy = memref.alloc(memref_t, [], [])
239+
240+
# CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_5]] : memref<10x10xf32>) outs(%[[VAL_6]] : memref<10x10xf32>) {
241+
# CHECK: ^bb0(%in: f32, %out: f32):
242+
# CHECK: linalg.yield %in : f32
243+
# CHECK: }
244+
@linalg.generic(
245+
[xx],
246+
[yy],
247+
[id_map_1, id_map_1],
248+
[linalg.IteratorType.parallel, linalg.IteratorType.parallel],
249+
)
250+
def f(a, b):
251+
assert isinstance(a, Value)
252+
assert isinstance(a.type, F32Type)
253+
assert isinstance(b, Value)
254+
assert isinstance(b.type, F32Type)
255+
return a
256+
257+
module.operation.verify()
258+
print(module)

0 commit comments

Comments
 (0)