Skip to content

Commit d7fc779

Browse files
authored
[mlir][SCF]-Fix loop coalescing with iteration arguements (#105488)
Fix a bug found when coalescing loops which have iteration arguments, such that the inner loop's terminator may have operands of the inner loop iteration arguments which are about to be replaced by the outer loop's iteration arguments. The current flow leads to crush within the IR code.
1 parent 42d06b8 commit d7fc779

File tree

2 files changed

+132
-0
lines changed

2 files changed

+132
-0
lines changed

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,18 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
864864

865865
Operation *innerTerminator = innerLoop.getBody()->getTerminator();
866866
auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
867+
assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs()));
868+
for (Value &yieldedVal : yieldedVals) {
869+
// The yielded value may be an iteration argument of the inner loop
870+
// which is about to be inlined.
871+
auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
872+
if (iter != innerLoop.getRegionIterArgs().end()) {
873+
unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
874+
// `outerLoop` iter args identical to the `innerLoop` init args.
875+
assert(iterArgIndex < innerLoop.getInitArgs().size());
876+
yieldedVal = innerLoop.getInitArgs()[iterArgIndex];
877+
}
878+
}
867879
rewriter.eraseOp(innerTerminator);
868880

869881
SmallVector<Value> innerBlockArgs;

mlir/test/Dialect/Affine/loop-coalescing.mlir

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,126 @@ func.func @unnormalized_loops() {
114114
return
115115
}
116116

117+
func.func @noramalized_loops_with_yielded_iter_args() {
118+
// CHECK: %[[orig_lb:.*]] = arith.constant 0
119+
// CHECK: %[[orig_step:.*]] = arith.constant 1
120+
// CHECK: %[[orig_ub_k:.*]] = arith.constant 3
121+
// CHECK: %[[orig_ub_i:.*]] = arith.constant 42
122+
// CHECK: %[[orig_ub_j:.*]] = arith.constant 56
123+
%c0 = arith.constant 0 : index
124+
%c1 = arith.constant 1 : index
125+
%c3 = arith.constant 3 : index
126+
%c42 = arith.constant 42 : index
127+
%c56 = arith.constant 56 : index
128+
// The range of the new scf.
129+
// CHECK: %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]]
130+
// CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]]
131+
132+
// Updated loop bounds.
133+
// CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]]) -> (index) {
134+
%2:1 = scf.for %i = %c0 to %c42 step %c1 iter_args(%arg0 = %c0) -> (index) {
135+
// Inner loops must have been removed.
136+
// CHECK-NOT: scf.for
137+
138+
// Reconstruct original IVs from the linearized one.
139+
// CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]]
140+
// CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]]
141+
// CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]]
142+
// CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]]
143+
%1:1 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg1 = %arg0) -> (index){
144+
%0:1 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg2 = %arg1) -> (index) {
145+
// CHECK: "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]])
146+
"use"(%i, %j, %k) : (index, index, index) -> ()
147+
// CHECK: scf.yield %[[VAL_1]] : index
148+
scf.yield %arg2 : index
149+
}
150+
scf.yield %0#0 : index
151+
}
152+
scf.yield %1#0 : index
153+
}
154+
return
155+
}
156+
157+
func.func @noramalized_loops_with_shuffled_yielded_iter_args() {
158+
// CHECK: %[[orig_lb:.*]] = arith.constant 0
159+
// CHECK: %[[orig_step:.*]] = arith.constant 1
160+
// CHECK: %[[orig_ub_k:.*]] = arith.constant 3
161+
// CHECK: %[[orig_ub_i:.*]] = arith.constant 42
162+
// CHECK: %[[orig_ub_j:.*]] = arith.constant 56
163+
%c0 = arith.constant 0 : index
164+
%c1 = arith.constant 1 : index
165+
%c3 = arith.constant 3 : index
166+
%c42 = arith.constant 42 : index
167+
%c56 = arith.constant 56 : index
168+
// The range of the new scf.
169+
// CHECK: %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]]
170+
// CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]]
171+
172+
// Updated loop bounds.
173+
// CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]], %[[VAL_2:.*]] = %[[orig_lb]]) -> (index, index) {
174+
%2:2 = scf.for %i = %c0 to %c42 step %c1 iter_args(%arg0 = %c0, %arg1 = %c0) -> (index, index) {
175+
// Inner loops must have been removed.
176+
// CHECK-NOT: scf.for
177+
178+
// Reconstruct original IVs from the linearized one.
179+
// CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]]
180+
// CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]]
181+
// CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]]
182+
// CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]]
183+
%1:2 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg2 = %arg0, %arg3 = %arg1) -> (index, index){
184+
%0:2 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg4 = %arg2, %arg5 = %arg3) -> (index, index) {
185+
// CHECK: "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]])
186+
"use"(%i, %j, %k) : (index, index, index) -> ()
187+
// CHECK: scf.yield %[[VAL_2]], %[[VAL_1]] : index, index
188+
scf.yield %arg5, %arg4 : index, index
189+
}
190+
scf.yield %0#0, %0#1 : index, index
191+
}
192+
scf.yield %1#0, %1#1 : index, index
193+
}
194+
return
195+
}
196+
197+
func.func @noramalized_loops_with_yielded_non_iter_args() {
198+
// CHECK: %[[orig_lb:.*]] = arith.constant 0
199+
// CHECK: %[[orig_step:.*]] = arith.constant 1
200+
// CHECK: %[[orig_ub_k:.*]] = arith.constant 3
201+
// CHECK: %[[orig_ub_i:.*]] = arith.constant 42
202+
// CHECK: %[[orig_ub_j:.*]] = arith.constant 56
203+
%c0 = arith.constant 0 : index
204+
%c1 = arith.constant 1 : index
205+
%c3 = arith.constant 3 : index
206+
%c42 = arith.constant 42 : index
207+
%c56 = arith.constant 56 : index
208+
// The range of the new scf.
209+
// CHECK: %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]]
210+
// CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]]
211+
212+
// Updated loop bounds.
213+
// CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]]) -> (index) {
214+
%2:1 = scf.for %i = %c0 to %c42 step %c1 iter_args(%arg0 = %c0) -> (index) {
215+
// Inner loops must have been removed.
216+
// CHECK-NOT: scf.for
217+
218+
// Reconstruct original IVs from the linearized one.
219+
// CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]]
220+
// CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]]
221+
// CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]]
222+
// CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]]
223+
%1:1 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg1 = %arg0) -> (index){
224+
%0:1 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg2 = %arg1) -> (index) {
225+
// CHECK: %[[res:.*]] = "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]])
226+
%res = "use"(%i, %j, %k) : (index, index, index) -> (index)
227+
// CHECK: scf.yield %[[res]] : index
228+
scf.yield %res : index
229+
}
230+
scf.yield %0#0 : index
231+
}
232+
scf.yield %1#0 : index
233+
}
234+
return
235+
}
236+
117237
// Check with parametric loop bounds and steps, capture the bounds here.
118238
// CHECK-LABEL: @parametric
119239
// CHECK-SAME: %[[orig_lb1:[A-Za-z0-9]+]]:

0 commit comments

Comments
 (0)