Skip to content

Commit db3bc49

Browse files
authored
[mlir][python] fix up affine for (#74495)
1 parent d8cd7fc commit db3bc49

File tree

3 files changed

+202
-97
lines changed

3 files changed

+202
-97
lines changed

mlir/python/mlir/dialects/_ods_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,7 @@ def get_op_result_or_op_results(
134134
# see the typing.Type doc string.
135135
_U = _TypeVar("_U", bound=_cext.ir.Value)
136136
SubClassValueT = _Type[_U]
137+
138+
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
139+
ResultValueT = _Union[ResultValueTypeTuple]
140+
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]

mlir/python/mlir/dialects/affine.py

Lines changed: 60 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

55
from ._affine_ops_gen import *
6-
from ._affine_ops_gen import _Dialect, AffineForOp
7-
from .arith import constant
6+
from ._affine_ops_gen import _Dialect
87

98
try:
109
from ..ir import *
1110
from ._ods_common import (
1211
get_op_result_or_value as _get_op_result_or_value,
1312
get_op_results_or_values as _get_op_results_or_values,
1413
_cext as _ods_cext,
14+
ResultValueTypeTuple as _ResultValueTypeTuple,
15+
ResultValueT as _ResultValueT,
16+
VariadicResultValueT as _VariadicResultValueT,
1517
)
1618
except ImportError as e:
1719
raise RuntimeError("Error loading imports from extension module") from e
@@ -21,17 +23,17 @@
2123

2224
@_ods_cext.register_operation(_Dialect, replace=True)
2325
class AffineForOp(AffineForOp):
24-
"""Specialization for the Affine for op class"""
26+
"""Specialization for the Affine for op class."""
2527

2628
def __init__(
2729
self,
28-
lower_bound,
29-
upper_bound,
30-
step,
31-
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
30+
lower_bound: Union[int, _ResultValueT, AffineMap],
31+
upper_bound: Optional[Union[int, _ResultValueT, AffineMap]],
32+
step: Optional[Union[int, Attribute]] = None,
33+
iter_args: Optional[_ResultValueT] = None,
3234
*,
33-
lower_bound_operands=[],
34-
upper_bound_operands=[],
35+
lower_bound_operands: Optional[_VariadicResultValueT] = None,
36+
upper_bound_operands: Optional[_VariadicResultValueT] = None,
3537
loc=None,
3638
ip=None,
3739
):
@@ -43,25 +45,57 @@ def __init__(
4345
- `iter_args` is a list of additional loop-carried arguments or an operation
4446
producing them as results.
4547
- `lower_bound_operands` is the list of arguments to substitute the dimensions,
46-
then symbols in the `lower_bound` affine map, in an increasing order
48+
then symbols in the `lower_bound` affine map, in an increasing order.
4749
- `upper_bound_operands` is the list of arguments to substitute the dimensions,
48-
then symbols in the `upper_bound` affine map, in an increasing order
50+
then symbols in the `upper_bound` affine map, in an increasing order.
4951
"""
5052

53+
if lower_bound_operands is None:
54+
lower_bound_operands = []
55+
if upper_bound_operands is None:
56+
upper_bound_operands = []
57+
58+
if step is None:
59+
step = 1
60+
61+
bounds_operands = [lower_bound_operands, upper_bound_operands]
62+
bounds = [lower_bound, upper_bound]
63+
bounds_names = ["lower", "upper"]
64+
for i, name in enumerate(bounds_names):
65+
if isinstance(bounds[i], int):
66+
bounds[i] = AffineMap.get_constant(bounds[i])
67+
elif isinstance(bounds[i], _ResultValueTypeTuple):
68+
if len(bounds_operands[i]):
69+
raise ValueError(
70+
f"Either a concrete {name} bound or an AffineMap in combination "
71+
f"with {name} bound operands, but not both, is supported."
72+
)
73+
if (
74+
isinstance(bounds[i], (OpView, Operation))
75+
and len(bounds[i].results) > 1
76+
):
77+
raise ValueError(
78+
f"Only a single concrete value is supported for {name} bound."
79+
)
80+
81+
bounds_operands[i].append(_get_op_result_or_value(bounds[i]))
82+
bounds[i] = AffineMap.get_identity(1)
83+
84+
if not isinstance(bounds[i], AffineMap):
85+
raise ValueError(
86+
f"{name} bound must be int | ResultValueT | AffineMap."
87+
)
88+
if len(bounds_operands[i]) != bounds[i].n_inputs:
89+
raise ValueError(
90+
f"Wrong number of {name} bound operands passed to AffineForOp; "
91+
+ f"Expected {bounds[i].n_inputs}, got {len(bounds_operands[i])}."
92+
)
93+
94+
lower_bound, upper_bound = bounds
95+
5196
if iter_args is None:
5297
iter_args = []
5398
iter_args = _get_op_results_or_values(iter_args)
54-
if len(lower_bound_operands) != lower_bound.n_inputs:
55-
raise ValueError(
56-
f"Wrong number of lower bound operands passed to AffineForOp. "
57-
+ "Expected {lower_bound.n_symbols}, got {len(lower_bound_operands)}."
58-
)
59-
60-
if len(upper_bound_operands) != upper_bound.n_inputs:
61-
raise ValueError(
62-
f"Wrong number of upper bound operands passed to AffineForOp. "
63-
+ "Expected {upper_bound.n_symbols}, got {len(upper_bound_operands)}."
64-
)
6599

66100
results = [arg.type for arg in iter_args]
67101
super().__init__(
@@ -71,7 +105,7 @@ def __init__(
71105
inits=list(iter_args),
72106
lowerBoundMap=AffineMapAttr.get(lower_bound),
73107
upperBoundMap=AffineMapAttr.get(upper_bound),
74-
step=IntegerAttr.get(IndexType.get(), step),
108+
step=step,
75109
loc=loc,
76110
ip=ip,
77111
)
@@ -98,37 +132,18 @@ def inner_iter_args(self):
98132

99133
def for_(
100134
start,
101-
stop=None,
135+
stop,
102136
step=None,
103137
iter_args: Optional[Sequence[Value]] = None,
104138
*,
105139
loc=None,
106140
ip=None,
107141
):
108-
if step is None:
109-
step = 1
110-
if stop is None:
111-
stop = start
112-
start = 0
113-
params = [start, stop]
114-
for i, p in enumerate(params):
115-
if isinstance(p, int):
116-
p = constant(IntegerAttr.get(IndexType.get(), p))
117-
elif isinstance(p, float):
118-
raise ValueError(f"{p=} must be int.")
119-
params[i] = p
120-
121-
start, stop = params
122-
s0 = AffineSymbolExpr.get(0)
123-
lbmap = AffineMap.get(0, 1, [s0])
124-
ubmap = AffineMap.get(0, 1, [s0])
125142
for_op = AffineForOp(
126-
lbmap,
127-
ubmap,
143+
start,
144+
stop,
128145
step,
129146
iter_args=iter_args,
130-
lower_bound_operands=[start],
131-
upper_bound_operands=[stop],
132147
loc=loc,
133148
ip=ip,
134149
)

mlir/test/python/dialects/affine.py

Lines changed: 138 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from mlir.dialects import arith
66
from mlir.dialects import memref
77
from mlir.dialects import affine
8+
import mlir.extras.types as T
89

910

1011
def constructAndPrintInModule(f):
@@ -107,66 +108,151 @@ def affine_for_op_test(buffer):
107108
# CHECK: %[[TMP:.*]] = memref.load %[[BUFFER]][%[[INDVAR]]] : memref<1024xf32>
108109
tmp = memref.LoadOp(buffer, [sum.induction_variable])
109110
sum_next = arith.AddFOp(sum.inner_iter_args[0], tmp)
110-
111111
affine.AffineYieldOp([sum_next])
112112

113-
return
113+
114+
# CHECK-LABEL: TEST: testAffineForOpErrors
115+
@constructAndPrintInModule
116+
def testAffineForOpErrors():
117+
c1 = arith.ConstantOp(T.index(), 1)
118+
c2 = arith.ConstantOp(T.index(), 2)
119+
c3 = arith.ConstantOp(T.index(), 3)
120+
d0 = AffineDimExpr.get(0)
121+
122+
try:
123+
affine.AffineForOp(
124+
c1,
125+
c2,
126+
1,
127+
lower_bound_operands=[c3],
128+
upper_bound_operands=[],
129+
)
130+
except ValueError as e:
131+
assert (
132+
e.args[0]
133+
== "Either a concrete lower bound or an AffineMap in combination with lower bound operands, but not both, is supported."
134+
)
135+
136+
try:
137+
affine.AffineForOp(
138+
AffineMap.get_constant(1),
139+
c2,
140+
1,
141+
lower_bound_operands=[c3, c3],
142+
upper_bound_operands=[],
143+
)
144+
except ValueError as e:
145+
assert (
146+
e.args[0]
147+
== "Wrong number of lower bound operands passed to AffineForOp; Expected 0, got 2."
148+
)
149+
150+
try:
151+
two_indices = affine.AffineDelinearizeIndexOp(
152+
[T.index(), T.index()], c1, [c1, c1]
153+
)
154+
affine.AffineForOp(
155+
two_indices,
156+
c2,
157+
1,
158+
lower_bound_operands=[],
159+
upper_bound_operands=[],
160+
)
161+
except ValueError as e:
162+
assert e.args[0] == "Only a single concrete value is supported for lower bound."
163+
164+
try:
165+
affine.AffineForOp(
166+
1.0,
167+
c2,
168+
1,
169+
lower_bound_operands=[],
170+
upper_bound_operands=[],
171+
)
172+
except ValueError as e:
173+
assert e.args[0] == "lower bound must be int | ResultValueT | AffineMap."
114174

115175

116176
@constructAndPrintInModule
117177
def testForSugar():
118-
index_type = IndexType.get()
119-
memref_t = MemRefType.get([10], index_type)
178+
memref_t = T.memref(10, T.index())
120179
range = affine.for_
121180

122-
# CHECK: func.func @range_loop_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
123-
# CHECK: %[[VAL_4:.*]] = arith.constant 10 : index
124-
# CHECK: affine.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_4]] step 2 {
125-
# CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
126-
# CHECK: affine.store %[[VAL_7]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_6]]{{\)\]}} : memref<10xindex>
127-
# CHECK: }
128-
# CHECK: return
129-
# CHECK: }
130-
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
131-
def range_loop_1(lb, ub, step, memref_v):
132-
for i in range(lb, 10, 2):
181+
# CHECK: #[[$ATTR_2:.+]] = affine_map<(d0) -> (d0)>
182+
183+
# CHECK-LABEL: func.func @range_loop_1(
184+
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
185+
# CHECK: affine.for %[[VAL_3:.*]] = #[[$ATTR_2]](%[[VAL_0]]) to #[[$ATTR_2]](%[[VAL_1]]) {
186+
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
187+
# CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
188+
# CHECK: }
189+
# CHECK: return
190+
# CHECK: }
191+
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
192+
def range_loop_1(lb, ub, memref_v):
193+
for i in range(lb, ub, step=1):
133194
add = arith.addi(i, i)
134-
s0 = AffineSymbolExpr.get(0)
135-
map = AffineMap.get(0, 1, [s0])
136-
affine.store(add, memref_v, [i], map=map)
137-
affine.AffineYieldOp([])
138-
139-
# CHECK: func.func @range_loop_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
140-
# CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
141-
# CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
142-
# CHECK: affine.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] {
143-
# CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
144-
# CHECK: affine.store %[[VAL_8]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_7]]{{\)\]}} : memref<10xindex>
145-
# CHECK: }
146-
# CHECK: return
147-
# CHECK: }
148-
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
149-
def range_loop_2(lb, ub, step, memref_v):
150-
for i in range(0, 10, 1):
195+
memref.store(add, memref_v, [i])
196+
197+
affine.yield_([])
198+
199+
# CHECK-LABEL: func.func @range_loop_2(
200+
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
201+
# CHECK: affine.for %[[VAL_3:.*]] = #[[$ATTR_2]](%[[VAL_0]]) to 10 {
202+
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
203+
# CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
204+
# CHECK: }
205+
# CHECK: return
206+
# CHECK: }
207+
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
208+
def range_loop_2(lb, ub, memref_v):
209+
for i in range(lb, 10, step=1):
151210
add = arith.addi(i, i)
152-
s0 = AffineSymbolExpr.get(0)
153-
map = AffineMap.get(0, 1, [s0])
154-
affine.store(add, memref_v, [i], map=map)
155-
affine.AffineYieldOp([])
156-
157-
# CHECK: func.func @range_loop_3(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
158-
# CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
159-
# CHECK: affine.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_1]] {
160-
# CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
161-
# CHECK: affine.store %[[VAL_7]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_6]]{{\)\]}} : memref<10xindex>
162-
# CHECK: }
163-
# CHECK: return
164-
# CHECK: }
165-
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
166-
def range_loop_3(lb, ub, step, memref_v):
167-
for i in range(0, ub, 1):
211+
memref.store(add, memref_v, [i])
212+
affine.yield_([])
213+
214+
# CHECK-LABEL: func.func @range_loop_3(
215+
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
216+
# CHECK: affine.for %[[VAL_3:.*]] = 0 to #[[$ATTR_2]](%[[VAL_1]]) {
217+
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
218+
# CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
219+
# CHECK: }
220+
# CHECK: return
221+
# CHECK: }
222+
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
223+
def range_loop_3(lb, ub, memref_v):
224+
for i in range(0, ub, step=1):
225+
add = arith.addi(i, i)
226+
memref.store(add, memref_v, [i])
227+
affine.yield_([])
228+
229+
# CHECK-LABEL: func.func @range_loop_4(
230+
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
231+
# CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
232+
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
233+
# CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
234+
# CHECK: }
235+
# CHECK: return
236+
# CHECK: }
237+
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
238+
def range_loop_4(lb, ub, memref_v):
239+
for i in range(0, 10, step=1):
240+
add = arith.addi(i, i)
241+
memref.store(add, memref_v, [i])
242+
affine.yield_([])
243+
244+
# CHECK-LABEL: func.func @range_loop_8(
245+
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
246+
# CHECK: %[[VAL_3:.*]] = affine.for %[[VAL_4:.*]] = 0 to 10 iter_args(%[[VAL_5:.*]] = %[[VAL_2]]) -> (memref<10xindex>) {
247+
# CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index
248+
# CHECK: memref.store %[[VAL_6]], %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<10xindex>
249+
# CHECK: affine.yield %[[VAL_5]] : memref<10xindex>
250+
# CHECK: }
251+
# CHECK: return
252+
# CHECK: }
253+
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
254+
def range_loop_8(lb, ub, memref_v):
255+
for i, it in range(0, 10, iter_args=[memref_v]):
168256
add = arith.addi(i, i)
169-
s0 = AffineSymbolExpr.get(0)
170-
map = AffineMap.get(0, 1, [s0])
171-
affine.store(add, memref_v, [i], map=map)
172-
affine.AffineYieldOp([])
257+
memref.store(add, it, [i])
258+
affine.yield_([it])

0 commit comments

Comments
 (0)