Skip to content

Commit 59569eb

Browse files
authored
[mlir] Fix support for loop normalization with integer indices (#76566)
Choose correct type for updated loop boundaries after scf loop normalization, do not force chosen type to IndexType
1 parent a0e6b7c commit 59569eb

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,9 +502,12 @@ static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
502502

503503
Value newLowerBound =
504504
isZeroBased ? lowerBound
505-
: boundsBuilder.create<arith::ConstantIndexOp>(loc, 0);
505+
: boundsBuilder.create<arith::ConstantOp>(
506+
loc, boundsBuilder.getZeroAttr(lowerBound.getType()));
506507
Value newStep =
507-
isStepOne ? step : boundsBuilder.create<arith::ConstantIndexOp>(loc, 1);
508+
isStepOne ? step
509+
: boundsBuilder.create<arith::ConstantOp>(
510+
loc, boundsBuilder.getIntegerAttr(step.getType(), 1));
508511

509512
// Insert code computing the value of the original loop induction variable
510513
// from the "normalized" one.

mlir/test/Dialect/SCF/transform-ops.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,33 @@ module attributes {transform.with_named_sequence} {
270270
transform.yield
271271
}
272272
}
273+
274+
// -----
275+
276+
// CHECK-LABEL: func @coalesce_i32_loops(
277+
278+
// This test checks for loop coalescing success for non-index loop boundaries and step type
279+
func.func @coalesce_i32_loops() {
280+
%0 = arith.constant 0 : i32
281+
%1 = arith.constant 128 : i32
282+
%2 = arith.constant 2 : i32
283+
%3 = arith.constant 64 : i32
284+
// CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32
285+
// CHECK-DAG: %[[C1_I32:.*]] = arith.constant 1 : i32
286+
// CHECK: scf.for %[[ARG0:.*]] = %[[C0_I32]] to {{.*}} step %[[C1_I32]] : i32
287+
scf.for %i = %0 to %1 step %2 : i32 {
288+
scf.for %j = %0 to %3 step %2 : i32 {
289+
arith.addi %i, %j : i32
290+
}
291+
} {coalesce}
292+
return
293+
}
294+
295+
module attributes {transform.with_named_sequence} {
296+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
297+
%0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op
298+
%1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for">
299+
%2 = transform.loop.coalesce %1: (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
300+
transform.yield
301+
}
302+
}

0 commit comments

Comments
 (0)