Skip to content

Commit 2f12f48

Browse files
committed
incorporate comments
1 parent 829df04 commit 2f12f48

File tree

3 files changed

+98
-77
lines changed

3 files changed

+98
-77
lines changed

mlir/python/mlir/dialects/_ods_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,5 +135,6 @@ def get_op_result_or_op_results(
135135
_U = _TypeVar("_U", bound=_cext.ir.Value)
136136
SubClassValueT = _Type[_U]
137137

138-
ResultValueT = _Union[_cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value]
138+
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
139+
ResultValueT = _Union[ResultValueTypeTuple]
139140
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]

mlir/python/mlir/dialects/affine.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
get_op_result_or_value as _get_op_result_or_value,
1212
get_op_results_or_values as _get_op_results_or_values,
1313
_cext as _ods_cext,
14+
ResultValueTypeTuple as _ResultValueTypeTuple,
1415
ResultValueT as _ResultValueT,
1516
VariadicResultValueT as _VariadicResultValueT,
1617
)
@@ -27,8 +28,8 @@ class AffineForOp(AffineForOp):
2728
def __init__(
2829
self,
2930
lower_bound: Union[int, _ResultValueT, AffineMap],
30-
upper_bound: Optional[Union[int, _ResultValueT, AffineMap]] = None,
31-
step: Optional[Union[int, _ResultValueT]] = None,
31+
upper_bound: Optional[Union[int, _ResultValueT, AffineMap]],
32+
step: Optional[Union[int, Attribute]] = None,
3233
iter_args: Optional[_ResultValueT] = None,
3334
*,
3435
lower_bound_operands: Optional[_VariadicResultValueT] = None,
@@ -44,7 +45,7 @@ def __init__(
4445
- `iter_args` is a list of additional loop-carried arguments or an operation
4546
producing them as results.
4647
- `lower_bound_operands` is the list of arguments to substitute the dimensions,
47-
then symbols in the `lower_bound` affine map, in an increasing order
48+
then symbols in the `lower_bound` affine map, in an increasing order.
4849
- `upper_bound_operands` is the list of arguments to substitute the dimensions,
4950
then symbols in the `upper_bound` affine map, in an increasing order.
5051
"""
@@ -56,36 +57,41 @@ def __init__(
5657

5758
if step is None:
5859
step = 1
59-
if upper_bound is None:
60-
upper_bound, lower_bound = lower_bound, 0
6160

62-
if isinstance(lower_bound, int):
63-
lower_bound = AffineMap.get_constant(lower_bound)
64-
elif isinstance(lower_bound, (Operation, OpView, Value)):
65-
if len(lower_bound_operands):
61+
bounds_operands = [lower_bound_operands, upper_bound_operands]
62+
bounds = [lower_bound, upper_bound]
63+
bounds_names = ["lower", "upper"]
64+
for i, name in enumerate(bounds_names):
65+
if isinstance(bounds[i], int):
66+
bounds[i] = AffineMap.get_constant(bounds[i])
67+
elif isinstance(bounds[i], _ResultValueTypeTuple):
68+
if len(bounds_operands[i]):
69+
raise ValueError(
70+
f"Either a concrete {name} bound or an AffineMap in combination "
71+
f"with {name} bound operands, but not both, is supported."
72+
)
73+
if (
74+
isinstance(bounds[i], (OpView, Operation))
75+
and len(bounds[i].results) > 1
76+
):
77+
raise ValueError(
78+
f"Only a single concrete value is supported for {name} bound."
79+
)
80+
81+
bounds_operands[i].append(_get_op_result_or_value(bounds[i]))
82+
bounds[i] = AffineMap.get_identity(1)
83+
84+
if not isinstance(bounds[i], AffineMap):
6685
raise ValueError(
67-
f"Either a concrete lower bound or an AffineMap in combination "
68-
f"with lower bound operands, but not both, is supported."
86+
f"{name} bound must be int | ResultValueT | AffineMap."
6987
)
70-
lower_bound_operands.append(lower_bound)
71-
lower_bound = AffineMap.get_identity(1)
72-
73-
if not isinstance(lower_bound, AffineMap):
74-
raise ValueError(f"{lower_bound=} must be int | ResultValueT | AffineMap")
75-
76-
if isinstance(upper_bound, int):
77-
upper_bound = AffineMap.get_constant(upper_bound)
78-
elif isinstance(upper_bound, (Operation, OpView, Value)):
79-
if len(upper_bound_operands):
88+
if len(bounds_operands[i]) != bounds[i].n_inputs:
8089
raise ValueError(
81-
f"Either a concrete upper bound or an AffineMap in combination "
82-
f"with upper bound operands, but not both, is supported."
90+
f"Wrong number of {name} bound operands passed to AffineForOp; "
91+
+ f"Expected {bounds[i].n_inputs}, got {len(bounds_operands[i])}."
8392
)
84-
upper_bound_operands.append(upper_bound)
85-
upper_bound = AffineMap.get_identity(1)
8693

87-
if not isinstance(upper_bound, AffineMap):
88-
raise ValueError(f"{upper_bound=} must be int | ResultValueT | AffineMap")
94+
lower_bound, upper_bound = bounds
8995

9096
if iter_args is None:
9197
iter_args = []
@@ -126,7 +132,7 @@ def inner_iter_args(self):
126132

127133
def for_(
128134
start,
129-
stop=None,
135+
stop,
130136
step=None,
131137
iter_args: Optional[Sequence[Value]] = None,
132138
*,

mlir/test/python/dialects/affine.py

Lines changed: 62 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,69 @@ def affine_for_op_test(buffer):
108108
# CHECK: %[[TMP:.*]] = memref.load %[[BUFFER]][%[[INDVAR]]] : memref<1024xf32>
109109
tmp = memref.LoadOp(buffer, [sum.induction_variable])
110110
sum_next = arith.AddFOp(sum.inner_iter_args[0], tmp)
111-
112111
affine.AffineYieldOp([sum_next])
113112

114-
return
113+
114+
# CHECK-LABEL: TEST: testAffineForOpErrors
115+
@constructAndPrintInModule
116+
def testAffineForOpErrors():
117+
c1 = arith.ConstantOp(T.index(), 1)
118+
c2 = arith.ConstantOp(T.index(), 2)
119+
c3 = arith.ConstantOp(T.index(), 3)
120+
d0 = AffineDimExpr.get(0)
121+
122+
try:
123+
affine.AffineForOp(
124+
c1,
125+
c2,
126+
1,
127+
lower_bound_operands=[c3],
128+
upper_bound_operands=[],
129+
)
130+
except ValueError as e:
131+
assert (
132+
e.args[0]
133+
== "Either a concrete lower bound or an AffineMap in combination with lower bound operands, but not both, is supported."
134+
)
135+
136+
try:
137+
affine.AffineForOp(
138+
AffineMap.get_constant(1),
139+
c2,
140+
1,
141+
lower_bound_operands=[c3, c3],
142+
upper_bound_operands=[],
143+
)
144+
except ValueError as e:
145+
assert (
146+
e.args[0]
147+
== "Wrong number of lower bound operands passed to AffineForOp; Expected 0, got 2."
148+
)
149+
150+
try:
151+
two_indices = affine.AffineDelinearizeIndexOp(
152+
[T.index(), T.index()], c1, [c1, c1]
153+
)
154+
affine.AffineForOp(
155+
two_indices,
156+
c2,
157+
1,
158+
lower_bound_operands=[],
159+
upper_bound_operands=[],
160+
)
161+
except ValueError as e:
162+
assert e.args[0] == "Only a single concrete value is supported for lower bound."
163+
164+
try:
165+
affine.AffineForOp(
166+
1.0,
167+
c2,
168+
1,
169+
lower_bound_operands=[],
170+
upper_bound_operands=[],
171+
)
172+
except ValueError as e:
173+
assert e.args[0] == "lower bound must be int | ResultValueT | AffineMap."
115174

116175

117176
@constructAndPrintInModule
@@ -182,51 +241,6 @@ def range_loop_4(lb, ub, memref_v):
182241
memref.store(add, memref_v, [i])
183242
affine.yield_([])
184243

185-
# CHECK-LABEL: func.func @range_loop_5(
186-
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
187-
# CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
188-
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
189-
# CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
190-
# CHECK: }
191-
# CHECK: return
192-
# CHECK: }
193-
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
194-
def range_loop_5(lb, ub, memref_v):
195-
for i in range(0, 10, step=1):
196-
add = arith.addi(i, i)
197-
memref.store(add, memref_v, [i])
198-
affine.yield_([])
199-
200-
# CHECK-LABEL: func.func @range_loop_6(
201-
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
202-
# CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
203-
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
204-
# CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
205-
# CHECK: }
206-
# CHECK: return
207-
# CHECK: }
208-
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
209-
def range_loop_6(lb, ub, memref_v):
210-
for i in range(0, 10):
211-
add = arith.addi(i, i)
212-
memref.store(add, memref_v, [i])
213-
affine.yield_([])
214-
215-
# CHECK-LABEL: func.func @range_loop_7(
216-
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
217-
# CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
218-
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
219-
# CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
220-
# CHECK: }
221-
# CHECK: return
222-
# CHECK: }
223-
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
224-
def range_loop_7(lb, ub, memref_v):
225-
for i in range(10):
226-
add = arith.addi(i, i)
227-
memref.store(add, memref_v, [i])
228-
affine.yield_([])
229-
230244
# CHECK-LABEL: func.func @range_loop_8(
231245
# CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
232246
# CHECK: %[[VAL_3:.*]] = affine.for %[[VAL_4:.*]] = 0 to 10 iter_args(%[[VAL_5:.*]] = %[[VAL_2]]) -> (memref<10xindex>) {
@@ -238,7 +252,7 @@ def range_loop_7(lb, ub, memref_v):
238252
# CHECK: }
239253
@func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
240254
def range_loop_8(lb, ub, memref_v):
241-
for i, it in range(10, iter_args=[memref_v]):
255+
for i, it in range(0, 10, iter_args=[memref_v]):
242256
add = arith.addi(i, i)
243257
memref.store(add, it, [i])
244258
affine.yield_([it])

0 commit comments

Comments
 (0)