Skip to content

Commit c08c6a7

Browse files
authored
[mlir][scf] Allow unrolling loops with integer-typed IV. (#106164)
SCF loops now can operate on integer-typed IV, thus I'm changing the loop unroller correspondingly.
1 parent e05c224 commit c08c6a7

File tree

2 files changed

+62
-14
lines changed

2 files changed

+62
-14
lines changed

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -270,11 +270,13 @@ bool mlir::getInnermostParallelLoops(Operation *rootOp,
270270
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
271271
int64_t divisor) {
272272
assert(divisor > 0 && "expected positive divisor");
273-
assert(dividend.getType().isIndex() && "expected index-typed value");
273+
assert(dividend.getType().isIntOrIndex() &&
274+
"expected integer or index-typed value");
274275

275-
Value divisorMinusOneCst =
276-
builder.create<arith::ConstantIndexOp>(loc, divisor - 1);
277-
Value divisorCst = builder.create<arith::ConstantIndexOp>(loc, divisor);
276+
Value divisorMinusOneCst = builder.create<arith::ConstantOp>(
277+
loc, builder.getIntegerAttr(dividend.getType(), divisor - 1));
278+
Value divisorCst = builder.create<arith::ConstantOp>(
279+
loc, builder.getIntegerAttr(dividend.getType(), divisor));
278280
Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
279281
return builder.create<arith::DivUIOp>(loc, sum, divisorCst);
280282
}
@@ -285,9 +287,10 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
285287
// where divis is rounding-to-zero division.
286288
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
287289
Value divisor) {
288-
assert(dividend.getType().isIndex() && "expected index-typed value");
289-
290-
Value cstOne = builder.create<arith::ConstantIndexOp>(loc, 1);
290+
assert(dividend.getType().isIntOrIndex() &&
291+
"expected integer or index-typed value");
292+
Value cstOne = builder.create<arith::ConstantOp>(
293+
loc, builder.getOneAttr(dividend.getType()));
291294
Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne);
292295
Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne);
293296
return builder.create<arith::DivUIOp>(loc, sum, divisor);
@@ -409,16 +412,18 @@ LogicalResult mlir::loopUnrollByFactor(
409412
// Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
410413
generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
411414
if (generateEpilogueLoop)
412-
upperBoundUnrolled = boundsBuilder.create<arith::ConstantIndexOp>(
413-
loc, upperBoundUnrolledCst);
415+
upperBoundUnrolled = boundsBuilder.create<arith::ConstantOp>(
416+
loc, boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(),
417+
upperBoundUnrolledCst));
414418
else
415419
upperBoundUnrolled = forOp.getUpperBound();
416420

417421
// Create constant for 'stepUnrolled'.
418422
stepUnrolled = stepCst == stepUnrolledCst
419423
? step
420-
: boundsBuilder.create<arith::ConstantIndexOp>(
421-
loc, stepUnrolledCst);
424+
: boundsBuilder.create<arith::ConstantOp>(
425+
loc, boundsBuilder.getIntegerAttr(
426+
step.getType(), stepUnrolledCst));
422427
} else {
423428
// Dynamic loop bounds computation.
424429
// TODO: Add dynamic asserts for negative lb/ub/step, or
@@ -428,8 +433,8 @@ LogicalResult mlir::loopUnrollByFactor(
428433
Value diff =
429434
boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
430435
Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
431-
Value unrollFactorCst =
432-
boundsBuilder.create<arith::ConstantIndexOp>(loc, unrollFactor);
436+
Value unrollFactorCst = boundsBuilder.create<arith::ConstantOp>(
437+
loc, boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor));
433438
Value tripCountRem =
434439
boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
435440
// Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
@@ -476,7 +481,9 @@ LogicalResult mlir::loopUnrollByFactor(
476481
[&](unsigned i, Value iv, OpBuilder b) {
477482
// iv' = iv + step * i;
478483
auto stride = b.create<arith::MulIOp>(
479-
loc, step, b.create<arith::ConstantIndexOp>(loc, i));
484+
loc, step,
485+
b.create<arith::ConstantOp>(loc,
486+
b.getIntegerAttr(iv.getType(), i)));
480487
return b.create<arith::AddIOp>(loc, iv, stride);
481488
},
482489
annotateFn, iterArgs, yieldedValues);

mlir/test/Dialect/SCF/loop-unroll.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,44 @@ func.func @loop_unroll_yield_iter_arg() {
448448
// CHECK-NEXT: affine.yield %[[ITER_ARG]] : index
449449
// CHECK-NEXT: }
450450
// CHECK-NEXT: return
451+
452+
// -----
453+
454+
// Test the loop unroller works with integer IV type.
455+
func.func @static_loop_unroll_with_integer_iv() -> (f32, f32) {
456+
%0 = arith.constant 7.0 : f32
457+
%lb = arith.constant 0 : i32
458+
%ub = arith.constant 20 : i32
459+
%step = arith.constant 1 : i32
460+
%result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) : i32{
461+
%add = arith.addf %arg0, %arg1 : f32
462+
%mul = arith.mulf %arg0, %arg1 : f32
463+
scf.yield %add, %mul : f32, f32
464+
}
465+
return %result#0, %result#1 : f32, f32
466+
}
467+
// UNROLL-BY-3-LABEL: func @static_loop_unroll_with_integer_iv
468+
//
469+
// UNROLL-BY-3-DAG: %[[CST:.*]] = arith.constant {{.*}} : f32
470+
// UNROLL-BY-3-DAG: %[[C0:.*]] = arith.constant 0 : i32
471+
// UNROLL-BY-3-DAG: %[[C1:.*]] = arith.constant 1 : i32
472+
// UNROLL-BY-3-DAG: %[[C20:.*]] = arith.constant 20 : i32
473+
// UNROLL-BY-3-DAG: %[[C18:.*]] = arith.constant 18 : i32
474+
// UNROLL-BY-3-DAG: %[[C3:.*]] = arith.constant 3 : i32
475+
// UNROLL-BY-3: %[[FOR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C18]] step %[[C3]]
476+
// UNROLL-BY-3-SAME: iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) : i32 {
477+
// UNROLL-BY-3-NEXT: %[[ADD0:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : f32
478+
// UNROLL-BY-3-NEXT: %[[MUL0:.*]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
479+
// UNROLL-BY-3-NEXT: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[MUL0]] : f32
480+
// UNROLL-BY-3-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD0]], %[[MUL0]] : f32
481+
// UNROLL-BY-3-NEXT: %[[ADD2:.*]] = arith.addf %[[ADD1]], %[[MUL1]] : f32
482+
// UNROLL-BY-3-NEXT: %[[MUL2:.*]] = arith.mulf %[[ADD1]], %[[MUL1]] : f32
483+
// UNROLL-BY-3-NEXT: scf.yield %[[ADD2]], %[[MUL2]] : f32, f32
484+
// UNROLL-BY-3-NEXT: }
485+
// UNROLL-BY-3: %[[EFOR:.*]]:2 = scf.for %[[EIV:.*]] = %[[C18]] to %[[C20]] step %[[C1]]
486+
// UNROLL-BY-3-SAME: iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) : i32 {
487+
// UNROLL-BY-3-NEXT: %[[EADD:.*]] = arith.addf %[[EARG0]], %[[EARG1]] : f32
488+
// UNROLL-BY-3-NEXT: %[[EMUL:.*]] = arith.mulf %[[EARG0]], %[[EARG1]] : f32
489+
// UNROLL-BY-3-NEXT: scf.yield %[[EADD]], %[[EMUL]] : f32, f32
490+
// UNROLL-BY-3-NEXT: }
491+
// UNROLL-BY-3-NEXT: return %[[EFOR]]#0, %[[EFOR]]#1 : f32, f32

0 commit comments

Comments
 (0)