Skip to content

Commit 488b3b6

Browse files
committed
[mlir][scf] Allow unrolling loops with integer-typed IV.
1 parent 657ec73 commit 488b3b6

File tree

4 files changed

+41
-26
lines changed

4 files changed

+41
-26
lines changed

mlir/include/mlir/Dialect/Arith/Utils/Utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
9393
int64_t value);
9494
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
9595
const APFloat &value);
96+
Value createIntOrIndexConstant(OpBuilder &builder, Location loc, Type type,
97+
int64_t value);
9698

9799
/// Returns the int type of the integer in ofr.
98100
/// Other attribute types are not supported.

mlir/lib/Dialect/Arith/Utils/Utils.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,16 @@ Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
302302
return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
303303
}
304304

305+
Value mlir::createIntOrIndexConstant(OpBuilder &b, Location loc, Type type,
306+
int64_t value) {
307+
assert(type.isIntOrIndex() &&
308+
"unexpected type other than integers and index");
309+
if (type.isIndex())
310+
return b.create<arith::ConstantIndexOp>(loc, value);
311+
else
312+
return b.create<arith::ConstantOp>(loc, b.getIntegerAttr(type, value));
313+
}
314+
305315
Type mlir::getType(OpFoldResult ofr) {
306316
if (auto value = dyn_cast_if_present<Value>(ofr))
307317
return value.getType();

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -264,11 +264,13 @@ bool mlir::getInnermostParallelLoops(Operation *rootOp,
264264
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
265265
int64_t divisor) {
266266
assert(divisor > 0 && "expected positive divisor");
267-
assert(dividend.getType().isIndex() && "expected index-typed value");
267+
assert(dividend.getType().isIntOrIndex() &&
268+
"expected integer or index-typed value");
268269

269270
Value divisorMinusOneCst =
270-
builder.create<arith::ConstantIndexOp>(loc, divisor - 1);
271-
Value divisorCst = builder.create<arith::ConstantIndexOp>(loc, divisor);
271+
createIntOrIndexConstant(builder, loc, dividend.getType(), divisor - 1);
272+
Value divisorCst =
273+
createIntOrIndexConstant(builder, loc, dividend.getType(), divisor);
272274
Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
273275
return builder.create<arith::DivUIOp>(loc, sum, divisorCst);
274276
}
@@ -279,9 +281,9 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
279281
// where divis is rounding-to-zero division.
280282
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
281283
Value divisor) {
282-
assert(dividend.getType().isIndex() && "expected index-typed value");
283-
284-
Value cstOne = builder.create<arith::ConstantIndexOp>(loc, 1);
284+
assert(dividend.getType().isIntOrIndex() &&
285+
"expected integer or index-typed value");
286+
Value cstOne = createIntOrIndexConstant(builder, loc, dividend.getType(), 1);
285287
Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne);
286288
Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne);
287289
return builder.create<arith::DivUIOp>(loc, sum, divisor);
@@ -388,16 +390,17 @@ LogicalResult mlir::loopUnrollByFactor(
388390
// Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
389391
generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
390392
if (generateEpilogueLoop)
391-
upperBoundUnrolled = boundsBuilder.create<arith::ConstantIndexOp>(
392-
loc, upperBoundUnrolledCst);
393+
upperBoundUnrolled = createIntOrIndexConstant(
394+
boundsBuilder, loc, forOp.getUpperBound().getType(), upperBoundUnrolledCst);
393395
else
394396
upperBoundUnrolled = forOp.getUpperBound();
395397

396398
// Create constant for 'stepUnrolled'.
397-
stepUnrolled = stepCst == stepUnrolledCst
398-
? step
399-
: boundsBuilder.create<arith::ConstantIndexOp>(
400-
loc, stepUnrolledCst);
399+
stepUnrolled =
400+
stepCst == stepUnrolledCst
401+
? step
402+
: createIntOrIndexConstant(boundsBuilder, loc, step.getType(),
403+
stepUnrolledCst);
401404
} else {
402405
// Dynamic loop bounds computation.
403406
// TODO: Add dynamic asserts for negative lb/ub/step, or
@@ -407,8 +410,8 @@ LogicalResult mlir::loopUnrollByFactor(
407410
Value diff =
408411
boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
409412
Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
410-
Value unrollFactorCst =
411-
boundsBuilder.create<arith::ConstantIndexOp>(loc, unrollFactor);
413+
Value unrollFactorCst = createIntOrIndexConstant(
414+
boundsBuilder, loc, tripCount.getType(), unrollFactor);
412415
Value tripCountRem =
413416
boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
414417
// Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
@@ -455,7 +458,7 @@ LogicalResult mlir::loopUnrollByFactor(
455458
[&](unsigned i, Value iv, OpBuilder b) {
456459
// iv' = iv + step * i;
457460
auto stride = b.create<arith::MulIOp>(
458-
loc, step, b.create<arith::ConstantIndexOp>(loc, i));
461+
loc, step, createIntOrIndexConstant(b, loc, iv.getType(), i));
459462
return b.create<arith::AddIOp>(loc, iv, stride);
460463
},
461464
annotateFn, iterArgs, yieldedValues);

mlir/test/Dialect/SCF/loop-unroll.mlir

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -311,10 +311,10 @@ func.func @static_loop_unroll_up_to_factor(%arg0 : memref<?xf32>) {
311311
// Test that epilogue's arguments are correctly renamed.
312312
func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) {
313313
%0 = arith.constant 7.0 : f32
314-
%lb = arith.constant 0 : index
315-
%ub = arith.constant 20 : index
316-
%step = arith.constant 1 : index
317-
%result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) {
314+
%lb = arith.constant 0 : i32
315+
%ub = arith.constant 20 : i32
316+
%step = arith.constant 1 : i32
317+
%result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) : i32{
318318
%add = arith.addf %arg0, %arg1 : f32
319319
%mul = arith.mulf %arg0, %arg1 : f32
320320
scf.yield %add, %mul : f32, f32
@@ -324,13 +324,13 @@ func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) {
324324
// UNROLL-BY-3-LABEL: func @static_loop_unroll_by_3_rename_epilogue_arguments
325325
//
326326
// UNROLL-BY-3-DAG: %[[CST:.*]] = arith.constant {{.*}} : f32
327-
// UNROLL-BY-3-DAG: %[[C0:.*]] = arith.constant 0 : index
328-
// UNROLL-BY-3-DAG: %[[C1:.*]] = arith.constant 1 : index
329-
// UNROLL-BY-3-DAG: %[[C20:.*]] = arith.constant 20 : index
330-
// UNROLL-BY-3-DAG: %[[C18:.*]] = arith.constant 18 : index
331-
// UNROLL-BY-3-DAG: %[[C3:.*]] = arith.constant 3 : index
327+
// UNROLL-BY-3-DAG: %[[C0:.*]] = arith.constant 0 : i32
328+
// UNROLL-BY-3-DAG: %[[C1:.*]] = arith.constant 1 : i32
329+
// UNROLL-BY-3-DAG: %[[C20:.*]] = arith.constant 20 : i32
330+
// UNROLL-BY-3-DAG: %[[C18:.*]] = arith.constant 18 : i32
331+
// UNROLL-BY-3-DAG: %[[C3:.*]] = arith.constant 3 : i32
332332
// UNROLL-BY-3: %[[FOR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C18]] step %[[C3]]
333-
// UNROLL-BY-3-SAME: iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) {
333+
// UNROLL-BY-3-SAME: iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) : i32 {
334334
// UNROLL-BY-3-NEXT: %[[ADD0:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : f32
335335
// UNROLL-BY-3-NEXT: %[[MUL0:.*]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
336336
// UNROLL-BY-3-NEXT: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[MUL0]] : f32
@@ -340,7 +340,7 @@ func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) {
340340
// UNROLL-BY-3-NEXT: scf.yield %[[ADD2]], %[[MUL2]] : f32, f32
341341
// UNROLL-BY-3-NEXT: }
342342
// UNROLL-BY-3: %[[EFOR:.*]]:2 = scf.for %[[EIV:.*]] = %[[C18]] to %[[C20]] step %[[C1]]
343-
// UNROLL-BY-3-SAME: iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) {
343+
// UNROLL-BY-3-SAME: iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) : i32 {
344344
// UNROLL-BY-3-NEXT: %[[EADD:.*]] = arith.addf %[[EARG0]], %[[EARG1]] : f32
345345
// UNROLL-BY-3-NEXT: %[[EMUL:.*]] = arith.mulf %[[EARG0]], %[[EARG1]] : f32
346346
// UNROLL-BY-3-NEXT: scf.yield %[[EADD]], %[[EMUL]] : f32, f32

0 commit comments

Comments
 (0)