Skip to content

Commit fd673e8

Browse files
[MLIR][SCF] Removes incorrect assertion in loop unroller (#69028)
In particular, `upperBoundUnrolledCst` may be larger than `ubCst` when: 1. the step size is greater than 1; 2. `ub - lb` is not evenly divisible by the step size; and 3. the loop's trip count is evenly divisible by the unroll factor. This is okay since the non-unit step size ensures that the unrolled loop maintains the same trip count as the original loop. Added a test case for this. Fixes #61832. Co-authored-by: Stephen Chou <[email protected]>
1 parent 019d67f commit fd673e8

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,6 @@ LogicalResult mlir::loopUnrollByFactor(
391391

392392
int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
393393
int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
394-
assert(upperBoundUnrolledCst <= ubCst);
395394
int64_t stepUnrolledCst = stepCst * unrollFactor;
396395

397396
// Create constant for 'upperBoundUnrolled' and set epilogue loop flag.

mlir/test/Dialect/SCF/loop-unroll.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,36 @@ func.func @static_loop_unroll_by_2(%arg0 : memref<?xf32>) {
186186
// UNROLL-BY-2-ANNOTATE: memref.store %{{.*}}, %[[MEM:.*0]][%{{.*}}] {unrolled_iteration = 0 : ui32} : memref<?xf32>
187187
// UNROLL-BY-2-ANNOTATE: memref.store %{{.*}}, %[[MEM]][%{{.*}}] {unrolled_iteration = 1 : ui32} : memref<?xf32>
188188

189+
// Test that no epilogue clean-up loop is generated because the trip count
190+
// (taking into account the non-unit step size) is a multiple of the unroll
191+
// factor.
192+
func.func @static_loop_step_2_unroll_by_2(%arg0 : memref<?xf32>) {
193+
%0 = arith.constant 7.0 : f32
194+
%lb = arith.constant 0 : index
195+
%ub = arith.constant 19 : index
196+
%step = arith.constant 2 : index
197+
scf.for %i0 = %lb to %ub step %step {
198+
memref.store %0, %arg0[%i0] : memref<?xf32>
199+
}
200+
return
201+
}
202+
203+
// UNROLL-BY-2-LABEL: func @static_loop_step_2_unroll_by_2
204+
// UNROLL-BY-2-SAME: %[[MEM:.*0]]: memref<?xf32>
205+
//
206+
// UNROLL-BY-2-DAG: %[[C0:.*]] = arith.constant 0 : index
207+
// UNROLL-BY-2-DAG: %[[C2:.*]] = arith.constant 2 : index
208+
// UNROLL-BY-2-DAG: %[[C19:.*]] = arith.constant 19 : index
209+
// UNROLL-BY-2-DAG: %[[C4:.*]] = arith.constant 4 : index
210+
// UNROLL-BY-2: scf.for %[[IV:.*]] = %[[C0]] to %[[C19]] step %[[C4]] {
211+
// UNROLL-BY-2-NEXT: memref.store %{{.*}}, %[[MEM]][%[[IV]]] : memref<?xf32>
212+
// UNROLL-BY-2-NEXT: %[[C1_IV:.*]] = arith.constant 1 : index
213+
// UNROLL-BY-2-NEXT: %[[V0:.*]] = arith.muli %[[C2]], %[[C1_IV]] : index
214+
// UNROLL-BY-2-NEXT: %[[V1:.*]] = arith.addi %[[IV]], %[[V0]] : index
215+
// UNROLL-BY-2-NEXT: memref.store %{{.*}}, %[[MEM]][%[[V1]]] : memref<?xf32>
216+
// UNROLL-BY-2-NEXT: }
217+
// UNROLL-BY-2-NEXT: return
218+
189219
// Test that epilogue clean up loop is generated (trip count is not
190220
// a multiple of unroll factor).
191221
func.func @static_loop_unroll_by_3(%arg0 : memref<?xf32>) {

0 commit comments

Comments
 (0)