Skip to content

Commit 85e8d62

Browse files
authored
[mlir][scf]: Expose emitNormalizedLoopBounds/denormalizeInductionVariable util functions (#94429)
Also adjusted `LoopParams` to use OpFoldResult instead of Value.
1 parent 2efe3d7 commit 85e8d62

File tree

6 files changed

+111
-63
lines changed

6 files changed

+111
-63
lines changed

mlir/include/mlir/Dialect/Arith/Utils/Utils.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,13 @@ llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank,
5454
ArrayRef<int64_t> shape);
5555

5656
/// Converts an OpFoldResult to a Value. Returns the fold result if it casts to
57-
/// a Value or creates a ConstantIndexOp if it casts to an IntegerAttribute.
57+
/// a Value or creates a ConstantOp if it casts to an Integer Attribute.
58+
/// Other attribute types are not supported.
59+
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc,
60+
OpFoldResult ofr);
61+
62+
/// Converts an OpFoldResult to a Value. Returns the fold result if it casts to
63+
/// a Value or creates a ConstantIndexOp if it casts to an Integer Attribute.
5864
/// Other attribute types are not supported.
5965
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
6066
OpFoldResult ofr);
@@ -88,6 +94,10 @@ Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
8894
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
8995
const APFloat &value);
9096

97+
/// Returns the int type of the integer in ofr.
98+
/// Other attribute types are not supported.
99+
Type getType(OpFoldResult ofr);
100+
91101
/// Helper struct to build simple arithmetic quantities with minimal type
92102
/// inference support.
93103
struct ArithBuilder {

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,32 @@ LogicalResult loopUnrollByFactor(
120120
scf::ForOp forOp, uint64_t unrollFactor,
121121
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
122122

123+
/// This structure is to pass and return sets of loop parameters without
124+
/// confusing the order.
125+
struct LoopParams {
126+
OpFoldResult lowerBound;
127+
OpFoldResult upperBound;
128+
OpFoldResult step;
129+
};
130+
131+
/// Transform a loop with a strictly positive step
132+
/// for %i = %lb to %ub step %s
133+
/// into a 0-based loop with step 1
134+
/// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
135+
/// %i = %ii * %s + %lb
136+
/// Insert the induction variable remapping in the body of `inner`, which is
137+
/// expected to be either `loop` or another loop perfectly nested under `loop`.
138+
/// Insert the definition of new bounds immediate before `outer`, which is
139+
/// expected to be either `loop` or its parent in the loop nest.
140+
LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
141+
OpFoldResult lb, OpFoldResult ub,
142+
OpFoldResult step);
143+
144+
/// Get back the original induction variable values after loop normalization.
145+
void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
146+
Value normalizedIv, OpFoldResult origLb,
147+
OpFoldResult origStep);
148+
123149
/// Tile a nest of standard for loops rooted at `rootForOp` by finding such
124150
/// parametric tile sizes that the outer loops have a fixed number of iterations
125151
/// as defined in `sizes`.

mlir/lib/Dialect/Arith/Utils/Utils.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,20 @@ llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
100100
return dimsToProject;
101101
}
102102

103+
Value mlir::getValueOrCreateConstantIntOp(OpBuilder &b, Location loc,
104+
OpFoldResult ofr) {
105+
if (auto value = dyn_cast_if_present<Value>(ofr))
106+
return value;
107+
auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
108+
return b.create<arith::ConstantOp>(
109+
loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue()));
110+
}
111+
103112
Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
104113
OpFoldResult ofr) {
105-
if (auto value = llvm::dyn_cast_if_present<Value>(ofr))
114+
if (auto value = dyn_cast_if_present<Value>(ofr))
106115
return value;
107-
auto attr = dyn_cast<IntegerAttr>(llvm::dyn_cast_if_present<Attribute>(ofr));
108-
assert(attr && "expect the op fold result casts to an integer attribute");
116+
auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
109117
return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
110118
}
111119

@@ -294,6 +302,13 @@ Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
294302
return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
295303
}
296304

305+
Type mlir::getType(OpFoldResult ofr) {
306+
if (auto value = dyn_cast_if_present<Value>(ofr))
307+
return value.getType();
308+
auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
309+
return attr.getType();
310+
}
311+
297312
Value ArithBuilder::_and(Value lhs, Value rhs) {
298313
return b.create<arith::AndIOp>(loc, lhs, rhs);
299314
}

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 39 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/SCF/IR/SCF.h"
1919
#include "mlir/IR/BuiltinOps.h"
2020
#include "mlir/IR/IRMapping.h"
21+
#include "mlir/IR/OpDefinition.h"
2122
#include "mlir/IR/PatternMatch.h"
2223
#include "mlir/Interfaces/SideEffectInterfaces.h"
2324
#include "mlir/Transforms/RegionUtils.h"
@@ -29,16 +30,6 @@
2930

3031
using namespace mlir;
3132

32-
namespace {
33-
// This structure is to pass and return sets of loop parameters without
34-
// confusing the order.
35-
struct LoopParams {
36-
Value lowerBound;
37-
Value upperBound;
38-
Value step;
39-
};
40-
} // namespace
41-
4233
SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
4334
RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
4435
ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
@@ -473,17 +464,9 @@ LogicalResult mlir::loopUnrollByFactor(
473464
return success();
474465
}
475466

476-
/// Transform a loop with a strictly positive step
477-
/// for %i = %lb to %ub step %s
478-
/// into a 0-based loop with step 1
479-
/// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
480-
/// %i = %ii * %s + %lb
481-
/// Insert the induction variable remapping in the body of `inner`, which is
482-
/// expected to be either `loop` or another loop perfectly nested under `loop`.
483-
/// Insert the definition of new bounds immediate before `outer`, which is
484-
/// expected to be either `loop` or its parent in the loop nest.
485-
static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
486-
Value lb, Value ub, Value step) {
467+
LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
468+
OpFoldResult lb, OpFoldResult ub,
469+
OpFoldResult step) {
487470
// For non-index types, generate `arith` instructions
488471
// Check if the loop is already known to have a constant zero lower bound or
489472
// a constant one step.
@@ -495,45 +478,54 @@ static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
495478
if (auto stepCst = getConstantIntValue(step))
496479
isStepOne = stepCst.value() == 1;
497480

481+
Type loopParamsType = getType(lb);
482+
assert(loopParamsType == getType(ub) && loopParamsType == getType(step) &&
483+
"expected matching types");
484+
498485
// Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
499486
// assuming the step is strictly positive. Update the bounds and the step
500487
// of the loop to go from 0 to the number of iterations, if necessary.
501488
if (isZeroBased && isStepOne)
502489
return {lb, ub, step};
503490

504-
Value diff = isZeroBased ? ub : rewriter.create<arith::SubIOp>(loc, ub, lb);
505-
Value newUpperBound =
506-
isStepOne ? diff : rewriter.create<arith::CeilDivSIOp>(loc, diff, step);
491+
OpFoldResult diff = ub;
492+
if (!isZeroBased) {
493+
diff = rewriter.createOrFold<arith::SubIOp>(
494+
loc, getValueOrCreateConstantIntOp(rewriter, loc, ub),
495+
getValueOrCreateConstantIntOp(rewriter, loc, lb));
496+
}
497+
OpFoldResult newUpperBound = diff;
498+
if (!isStepOne) {
499+
newUpperBound = rewriter.createOrFold<arith::CeilDivSIOp>(
500+
loc, getValueOrCreateConstantIntOp(rewriter, loc, diff),
501+
getValueOrCreateConstantIntOp(rewriter, loc, step));
502+
}
507503

508-
Value newLowerBound = isZeroBased
509-
? lb
510-
: rewriter.create<arith::ConstantOp>(
511-
loc, rewriter.getZeroAttr(lb.getType()));
512-
Value newStep = isStepOne
513-
? step
514-
: rewriter.create<arith::ConstantOp>(
515-
loc, rewriter.getIntegerAttr(step.getType(), 1));
504+
OpFoldResult newLowerBound = rewriter.getZeroAttr(loopParamsType);
505+
OpFoldResult newStep = rewriter.getOneAttr(loopParamsType);
516506

517507
return {newLowerBound, newUpperBound, newStep};
518508
}
519509

520-
/// Get back the original induction variable values after loop normalization
521-
static void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
522-
Value normalizedIv, Value origLb,
523-
Value origStep) {
510+
void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
511+
Value normalizedIv, OpFoldResult origLb,
512+
OpFoldResult origStep) {
524513
Value denormalizedIv;
525514
SmallPtrSet<Operation *, 2> preserve;
526515
bool isStepOne = isConstantIntValue(origStep, 1);
527516
bool isZeroBased = isConstantIntValue(origLb, 0);
528517

529518
Value scaled = normalizedIv;
530519
if (!isStepOne) {
531-
scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStep);
520+
Value origStepValue =
521+
getValueOrCreateConstantIntOp(rewriter, loc, origStep);
522+
scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStepValue);
532523
preserve.insert(scaled.getDefiningOp());
533524
}
534525
denormalizedIv = scaled;
535526
if (!isZeroBased) {
536-
denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLb);
527+
Value origLbValue = getValueOrCreateConstantIntOp(rewriter, loc, origLb);
528+
denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLbValue);
537529
preserve.insert(denormalizedIv.getDefiningOp());
538530
}
539531

@@ -638,9 +630,12 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
638630
emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
639631

640632
rewriter.modifyOpInPlace(loop, [&]() {
641-
loop.setLowerBound(newLoopParams.lowerBound);
642-
loop.setUpperBound(newLoopParams.upperBound);
643-
loop.setStep(newLoopParams.step);
633+
loop.setLowerBound(getValueOrCreateConstantIntOp(
634+
rewriter, loop.getLoc(), newLoopParams.lowerBound));
635+
loop.setUpperBound(getValueOrCreateConstantIntOp(
636+
rewriter, loop.getLoc(), newLoopParams.upperBound));
637+
loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
638+
newLoopParams.step));
644639
});
645640

646641
rewriter.setInsertionPointToStart(innermost.getBody());
@@ -778,18 +773,16 @@ void mlir::collapseParallelLoops(
778773
llvm::sort(dims);
779774

780775
// Normalize ParallelOp's iteration pattern.
781-
SmallVector<Value, 3> normalizedLowerBounds, normalizedSteps,
782-
normalizedUpperBounds;
776+
SmallVector<Value, 3> normalizedUpperBounds;
783777
for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
784778
OpBuilder::InsertionGuard g2(rewriter);
785779
rewriter.setInsertionPoint(loops);
786780
Value lb = loops.getLowerBound()[i];
787781
Value ub = loops.getUpperBound()[i];
788782
Value step = loops.getStep()[i];
789783
auto newLoopParams = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
790-
normalizedLowerBounds.push_back(newLoopParams.lowerBound);
791-
normalizedUpperBounds.push_back(newLoopParams.upperBound);
792-
normalizedSteps.push_back(newLoopParams.step);
784+
normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
785+
rewriter, loops.getLoc(), newLoopParams.upperBound));
793786

794787
rewriter.setInsertionPointToStart(loops.getBody());
795788
denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,

mlir/test/Dialect/Affine/loop-coalescing.mlir

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,32 +74,27 @@ func.func @multi_use() {
7474

7575
func.func @unnormalized_loops() {
7676
// CHECK: %[[orig_step_i:.*]] = arith.constant 2
77-
// CHECK: %[[orig_step_j:.*]] = arith.constant 3
77+
78+
// CHECK: %[[orig_step_j_and_numiter_i:.*]] = arith.constant 3
7879
// CHECK: %[[orig_lb_i:.*]] = arith.constant 5
7980
// CHECK: %[[orig_lb_j:.*]] = arith.constant 7
80-
// CHECK: %[[orig_ub_i:.*]] = arith.constant 10
81-
// CHECK: %[[orig_ub_j:.*]] = arith.constant 17
8281
%c2 = arith.constant 2 : index
8382
%c3 = arith.constant 3 : index
8483
%c5 = arith.constant 5 : index
8584
%c7 = arith.constant 7 : index
8685
%c10 = arith.constant 10 : index
8786
%c17 = arith.constant 17 : index
8887

89-
// Number of iterations in the outer scf.
90-
// CHECK: %[[diff_i:.*]] = arith.subi %[[orig_ub_i]], %[[orig_lb_i]]
91-
// CHECK: %[[numiter_i:.*]] = arith.ceildivsi %[[diff_i]], %[[orig_step_i]]
92-
9388
// Normalized lower bound and step for the outer scf.
9489
// CHECK: %[[lb_i:.*]] = arith.constant 0
9590
// CHECK: %[[step_i:.*]] = arith.constant 1
9691

9792
// Number of iterations in the inner loop, the pattern is the same as above,
9893
// only capture the final result.
99-
// CHECK: %[[numiter_j:.*]] = arith.ceildivsi {{.*}}, %[[orig_step_j]]
94+
// CHECK: %[[numiter_j:.*]] = arith.constant 4
10095

10196
// New bounds of the outer scf.
102-
// CHECK: %[[range:.*]] = arith.muli %[[numiter_i]], %[[numiter_j]]
97+
// CHECK: %[[range:.*]] = arith.muli %[[orig_step_j_and_numiter_i:.*]], %[[numiter_j]]
10398
// CHECK: scf.for %[[i:.*]] = %[[lb_i]] to %[[range]] step %[[step_i]]
10499
scf.for %i = %c5 to %c10 step %c2 {
105100
// The inner loop has been removed.
@@ -108,7 +103,7 @@ func.func @unnormalized_loops() {
108103
// The IVs are rewritten.
109104
// CHECK: %[[normalized_j:.*]] = arith.remsi %[[i]], %[[numiter_j]]
110105
// CHECK: %[[normalized_i:.*]] = arith.divsi %[[i]], %[[numiter_j]]
111-
// CHECK: %[[scaled_j:.*]] = arith.muli %[[normalized_j]], %[[orig_step_j]]
106+
// CHECK: %[[scaled_j:.*]] = arith.muli %[[normalized_j]], %[[orig_step_j_and_numiter_i]]
112107
// CHECK: %[[orig_j:.*]] = arith.addi %[[scaled_j]], %[[orig_lb_j]]
113108
// CHECK: %[[scaled_i:.*]] = arith.muli %[[normalized_i]], %[[orig_step_i]]
114109
// CHECK: %[[orig_i:.*]] = arith.addi %[[scaled_i]], %[[orig_lb_i]]

mlir/test/Dialect/SCF/transform-ops.mlir

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,22 @@ module attributes {transform.with_named_sequence} {
277277

278278
// This test checks for loop coalescing success for non-index loop boundaries and step type
279279
func.func @coalesce_i32_loops() {
280+
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32
281+
// CHECK: %[[VAL_1:.*]] = arith.constant 128 : i32
282+
// CHECK: %[[VAL_2:.*]] = arith.constant 2 : i32
283+
// CHECK: %[[VAL_3:.*]] = arith.constant 64 : i32
280284
%0 = arith.constant 0 : i32
281285
%1 = arith.constant 128 : i32
282286
%2 = arith.constant 2 : i32
283287
%3 = arith.constant 64 : i32
284-
// CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32
285-
// CHECK-DAG: %[[C1_I32:.*]] = arith.constant 1 : i32
286-
// CHECK: scf.for %[[ARG0:.*]] = %[[C0_I32]] to {{.*}} step %[[C1_I32]] : i32
288+
// CHECK: %[[VAL_4:.*]] = arith.constant 64 : i32
289+
// CHECK: %[[ZERO:.*]] = arith.constant 0 : i32
290+
// CHECK: %[[ONE:.*]] = arith.constant 1 : i32
291+
// CHECK: %[[VAL_7:.*]] = arith.constant 32 : i32
292+
// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32
293+
// CHECK: %[[VAL_9:.*]] = arith.constant 1 : i32
294+
// CHECK: %[[UB:.*]] = arith.muli %[[VAL_4]], %[[VAL_7]] : i32
295+
// CHECK: scf.for %[[VAL_11:.*]] = %[[ZERO]] to %[[UB]] step %[[ONE]] : i32 {
287296
scf.for %i = %0 to %1 step %2 : i32 {
288297
scf.for %j = %0 to %3 step %2 : i32 {
289298
arith.addi %i, %j : i32

0 commit comments

Comments
 (0)