Skip to content

Commit e9453f3

Browse files
authored
[mlir][python] fix scf.for_ convenience builder (#72170)
1 parent ed96430 commit e9453f3

File tree

2 files changed

+114
-10
lines changed

2 files changed

+114
-10
lines changed

mlir/python/mlir/dialects/scf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,13 @@ def for_(
120120
params = [start, stop, step]
121121
for i, p in enumerate(params):
122122
if isinstance(p, int):
123-
p = constant(p)
123+
p = constant(IntegerAttr.get(IndexType.get(), p))
124124
elif isinstance(p, float):
125125
raise ValueError(f"{p=} must be int.")
126126
params[i] = p
127127

128+
start, stop, step = params
129+
128130
for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
129131
iv = for_op.induction_variable
130132
iter_args = tuple(for_op.inner_iter_args)

mlir/test/python/dialects/scf.py

Lines changed: 111 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from mlir.ir import *
44
from mlir.dialects import arith
55
from mlir.dialects import func
6+
from mlir.dialects import memref
67
from mlir.dialects import scf
8+
from mlir.passmanager import PassManager
79

810

911
def constructAndPrintInModule(f):
@@ -57,22 +59,122 @@ def induction_var(lb, ub, step):
5759
@constructAndPrintInModule
5860
def testForSugar():
5961
index_type = IndexType.get()
62+
memref_t = MemRefType.get([10], index_type)
6063
range = scf.for_
6164

62-
@func.FuncOp.from_py_func(index_type, index_type, index_type)
63-
def range_loop(lb, ub, step):
65+
# CHECK: func.func @range_loop_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
66+
# CHECK: scf.for %[[VAL_4:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
67+
# CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index
68+
# CHECK: memref.store %[[VAL_5]], %[[VAL_3]]{{\[}}%[[VAL_4]]] : memref<10xindex>
69+
# CHECK: }
70+
# CHECK: return
71+
# CHECK: }
72+
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
73+
def range_loop_1(lb, ub, step, memref_v):
6474
for i in range(lb, ub, step):
6575
add = arith.addi(i, i)
76+
memref.store(add, memref_v, [i])
77+
78+
scf.yield_([])
79+
80+
# CHECK: func.func @range_loop_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
81+
# CHECK: %[[VAL_4:.*]] = arith.constant 10 : index
82+
# CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
83+
# CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_4]] step %[[VAL_5]] {
84+
# CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
85+
# CHECK: memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
86+
# CHECK: }
87+
# CHECK: return
88+
# CHECK: }
89+
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
90+
def range_loop_2(lb, ub, step, memref_v):
91+
for i in range(lb, 10, 1):
92+
add = arith.addi(i, i)
93+
memref.store(add, memref_v, [i])
6694
scf.yield_([])
67-
return
6895

96+
# CHECK: func.func @range_loop_3(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
97+
# CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
98+
# CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
99+
# CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_1]] step %[[VAL_5]] {
100+
# CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
101+
# CHECK: memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
102+
# CHECK: }
103+
# CHECK: return
104+
# CHECK: }
105+
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
106+
def range_loop_3(lb, ub, step, memref_v):
107+
for i in range(0, ub, 1):
108+
add = arith.addi(i, i)
109+
memref.store(add, memref_v, [i])
110+
scf.yield_([])
69111

70-
# CHECK: func.func @range_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) {
71-
# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
72-
# CHECK: %0 = arith.addi %[[IV]], %[[IV]] : index
73-
# CHECK: }
74-
# CHECK: return
75-
# CHECK: }
112+
# CHECK: func.func @range_loop_4(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
113+
# CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
114+
# CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
115+
# CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_2]] {
116+
# CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
117+
# CHECK: memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
118+
# CHECK: }
119+
# CHECK: return
120+
# CHECK: }
121+
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
122+
def range_loop_4(lb, ub, step, memref_v):
123+
for i in range(0, 10, step):
124+
add = arith.addi(i, i)
125+
memref.store(add, memref_v, [i])
126+
scf.yield_([])
127+
128+
# CHECK: func.func @range_loop_5(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
129+
# CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
130+
# CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
131+
# CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
132+
# CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
133+
# CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
134+
# CHECK: memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
135+
# CHECK: }
136+
# CHECK: return
137+
# CHECK: }
138+
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
139+
def range_loop_5(lb, ub, step, memref_v):
140+
for i in range(0, 10, 1):
141+
add = arith.addi(i, i)
142+
memref.store(add, memref_v, [i])
143+
scf.yield_([])
144+
145+
# CHECK: func.func @range_loop_6(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
146+
# CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
147+
# CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
148+
# CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
149+
# CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
150+
# CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
151+
# CHECK: memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
152+
# CHECK: }
153+
# CHECK: return
154+
# CHECK: }
155+
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
156+
def range_loop_6(lb, ub, step, memref_v):
157+
for i in range(0, 10):
158+
add = arith.addi(i, i)
159+
memref.store(add, memref_v, [i])
160+
scf.yield_([])
161+
162+
# CHECK: func.func @range_loop_7(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
163+
# CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
164+
# CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
165+
# CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
166+
# CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
167+
# CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
168+
# CHECK: memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
169+
# CHECK: }
170+
# CHECK: return
171+
# CHECK: }
172+
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
173+
def range_loop_7(lb, ub, step, memref_v):
174+
for i in range(10):
175+
add = arith.addi(i, i)
176+
memref.store(add, memref_v, [i])
177+
scf.yield_([])
76178

77179

78180
@constructAndPrintInModule

0 commit comments

Comments
 (0)