Skip to content

Commit 543589a

Browse files
authored
[mlir][python] python binding wrapper for the affine.AffineForOp (#74408)
This PR creates the wrapper class AffineForOp and adds a testcase for it. A testcase for the AffineLoadOp is also added.
1 parent d504824 commit 543589a

File tree

2 files changed

+293
-27
lines changed

2 files changed

+293
-27
lines changed

mlir/python/mlir/dialects/affine.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,141 @@
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

55
from ._affine_ops_gen import *
6+
from ._affine_ops_gen import _Dialect, AffineForOp
7+
from .arith import constant
8+
9+
try:
10+
from ..ir import *
11+
from ._ods_common import (
12+
get_op_result_or_value as _get_op_result_or_value,
13+
get_op_results_or_values as _get_op_results_or_values,
14+
_cext as _ods_cext,
15+
)
16+
except ImportError as e:
17+
raise RuntimeError("Error loading imports from extension module") from e
18+
19+
from typing import Optional, Sequence, Union
20+
21+
22+
@_ods_cext.register_operation(_Dialect, replace=True)
23+
class AffineForOp(AffineForOp):
24+
"""Specialization for the Affine for op class"""
25+
26+
def __init__(
27+
self,
28+
lower_bound,
29+
upper_bound,
30+
step,
31+
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
32+
*,
33+
lower_bound_operands=[],
34+
upper_bound_operands=[],
35+
loc=None,
36+
ip=None,
37+
):
38+
"""Creates an Affine `for` operation.
39+
40+
- `lower_bound` is the affine map to use as lower bound of the loop.
41+
- `upper_bound` is the affine map to use as upper bound of the loop.
42+
- `step` is the value to use as loop step.
43+
- `iter_args` is a list of additional loop-carried arguments or an operation
44+
producing them as results.
45+
- `lower_bound_operands` is the list of arguments to substitute the dimensions,
46+
then symbols in the `lower_bound` affine map, in an increasing order
47+
- `upper_bound_operands` is the list of arguments to substitute the dimensions,
48+
then symbols in the `upper_bound` affine map, in an increasing order
49+
"""
50+
51+
if iter_args is None:
52+
iter_args = []
53+
iter_args = _get_op_results_or_values(iter_args)
54+
if len(lower_bound_operands) != lower_bound.n_inputs:
55+
raise ValueError(
56+
f"Wrong number of lower bound operands passed to AffineForOp. "
57+
+ "Expected {lower_bound.n_symbols}, got {len(lower_bound_operands)}."
58+
)
59+
60+
if len(upper_bound_operands) != upper_bound.n_inputs:
61+
raise ValueError(
62+
f"Wrong number of upper bound operands passed to AffineForOp. "
63+
+ "Expected {upper_bound.n_symbols}, got {len(upper_bound_operands)}."
64+
)
65+
66+
results = [arg.type for arg in iter_args]
67+
super().__init__(
68+
results_=results,
69+
lowerBoundOperands=_get_op_results_or_values(lower_bound_operands),
70+
upperBoundOperands=_get_op_results_or_values(upper_bound_operands),
71+
inits=list(iter_args),
72+
lowerBoundMap=AffineMapAttr.get(lower_bound),
73+
upperBoundMap=AffineMapAttr.get(upper_bound),
74+
step=IntegerAttr.get(IndexType.get(), step),
75+
loc=loc,
76+
ip=ip,
77+
)
78+
self.regions[0].blocks.append(IndexType.get(), *results)
79+
80+
@property
81+
def body(self):
82+
"""Returns the body (block) of the loop."""
83+
return self.regions[0].blocks[0]
84+
85+
@property
86+
def induction_variable(self):
87+
"""Returns the induction variable of the loop."""
88+
return self.body.arguments[0]
89+
90+
@property
91+
def inner_iter_args(self):
92+
"""Returns the loop-carried arguments usable within the loop.
93+
94+
To obtain the loop-carried operands, use `iter_args`.
95+
"""
96+
return self.body.arguments[1:]
97+
98+
99+
def for_(
100+
start,
101+
stop=None,
102+
step=None,
103+
iter_args: Optional[Sequence[Value]] = None,
104+
*,
105+
loc=None,
106+
ip=None,
107+
):
108+
if step is None:
109+
step = 1
110+
if stop is None:
111+
stop = start
112+
start = 0
113+
params = [start, stop]
114+
for i, p in enumerate(params):
115+
if isinstance(p, int):
116+
p = constant(IntegerAttr.get(IndexType.get(), p))
117+
elif isinstance(p, float):
118+
raise ValueError(f"{p=} must be int.")
119+
params[i] = p
120+
121+
start, stop = params
122+
s0 = AffineSymbolExpr.get(0)
123+
lbmap = AffineMap.get(0, 1, [s0])
124+
ubmap = AffineMap.get(0, 1, [s0])
125+
for_op = AffineForOp(
126+
lbmap,
127+
ubmap,
128+
step,
129+
iter_args=iter_args,
130+
lower_bound_operands=[start],
131+
upper_bound_operands=[stop],
132+
loc=loc,
133+
ip=ip,
134+
)
135+
iv = for_op.induction_variable
136+
iter_args = tuple(for_op.inner_iter_args)
137+
with InsertionPoint(for_op.body):
138+
if len(iter_args) > 1:
139+
yield iv, iter_args
140+
elif len(iter_args) == 1:
141+
yield iv, iter_args[0]
142+
else:
143+
yield iv

mlir/test/python/dialects/affine.py

Lines changed: 155 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,172 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

33
from mlir.ir import *
4-
import mlir.dialects.func as func
5-
import mlir.dialects.arith as arith
6-
import mlir.dialects.affine as affine
7-
import mlir.dialects.memref as memref
4+
from mlir.dialects import func
5+
from mlir.dialects import arith
6+
from mlir.dialects import memref
7+
from mlir.dialects import affine
88

99

10-
def run(f):
10+
def constructAndPrintInModule(f):
1111
print("\nTEST:", f.__name__)
12-
f()
12+
with Context(), Location.unknown():
13+
module = Module.create()
14+
with InsertionPoint(module.body):
15+
f()
16+
print(module)
1317
return f
1418

1519

1620
# CHECK-LABEL: TEST: testAffineStoreOp
17-
@run
21+
@constructAndPrintInModule
1822
def testAffineStoreOp():
19-
with Context() as ctx, Location.unknown():
20-
module = Module.create()
21-
with InsertionPoint(module.body):
22-
f32 = F32Type.get()
23-
index_type = IndexType.get()
24-
memref_type_out = MemRefType.get([12, 12], f32)
23+
f32 = F32Type.get()
24+
index_type = IndexType.get()
25+
memref_type_out = MemRefType.get([12, 12], f32)
2526

26-
# CHECK: func.func @affine_store_test(%[[ARG0:.*]]: index) -> memref<12x12xf32> {
27-
@func.FuncOp.from_py_func(index_type)
28-
def affine_store_test(arg0):
29-
# CHECK: %[[O_VAR:.*]] = memref.alloc() : memref<12x12xf32>
30-
mem = memref.AllocOp(memref_type_out, [], []).result
27+
# CHECK: func.func @affine_store_test(%[[ARG0:.*]]: index) -> memref<12x12xf32> {
28+
@func.FuncOp.from_py_func(index_type)
29+
def affine_store_test(arg0):
30+
# CHECK: %[[O_VAR:.*]] = memref.alloc() : memref<12x12xf32>
31+
mem = memref.AllocOp(memref_type_out, [], []).result
3132

32-
d0 = AffineDimExpr.get(0)
33-
s0 = AffineSymbolExpr.get(0)
34-
map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1])
33+
d0 = AffineDimExpr.get(0)
34+
s0 = AffineSymbolExpr.get(0)
35+
map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1])
3536

36-
# CHECK: %[[A1:.*]] = arith.constant 2.100000e+00 : f32
37-
a1 = arith.ConstantOp(f32, 2.1)
37+
# CHECK: %[[A1:.*]] = arith.constant 2.100000e+00 : f32
38+
a1 = arith.ConstantOp(f32, 2.1)
3839

39-
# CHECK: affine.store %[[A1]], %alloc[symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<12x12xf32>
40-
affine.AffineStoreOp(a1, mem, indices=[arg0, arg0], map=map)
40+
# CHECK: affine.store %[[A1]], %alloc[symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<12x12xf32>
41+
affine.AffineStoreOp(a1, mem, indices=[arg0, arg0], map=map)
4142

42-
return mem
43+
return mem
4344

44-
print(module)
45+
46+
# CHECK-LABEL: TEST: testAffineLoadOp
47+
@constructAndPrintInModule
48+
def testAffineLoadOp():
49+
f32 = F32Type.get()
50+
index_type = IndexType.get()
51+
memref_type_in = MemRefType.get([10, 10], f32)
52+
53+
# CHECK: func.func @affine_load_test(%[[I_VAR:.*]]: memref<10x10xf32>, %[[ARG0:.*]]: index) -> f32 {
54+
@func.FuncOp.from_py_func(memref_type_in, index_type)
55+
def affine_load_test(I, arg0):
56+
d0 = AffineDimExpr.get(0)
57+
s0 = AffineSymbolExpr.get(0)
58+
map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1])
59+
60+
# CHECK: {{.*}} = affine.load %[[I_VAR]][symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<10x10xf32>
61+
a1 = affine.AffineLoadOp(f32, I, indices=[arg0, arg0], map=map)
62+
63+
return a1
64+
65+
66+
# CHECK-LABEL: TEST: testAffineForOp
67+
@constructAndPrintInModule
68+
def testAffineForOp():
69+
f32 = F32Type.get()
70+
index_type = IndexType.get()
71+
memref_type = MemRefType.get([1024], f32)
72+
73+
# CHECK: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (0, d0 + s0)>
74+
# CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 - 2, d1 * 32)>
75+
# CHECK: func.func @affine_for_op_test(%[[BUFFER:.*]]: memref<1024xf32>) {
76+
@func.FuncOp.from_py_func(memref_type)
77+
def affine_for_op_test(buffer):
78+
# CHECK: %[[C1:.*]] = arith.constant 1 : index
79+
c1 = arith.ConstantOp(index_type, 1)
80+
# CHECK: %[[C2:.*]] = arith.constant 2 : index
81+
c2 = arith.ConstantOp(index_type, 2)
82+
# CHECK: %[[C3:.*]] = arith.constant 3 : index
83+
c3 = arith.ConstantOp(index_type, 3)
84+
# CHECK: %[[C9:.*]] = arith.constant 9 : index
85+
c9 = arith.ConstantOp(index_type, 9)
86+
# CHECK: %[[AC0:.*]] = arith.constant 0.000000e+00 : f32
87+
ac0 = AffineConstantExpr.get(0)
88+
89+
d0 = AffineDimExpr.get(0)
90+
d1 = AffineDimExpr.get(1)
91+
s0 = AffineSymbolExpr.get(0)
92+
lb = AffineMap.get(1, 1, [ac0, d0 + s0])
93+
ub = AffineMap.get(2, 0, [d0 - 2, 32 * d1])
94+
sum_0 = arith.ConstantOp(f32, 0.0)
95+
96+
# CHECK: %0 = affine.for %[[INDVAR:.*]] = max #[[MAP0]](%[[C2]])[%[[C3]]] to min #[[MAP1]](%[[C9]], %[[C1]]) step 2 iter_args(%[[SUM0:.*]] = %[[AC0]]) -> (f32) {
97+
sum = affine.AffineForOp(
98+
lb,
99+
ub,
100+
2,
101+
iter_args=[sum_0],
102+
lower_bound_operands=[c2, c3],
103+
upper_bound_operands=[c9, c1],
104+
)
105+
106+
with InsertionPoint(sum.body):
107+
# CHECK: %[[TMP:.*]] = memref.load %[[BUFFER]][%[[INDVAR]]] : memref<1024xf32>
108+
tmp = memref.LoadOp(buffer, [sum.induction_variable])
109+
sum_next = arith.AddFOp(sum.inner_iter_args[0], tmp)
110+
111+
affine.AffineYieldOp([sum_next])
112+
113+
return
114+
115+
116+
@constructAndPrintInModule
117+
def testForSugar():
118+
index_type = IndexType.get()
119+
memref_t = MemRefType.get([10], index_type)
120+
range = affine.for_
121+
122+
# CHECK: func.func @range_loop_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
123+
# CHECK: %[[VAL_4:.*]] = arith.constant 10 : index
124+
# CHECK: affine.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_4]] step 2 {
125+
# CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
126+
# CHECK: affine.store %[[VAL_7]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_6]]{{\)\]}} : memref<10xindex>
127+
# CHECK: }
128+
# CHECK: return
129+
# CHECK: }
130+
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
131+
def range_loop_1(lb, ub, step, memref_v):
132+
for i in range(lb, 10, 2):
133+
add = arith.addi(i, i)
134+
s0 = AffineSymbolExpr.get(0)
135+
map = AffineMap.get(0, 1, [s0])
136+
affine.store(add, memref_v, [i], map=map)
137+
affine.AffineYieldOp([])
138+
139+
# CHECK: func.func @range_loop_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
140+
# CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
141+
# CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
142+
# CHECK: affine.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] {
143+
# CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
144+
# CHECK: affine.store %[[VAL_8]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_7]]{{\)\]}} : memref<10xindex>
145+
# CHECK: }
146+
# CHECK: return
147+
# CHECK: }
148+
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
149+
def range_loop_2(lb, ub, step, memref_v):
150+
for i in range(0, 10, 1):
151+
add = arith.addi(i, i)
152+
s0 = AffineSymbolExpr.get(0)
153+
map = AffineMap.get(0, 1, [s0])
154+
affine.store(add, memref_v, [i], map=map)
155+
affine.AffineYieldOp([])
156+
157+
# CHECK: func.func @range_loop_3(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
158+
# CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
159+
# CHECK: affine.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_1]] {
160+
# CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
161+
# CHECK: affine.store %[[VAL_7]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_6]]{{\)\]}} : memref<10xindex>
162+
# CHECK: }
163+
# CHECK: return
164+
# CHECK: }
165+
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
166+
def range_loop_3(lb, ub, step, memref_v):
167+
for i in range(0, ub, 1):
168+
add = arith.addi(i, i)
169+
s0 = AffineSymbolExpr.get(0)
170+
map = AffineMap.get(0, 1, [s0])
171+
affine.store(add, memref_v, [i], map=map)
172+
affine.AffineYieldOp([])

0 commit comments

Comments
 (0)