Skip to content

[mlir][scf]: Expose emitNormalizedLoopBounds/denormalizeInductionVariable util functions (NFC) #94429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion mlir/include/mlir/Dialect/Arith/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank,
ArrayRef<int64_t> shape);

/// Converts an OpFoldResult to a Value. Returns the fold result if it casts to
/// a Value or creates a ConstantIndexOp if it casts to an IntegerAttribute.
/// a Value or creates a ConstantOp if it casts to an Integer Attribute.
/// Other attribute types are not supported.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc,
OpFoldResult ofr);

/// Converts an OpFoldResult to a Value. Returns the fold result if it casts to
/// a Value or creates a ConstantIndexOp if it casts to an Integer Attribute.
/// Other attribute types are not supported.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
OpFoldResult ofr);
Expand Down Expand Up @@ -88,6 +94,10 @@ Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
const APFloat &value);

/// Returns the int type of the integer in ofr.
/// Other attribute types are not supported.
Type getType(OpFoldResult ofr);

/// Helper struct to build simple arithmetic quantities with minimal type
/// inference support.
struct ArithBuilder {
Expand Down
26 changes: 26 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,32 @@ LogicalResult loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);

/// This structure is to pass and return sets of loop parameters without
/// confusing the order.
struct LoopParams {
OpFoldResult lowerBound;
OpFoldResult upperBound;
OpFoldResult step;
};

/// Transform a loop with a strictly positive step
/// for %i = %lb to %ub step %s
/// into a 0-based loop with step 1
/// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
/// %i = %ii * %s + %lb
/// Insert the induction variable remapping in the body of `inner`, which is
/// expected to be either `loop` or another loop perfectly nested under `loop`.
/// Insert the definition of new bounds immediate before `outer`, which is
/// expected to be either `loop` or its parent in the loop nest.
LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
OpFoldResult lb, OpFoldResult ub,
OpFoldResult step);

/// Get back the original induction variable values after loop normalization.
void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
Value normalizedIv, OpFoldResult origLb,
OpFoldResult origStep);

/// Tile a nest of standard for loops rooted at `rootForOp` by finding such
/// parametric tile sizes that the outer loops have a fixed number of iterations
/// as defined in `sizes`.
Expand Down
21 changes: 18 additions & 3 deletions mlir/lib/Dialect/Arith/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,20 @@ llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
return dimsToProject;
}

Value mlir::getValueOrCreateConstantIntOp(OpBuilder &b, Location loc,
OpFoldResult ofr) {
if (auto value = dyn_cast_if_present<Value>(ofr))
return value;
auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
return b.create<arith::ConstantOp>(
loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue()));
}

Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
OpFoldResult ofr) {
if (auto value = llvm::dyn_cast_if_present<Value>(ofr))
if (auto value = dyn_cast_if_present<Value>(ofr))
return value;
auto attr = dyn_cast<IntegerAttr>(llvm::dyn_cast_if_present<Attribute>(ofr));
assert(attr && "expect the op fold result casts to an integer attribute");
auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
}

Expand Down Expand Up @@ -294,6 +302,13 @@ Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
}

Type mlir::getType(OpFoldResult ofr) {
if (auto value = dyn_cast_if_present<Value>(ofr))
return value.getType();
auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
return attr.getType();
}

Value ArithBuilder::_and(Value lhs, Value rhs) {
return b.create<arith::AndIOp>(loc, lhs, rhs);
}
Expand Down
85 changes: 39 additions & 46 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/MathExtras.h"
Expand All @@ -29,16 +30,6 @@

using namespace mlir;

namespace {
// This structure is to pass and return sets of loop parameters without
// confusing the order.
struct LoopParams {
Value lowerBound;
Value upperBound;
Value step;
};
} // namespace

SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
Expand Down Expand Up @@ -473,17 +464,9 @@ LogicalResult mlir::loopUnrollByFactor(
return success();
}

/// Transform a loop with a strictly positive step
/// for %i = %lb to %ub step %s
/// into a 0-based loop with step 1
/// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
/// %i = %ii * %s + %lb
/// Insert the induction variable remapping in the body of `inner`, which is
/// expected to be either `loop` or another loop perfectly nested under `loop`.
/// Insert the definition of new bounds immediate before `outer`, which is
/// expected to be either `loop` or its parent in the loop nest.
static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
Value lb, Value ub, Value step) {
LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
OpFoldResult lb, OpFoldResult ub,
OpFoldResult step) {
// For non-index types, generate `arith` instructions
// Check if the loop is already known to have a constant zero lower bound or
// a constant one step.
Expand All @@ -495,45 +478,54 @@ static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
if (auto stepCst = getConstantIntValue(step))
isStepOne = stepCst.value() == 1;

Type loopParamsType = getType(lb);
assert(loopParamsType == getType(ub) && loopParamsType == getType(step) &&
"expected matching types");

// Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
// assuming the step is strictly positive. Update the bounds and the step
// of the loop to go from 0 to the number of iterations, if necessary.
if (isZeroBased && isStepOne)
return {lb, ub, step};

Value diff = isZeroBased ? ub : rewriter.create<arith::SubIOp>(loc, ub, lb);
Value newUpperBound =
isStepOne ? diff : rewriter.create<arith::CeilDivSIOp>(loc, diff, step);
OpFoldResult diff = ub;
if (!isZeroBased) {
diff = rewriter.createOrFold<arith::SubIOp>(
loc, getValueOrCreateConstantIntOp(rewriter, loc, ub),
getValueOrCreateConstantIntOp(rewriter, loc, lb));
}
OpFoldResult newUpperBound = diff;
if (!isStepOne) {
newUpperBound = rewriter.createOrFold<arith::CeilDivSIOp>(
loc, getValueOrCreateConstantIntOp(rewriter, loc, diff),
getValueOrCreateConstantIntOp(rewriter, loc, step));
}

Value newLowerBound = isZeroBased
? lb
: rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(lb.getType()));
Value newStep = isStepOne
? step
: rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(step.getType(), 1));
OpFoldResult newLowerBound = rewriter.getZeroAttr(loopParamsType);
OpFoldResult newStep = rewriter.getOneAttr(loopParamsType);

return {newLowerBound, newUpperBound, newStep};
}

/// Get back the original induction variable values after loop normalization
static void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
Value normalizedIv, Value origLb,
Value origStep) {
void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
Value normalizedIv, OpFoldResult origLb,
OpFoldResult origStep) {
Value denormalizedIv;
SmallPtrSet<Operation *, 2> preserve;
bool isStepOne = isConstantIntValue(origStep, 1);
bool isZeroBased = isConstantIntValue(origLb, 0);

Value scaled = normalizedIv;
if (!isStepOne) {
scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStep);
Value origStepValue =
getValueOrCreateConstantIntOp(rewriter, loc, origStep);
scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStepValue);
preserve.insert(scaled.getDefiningOp());
}
denormalizedIv = scaled;
if (!isZeroBased) {
denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLb);
Value origLbValue = getValueOrCreateConstantIntOp(rewriter, loc, origLb);
denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLbValue);
preserve.insert(denormalizedIv.getDefiningOp());
}

Expand Down Expand Up @@ -638,9 +630,12 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);

rewriter.modifyOpInPlace(loop, [&]() {
loop.setLowerBound(newLoopParams.lowerBound);
loop.setUpperBound(newLoopParams.upperBound);
loop.setStep(newLoopParams.step);
loop.setLowerBound(getValueOrCreateConstantIntOp(
rewriter, loop.getLoc(), newLoopParams.lowerBound));
loop.setUpperBound(getValueOrCreateConstantIntOp(
rewriter, loop.getLoc(), newLoopParams.upperBound));
loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
newLoopParams.step));
});

rewriter.setInsertionPointToStart(innermost.getBody());
Expand Down Expand Up @@ -778,18 +773,16 @@ void mlir::collapseParallelLoops(
llvm::sort(dims);

// Normalize ParallelOp's iteration pattern.
SmallVector<Value, 3> normalizedLowerBounds, normalizedSteps,
normalizedUpperBounds;
SmallVector<Value, 3> normalizedUpperBounds;
for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
OpBuilder::InsertionGuard g2(rewriter);
rewriter.setInsertionPoint(loops);
Value lb = loops.getLowerBound()[i];
Value ub = loops.getUpperBound()[i];
Value step = loops.getStep()[i];
auto newLoopParams = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
normalizedLowerBounds.push_back(newLoopParams.lowerBound);
normalizedUpperBounds.push_back(newLoopParams.upperBound);
normalizedSteps.push_back(newLoopParams.step);
normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
rewriter, loops.getLoc(), newLoopParams.upperBound));

rewriter.setInsertionPointToStart(loops.getBody());
denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
Expand Down
15 changes: 5 additions & 10 deletions mlir/test/Dialect/Affine/loop-coalescing.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -74,32 +74,27 @@ func.func @multi_use() {

func.func @unnormalized_loops() {
// CHECK: %[[orig_step_i:.*]] = arith.constant 2
// CHECK: %[[orig_step_j:.*]] = arith.constant 3

// CHECK: %[[orig_step_j_and_numiter_i:.*]] = arith.constant 3
// CHECK: %[[orig_lb_i:.*]] = arith.constant 5
// CHECK: %[[orig_lb_j:.*]] = arith.constant 7
// CHECK: %[[orig_ub_i:.*]] = arith.constant 10
// CHECK: %[[orig_ub_j:.*]] = arith.constant 17
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c5 = arith.constant 5 : index
%c7 = arith.constant 7 : index
%c10 = arith.constant 10 : index
%c17 = arith.constant 17 : index

// Number of iterations in the outer scf.
// CHECK: %[[diff_i:.*]] = arith.subi %[[orig_ub_i]], %[[orig_lb_i]]
// CHECK: %[[numiter_i:.*]] = arith.ceildivsi %[[diff_i]], %[[orig_step_i]]

// Normalized lower bound and step for the outer scf.
// CHECK: %[[lb_i:.*]] = arith.constant 0
// CHECK: %[[step_i:.*]] = arith.constant 1

// Number of iterations in the inner loop, the pattern is the same as above,
// only capture the final result.
// CHECK: %[[numiter_j:.*]] = arith.ceildivsi {{.*}}, %[[orig_step_j]]
// CHECK: %[[numiter_j:.*]] = arith.constant 4

// New bounds of the outer scf.
// CHECK: %[[range:.*]] = arith.muli %[[numiter_i]], %[[numiter_j]]
// CHECK: %[[range:.*]] = arith.muli %[[orig_step_j_and_numiter_i:.*]], %[[numiter_j]]
// CHECK: scf.for %[[i:.*]] = %[[lb_i]] to %[[range]] step %[[step_i]]
scf.for %i = %c5 to %c10 step %c2 {
// The inner loop has been removed.
Expand All @@ -108,7 +103,7 @@ func.func @unnormalized_loops() {
// The IVs are rewritten.
// CHECK: %[[normalized_j:.*]] = arith.remsi %[[i]], %[[numiter_j]]
// CHECK: %[[normalized_i:.*]] = arith.divsi %[[i]], %[[numiter_j]]
// CHECK: %[[scaled_j:.*]] = arith.muli %[[normalized_j]], %[[orig_step_j]]
// CHECK: %[[scaled_j:.*]] = arith.muli %[[normalized_j]], %[[orig_step_j_and_numiter_i]]
// CHECK: %[[orig_j:.*]] = arith.addi %[[scaled_j]], %[[orig_lb_j]]
// CHECK: %[[scaled_i:.*]] = arith.muli %[[normalized_i]], %[[orig_step_i]]
// CHECK: %[[orig_i:.*]] = arith.addi %[[scaled_i]], %[[orig_lb_i]]
Expand Down
15 changes: 12 additions & 3 deletions mlir/test/Dialect/SCF/transform-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,22 @@ module attributes {transform.with_named_sequence} {

// This test checks for loop coalescing success for non-index loop boundaries and step type
func.func @coalesce_i32_loops() {
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_1:.*]] = arith.constant 128 : i32
// CHECK: %[[VAL_2:.*]] = arith.constant 2 : i32
// CHECK: %[[VAL_3:.*]] = arith.constant 64 : i32
%0 = arith.constant 0 : i32
%1 = arith.constant 128 : i32
%2 = arith.constant 2 : i32
%3 = arith.constant 64 : i32
// CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[C1_I32:.*]] = arith.constant 1 : i32
// CHECK: scf.for %[[ARG0:.*]] = %[[C0_I32]] to {{.*}} step %[[C1_I32]] : i32
// CHECK: %[[VAL_4:.*]] = arith.constant 64 : i32
// CHECK: %[[ZERO:.*]] = arith.constant 0 : i32
// CHECK: %[[ONE:.*]] = arith.constant 1 : i32
// CHECK: %[[VAL_7:.*]] = arith.constant 32 : i32
// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_9:.*]] = arith.constant 1 : i32
// CHECK: %[[UB:.*]] = arith.muli %[[VAL_4]], %[[VAL_7]] : i32
// CHECK: scf.for %[[VAL_11:.*]] = %[[ZERO]] to %[[UB]] step %[[ONE]] : i32 {
scf.for %i = %0 to %1 step %2 : i32 {
scf.for %j = %0 to %3 step %2 : i32 {
arith.addi %i, %j : i32
Expand Down
Loading