Skip to content

Commit c333f86

Browse files
committed
[mlir][python] fix up affine for
1 parent ff0e4fb commit c333f86

File tree

3 files changed

+168
-93
lines changed

3 files changed

+168
-93
lines changed

mlir/python/mlir/dialects/_ods_common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,6 @@ def get_op_result_or_op_results(
134134
# see the typing.Type doc string.
135135
_U = _TypeVar("_U", bound=_cext.ir.Value)
136136
SubClassValueT = _Type[_U]
137+
138+
ResultValueT = _Union[_cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value]
139+
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]

mlir/python/mlir/dialects/affine.py

Lines changed: 42 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

55
from ._affine_ops_gen import *
6-
from ._affine_ops_gen import _Dialect, AffineForOp
7-
from .arith import constant
6+
from ._affine_ops_gen import _Dialect
87

98
try:
109
from ..ir import *
1110
from ._ods_common import (
1211
get_op_result_or_value as _get_op_result_or_value,
1312
get_op_results_or_values as _get_op_results_or_values,
1413
_cext as _ods_cext,
14+
ResultValueT as _ResultValueT,
15+
VariadicResultValueT as _VariadicResultValueT,
1516
)
1617
except ImportError as e:
1718
raise RuntimeError("Error loading imports from extension module") from e
@@ -21,17 +22,17 @@
2122

2223
@_ods_cext.register_operation(_Dialect, replace=True)
2324
class AffineForOp(AffineForOp):
24-
"""Specialization for the Affine for op class"""
25+
"""Specialization for the Affine for op class."""
2526

2627
def __init__(
2728
self,
28-
lower_bound,
29-
upper_bound,
30-
step,
31-
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
29+
lower_bound: Union[int, _ResultValueT, AffineMap],
30+
upper_bound: Optional[Union[int, _ResultValueT, AffineMap]] = None,
31+
step: Optional[Union[int, _ResultValueT]] = None,
32+
iter_args: Optional[_ResultValueT] = None,
3233
*,
33-
lower_bound_operands=[],
34-
upper_bound_operands=[],
34+
lower_bound_operands: Optional[_VariadicResultValueT] = None,
35+
upper_bound_operands: Optional[_VariadicResultValueT] = None,
3536
loc=None,
3637
ip=None,
3738
):
@@ -45,23 +46,40 @@ def __init__(
4546
- `lower_bound_operands` is the list of arguments to substitute the dimensions,
4647
then symbols in the `lower_bound` affine map, in an increasing order
4748
- `upper_bound_operands` is the list of arguments to substitute the dimensions,
48-
then symbols in the `upper_bound` affine map, in an increasing order
49+
then symbols in the `upper_bound` affine map, in an increasing order.
4950
"""
5051

52+
if lower_bound_operands is None:
53+
lower_bound_operands = []
54+
if upper_bound_operands is None:
55+
upper_bound_operands = []
56+
57+
if step is None:
58+
step = 1
59+
if upper_bound is None:
60+
upper_bound, lower_bound = lower_bound, 0
61+
62+
if isinstance(lower_bound, int):
63+
lower_bound = AffineMap.get_constant(lower_bound)
64+
elif isinstance(lower_bound, _ResultValueT):
65+
lower_bound_operands.append(lower_bound)
66+
lower_bound = AffineMap.get_constant(1)
67+
68+
if not isinstance(lower_bound, AffineMap):
69+
raise ValueError(f"{lower_bound=} must be int | ResultValueT | AffineMap")
70+
71+
if isinstance(upper_bound, int):
72+
upper_bound = AffineMap.get_constant(upper_bound)
73+
elif isinstance(upper_bound, _ResultValueT):
74+
upper_bound_operands.append(upper_bound)
75+
upper_bound = AffineMap.get_constant(1)
76+
77+
if not isinstance(upper_bound, AffineMap):
78+
raise ValueError(f"{upper_bound=} must be int | ResultValueT | AffineMap")
79+
5180
if iter_args is None:
5281
iter_args = []
5382
iter_args = _get_op_results_or_values(iter_args)
54-
if len(lower_bound_operands) != lower_bound.n_inputs:
55-
raise ValueError(
56-
f"Wrong number of lower bound operands passed to AffineForOp. "
57-
+ "Expected {lower_bound.n_symbols}, got {len(lower_bound_operands)}."
58-
)
59-
60-
if len(upper_bound_operands) != upper_bound.n_inputs:
61-
raise ValueError(
62-
f"Wrong number of upper bound operands passed to AffineForOp. "
63-
+ "Expected {upper_bound.n_symbols}, got {len(upper_bound_operands)}."
64-
)
6583

6684
results = [arg.type for arg in iter_args]
6785
super().__init__(
@@ -71,7 +89,7 @@ def __init__(
7189
inits=list(iter_args),
7290
lowerBoundMap=AffineMapAttr.get(lower_bound),
7391
upperBoundMap=AffineMapAttr.get(upper_bound),
74-
step=IntegerAttr.get(IndexType.get(), step),
92+
step=step,
7593
loc=loc,
7694
ip=ip,
7795
)
@@ -105,30 +123,11 @@ def for_(
105123
loc=None,
106124
ip=None,
107125
):
108-
if step is None:
109-
step = 1
110-
if stop is None:
111-
stop = start
112-
start = 0
113-
params = [start, stop]
114-
for i, p in enumerate(params):
115-
if isinstance(p, int):
116-
p = constant(IntegerAttr.get(IndexType.get(), p))
117-
elif isinstance(p, float):
118-
raise ValueError(f"{p=} must be int.")
119-
params[i] = p
120-
121-
start, stop = params
122-
s0 = AffineSymbolExpr.get(0)
123-
lbmap = AffineMap.get(0, 1, [s0])
124-
ubmap = AffineMap.get(0, 1, [s0])
125126
for_op = AffineForOp(
126-
lbmap,
127-
ubmap,
127+
start,
128+
stop,
128129
step,
129130
iter_args=iter_args,
130-
lower_bound_operands=[start],
131-
upper_bound_operands=[stop],
132131
loc=loc,
133132
ip=ip,
134133
)

mlir/test/python/dialects/affine.py

Lines changed: 123 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from mlir.dialects import arith
66
from mlir.dialects import memref
77
from mlir.dialects import affine
8+
import mlir.extras.types as T
89

910

1011
def constructAndPrintInModule(f):
@@ -115,58 +116,130 @@ def affine_for_op_test(buffer):
115116

116117
@constructAndPrintInModule
117118
def testForSugar():
118-
index_type = IndexType.get()
119-
memref_t = MemRefType.get([10], index_type)
119+
memref_t = T.memref(10, T.index())
120120
range = affine.for_
121121

122-
# CHECK: func.func @range_loop_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
123-
# CHECK: %[[VAL_4:.*]] = arith.constant 10 : index
124-
# CHECK: affine.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_4]] step 2 {
125-
# CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
126-
# CHECK: affine.store %[[VAL_7]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_6]]{{\)\]}} : memref<10xindex>
127-
# CHECK: }
128-
# CHECK: return
129-
# CHECK: }
130-
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
131-
def range_loop_1(lb, ub, step, memref_v):
132-
for i in range(lb, 10, 2):
122+
# CHECK-LABEL: func.func @range_loop_1(
123+
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
124+
# CHECK: affine.for %[[VAL_3:.*]] = 1 to 1 iter_args() -> () {
125+
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
126+
# CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
127+
# CHECK: affine.yield
128+
# CHECK: }
129+
# CHECK: return
130+
# CHECK: }
131+
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
132+
def range_loop_1(lb, ub, memref_v):
133+
for i in range(lb, ub, step=1):
134+
add = arith.addi(i, i)
135+
memref.store(add, memref_v, [i])
136+
137+
affine.yield_([])
138+
139+
# CHECK-LABEL: func.func @range_loop_2(
140+
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
141+
# CHECK: affine.for %[[VAL_3:.*]] = 1 to 10 iter_args() -> () {
142+
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
143+
# CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
144+
# CHECK: affine.yield
145+
# CHECK: }
146+
# CHECK: return
147+
# CHECK: }
148+
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
149+
def range_loop_2(lb, ub, memref_v):
150+
for i in range(lb, 10, step=1):
151+
add = arith.addi(i, i)
152+
memref.store(add, memref_v, [i])
153+
affine.yield_([])
154+
155+
# CHECK-LABEL: func.func @range_loop_3(
156+
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
157+
# CHECK: affine.for %[[VAL_3:.*]] = 0 to 1 iter_args() -> () {
158+
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
159+
# CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
160+
# CHECK: affine.yield
161+
# CHECK: }
162+
# CHECK: return
163+
# CHECK: }
164+
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
165+
def range_loop_3(lb, ub, memref_v):
166+
for i in range(0, ub, step=1):
167+
add = arith.addi(i, i)
168+
memref.store(add, memref_v, [i])
169+
affine.yield_([])
170+
171+
# CHECK-LABEL: func.func @range_loop_4(
172+
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
173+
# CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
174+
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
175+
# CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
176+
# CHECK: }
177+
# CHECK: return
178+
# CHECK: }
179+
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
180+
def range_loop_4(lb, ub, memref_v):
181+
for i in range(0, 10, step=1):
182+
add = arith.addi(i, i)
183+
memref.store(add, memref_v, [i])
184+
affine.yield_([])
185+
186+
# CHECK-LABEL: func.func @range_loop_5(
187+
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
188+
# CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
189+
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
190+
# CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
191+
# CHECK: }
192+
# CHECK: return
193+
# CHECK: }
194+
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
195+
def range_loop_5(lb, ub, memref_v):
196+
for i in range(0, 10, step=1):
197+
add = arith.addi(i, i)
198+
memref.store(add, memref_v, [i])
199+
affine.yield_([])
200+
201+
# CHECK-LABEL: func.func @range_loop_6(
202+
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
203+
# CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
204+
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
205+
# CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
206+
# CHECK: }
207+
# CHECK: return
208+
# CHECK: }
209+
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
210+
def range_loop_6(lb, ub, memref_v):
211+
for i in range(0, 10):
133212
add = arith.addi(i, i)
134-
s0 = AffineSymbolExpr.get(0)
135-
map = AffineMap.get(0, 1, [s0])
136-
affine.store(add, memref_v, [i], map=map)
137-
affine.AffineYieldOp([])
138-
139-
# CHECK: func.func @range_loop_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
140-
# CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
141-
# CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
142-
# CHECK: affine.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] {
143-
# CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
144-
# CHECK: affine.store %[[VAL_8]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_7]]{{\)\]}} : memref<10xindex>
145-
# CHECK: }
146-
# CHECK: return
147-
# CHECK: }
148-
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
149-
def range_loop_2(lb, ub, step, memref_v):
150-
for i in range(0, 10, 1):
213+
memref.store(add, memref_v, [i])
214+
affine.yield_([])
215+
216+
# CHECK-LABEL: func.func @range_loop_7(
217+
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
218+
# CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
219+
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
220+
# CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
221+
# CHECK: }
222+
# CHECK: return
223+
# CHECK: }
224+
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
225+
def range_loop_7(lb, ub, memref_v):
226+
for i in range(10):
151227
add = arith.addi(i, i)
152-
s0 = AffineSymbolExpr.get(0)
153-
map = AffineMap.get(0, 1, [s0])
154-
affine.store(add, memref_v, [i], map=map)
155-
affine.AffineYieldOp([])
156-
157-
# CHECK: func.func @range_loop_3(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
158-
# CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
159-
# CHECK: affine.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_1]] {
160-
# CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
161-
# CHECK: affine.store %[[VAL_7]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_6]]{{\)\]}} : memref<10xindex>
162-
# CHECK: }
163-
# CHECK: return
164-
# CHECK: }
165-
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
166-
def range_loop_3(lb, ub, step, memref_v):
167-
for i in range(0, ub, 1):
228+
memref.store(add, memref_v, [i])
229+
affine.yield_([])
230+
231+
# CHECK-LABEL: func.func @range_loop_8(
232+
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
233+
# CHECK: %[[VAL_3:.*]] = affine.for %[[VAL_4:.*]] = 0 to 10 iter_args(%[[VAL_5:.*]] = %[[VAL_2]]) -> (memref<10xindex>) {
234+
# CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index
235+
# CHECK: memref.store %[[VAL_6]], %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<10xindex>
236+
# CHECK: affine.yield %[[VAL_5]] : memref<10xindex>
237+
# CHECK: }
238+
# CHECK: return
239+
# CHECK: }
240+
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
241+
def range_loop_8(lb, ub, memref_v):
242+
for i, it in range(10, iter_args=[memref_v]):
168243
add = arith.addi(i, i)
169-
s0 = AffineSymbolExpr.get(0)
170-
map = AffineMap.get(0, 1, [s0])
171-
affine.store(add, memref_v, [i], map=map)
172-
affine.AffineYieldOp([])
244+
memref.store(add, it, [i])
245+
affine.yield_([it])

0 commit comments

Comments
 (0)