Skip to content

[mlir][Affine] Expand affine.[de]linearize_index without affine maps #116703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Affine/LoopUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,11 @@ separateFullTiles(MutableArrayRef<AffineForOp> nest,
/// Walk an affine.for to find a band to coalesce.
LogicalResult coalescePerfectlyNestedAffineLoops(AffineForOp op);

/// Count the number of loops surrounding `operand` such that operand could be
/// hoisted above.
/// Stop counting at the first loop over which the operand cannot be hoisted.
/// This counts any LoopLikeOpInterface, not just affine.for.
int64_t numEnclosingInvariantLoops(OpOperand &operand);
} // namespace affine
} // namespace mlir

Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Affine/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ std::unique_ptr<OperationPass<func::FuncOp>> createPipelineDataTransferPass();
/// operations (not necessarily restricted to Affine dialect).
std::unique_ptr<Pass> createAffineExpandIndexOpsPass();

/// Creates a pass to expand affine index operations into affine.apply
/// operations.
std::unique_ptr<Pass> createAffineExpandIndexOpsAsAffinePass();

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Affine/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -408,4 +408,9 @@ def AffineExpandIndexOps : Pass<"affine-expand-index-ops"> {
let constructor = "mlir::affine::createAffineExpandIndexOpsPass()";
}

def AffineExpandIndexOpsAsAffine : Pass<"affine-expand-index-ops-as-affine"> {
let summary = "Lower affine operations operating on indices into affine.apply operations";
let constructor = "mlir::affine::createAffineExpandIndexOpsAsAffinePass()";
}

#endif // MLIR_DIALECT_AFFINE_PASSES
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class AffineApplyOp;
/// operations (not necessarily restricted to Affine dialect).
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns);

/// Populate patterns that expand affine index operations into their equivalent
/// `affine.apply` representations.
void populateAffineExpandIndexOpsAsAffinePatterns(RewritePatternSet &patterns);

/// Helper function to rewrite `op`'s affine map and reorder its operands such
/// that they are in increasing order of hoistability (i.e. the least hoistable)
/// operands come first in the operand list.
Expand Down
148 changes: 134 additions & 14 deletions mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// fundamental operations.
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Affine/Passes.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
Expand All @@ -28,6 +29,50 @@ namespace affine {
using namespace mlir;
using namespace mlir::affine;

/// Given a basis (in static and dynamic components), return the sequence of
/// suffix products of the basis, including the product of the entire basis,
/// which must **not** contain an outer bound.
///
/// If excess dynamic values are provided, the values at the beginning
/// will be ignored. This allows for dropping the outer bound without
/// needing to manipulate the dynamic value array.
static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
ValueRange dynamicBasis,
ArrayRef<int64_t> staticBasis) {
if (staticBasis.empty())
return {};

SmallVector<Value> result;
result.reserve(staticBasis.size());
size_t dynamicIndex = dynamicBasis.size();
Value dynamicPart = nullptr;
int64_t staticPart = 1;
for (int64_t elem : llvm::reverse(staticBasis)) {
if (ShapedType::isDynamic(elem)) {
if (dynamicPart)
dynamicPart = rewriter.create<arith::MulIOp>(
loc, dynamicPart, dynamicBasis[dynamicIndex - 1]);
else
dynamicPart = dynamicBasis[dynamicIndex - 1];
--dynamicIndex;
} else {
staticPart *= elem;
}

if (dynamicPart && staticPart == 1) {
result.push_back(dynamicPart);
} else {
Value stride =
rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
if (dynamicPart)
stride = rewriter.create<arith::MulIOp>(loc, dynamicPart, stride);
result.push_back(stride);
}
}
std::reverse(result.begin(), result.end());
return result;
}

namespace {
/// Lowers `affine.delinearize_index` into a sequence of division and remainder
/// operations.
Expand All @@ -36,18 +81,62 @@ struct LowerDelinearizeIndexOps
using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
PatternRewriter &rewriter) const override {
FailureOr<SmallVector<Value>> multiIndex =
delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
op.getEffectiveBasis(), /*hasOuterBound=*/false);
if (failed(multiIndex))
return failure();
rewriter.replaceOp(op, *multiIndex);
Location loc = op.getLoc();
Value linearIdx = op.getLinearIndex();
unsigned numResults = op.getNumResults();
ArrayRef<int64_t> staticBasis = op.getStaticBasis();
if (numResults == staticBasis.size())
staticBasis = staticBasis.drop_front();

if (numResults == 1) {
rewriter.replaceOp(op, linearIdx);
return success();
}

SmallVector<Value> results;
results.reserve(numResults);
SmallVector<Value> strides =
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);

Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);

Value initialPart =
rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
results.push_back(initialPart);

auto emitModTerm = [&](Value stride) -> Value {
Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
Value remainderNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, remainder, zero);
Value corrected = rewriter.create<arith::AddIOp>(loc, remainder, stride);
Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
corrected, remainder);
return mod;
};

// Generate all the intermediate parts
for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
Value thisStride = strides[i];
Value nextStride = strides[i + 1];
Value modulus = emitModTerm(thisStride);
// We know both inputs are positive, so floorDiv == div.
// This could potentially be a divui, but it's not clear if that would
// cause issues.
Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
results.push_back(divided);
}

results.push_back(emitModTerm(strides.back()));

rewriter.replaceOp(op, results);
return success();
}
};

/// Lowers `affine.linearize_index` into a sequence of multiplications and
/// additions.
/// additions. Make a best effort to sort the input indices so that
/// the most loop-invariant terms are at the left of the additions
/// to enable loop-invariant code motion.
struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
Expand All @@ -58,13 +147,44 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
return success();
}

SmallVector<OpFoldResult> multiIndex =
getAsOpFoldResult(op.getMultiIndex());
OpFoldResult linearIndex =
linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
Value linearIndexValue =
getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
rewriter.replaceOp(op, linearIndexValue);
Location loc = op.getLoc();
ValueRange multiIndex = op.getMultiIndex();
size_t numIndexes = multiIndex.size();
ArrayRef<int64_t> staticBasis = op.getStaticBasis();
if (numIndexes == staticBasis.size())
staticBasis = staticBasis.drop_front();

SmallVector<Value> strides =
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
SmallVector<std::pair<Value, int64_t>> scaledValues;
scaledValues.reserve(numIndexes);

// Note: strides doesn't contain a value for the final element (stride 1)
// and everything else lines up. We use the "mutable" accessor so we can get
// our hands on an `OpOperand&` for the loop invariant counting function.
for (auto [stride, idxOp] :
llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
Value scaledIdx =
rewriter.create<arith::MulIOp>(loc, idxOp.get(), stride);
int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
scaledValues.emplace_back(scaledIdx, numHoistableLoops);
}
scaledValues.emplace_back(
multiIndex.back(),
numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what the number of enclosing invariant loops is for. I dont think we should need this, i.e seems of mixing of unrelated concerns.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I see this from the commit message

In addition, the lowering of affine.linearize_index now sorts the operands by loop-independence, allowing an increased amount of loop-invariant code motion after lowering.

I think I understand what this does... that is interesting. I wonder if we can do this separately as an optimizations instead of linking it to linearization lowering. Take a sequence of arith.mul operations and reorder them to allow for better hoisting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's definitely an option - reordering mul mul mul add add type sequences as a separate pass

I don't have strong feelings about doing this optimization here, but it felt like a somewhat natural spot for it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its OK to do here on the lowering (though it makes the lowering more involved).


// Sort by how many enclosing loops there are, ties implicitly broken by
// size of the stride.
llvm::stable_sort(scaledValues,
[&](auto l, auto r) { return l.second > r.second; });

Value result = scaledValues.front().first;
for (auto [scaledValue, numHoistableLoops] :
llvm::drop_begin(scaledValues)) {
std::ignore = numHoistableLoops;
result = rewriter.create<arith::AddIOp>(loc, result, scaledValue);
}
rewriter.replaceOp(op, result);
return success();
}
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
//===- AffineExpandIndexOpsAsAffine.cpp - Expand index ops to apply pass --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to expand affine index ops into one or more more
// fundamental operations.
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/Passes.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace affine {
#define GEN_PASS_DEF_AFFINEEXPANDINDEXOPSASAFFINE
#include "mlir/Dialect/Affine/Passes.h.inc"
} // namespace affine
} // namespace mlir

using namespace mlir;
using namespace mlir::affine;

namespace {
/// Lowers `affine.delinearize_index` into a sequence of division and remainder
/// operations.
struct LowerDelinearizeIndexOps
: public OpRewritePattern<AffineDelinearizeIndexOp> {
using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
PatternRewriter &rewriter) const override {
FailureOr<SmallVector<Value>> multiIndex =
delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
op.getEffectiveBasis(), /*hasOuterBound=*/false);
if (failed(multiIndex))
return failure();
rewriter.replaceOp(op, *multiIndex);
return success();
}
};

/// Lowers `affine.linearize_index` into a sequence of multiplications and
/// additions.
struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
PatternRewriter &rewriter) const override {
// Should be folded away, included here for safety.
if (op.getMultiIndex().empty()) {
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
return success();
}

SmallVector<OpFoldResult> multiIndex =
getAsOpFoldResult(op.getMultiIndex());
OpFoldResult linearIndex =
linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
Value linearIndexValue =
getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
rewriter.replaceOp(op, linearIndexValue);
return success();
}
};

class ExpandAffineIndexOpsAsAffinePass
: public affine::impl::AffineExpandIndexOpsAsAffineBase<
ExpandAffineIndexOpsAsAffinePass> {
public:
ExpandAffineIndexOpsAsAffinePass() = default;

void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
populateAffineExpandIndexOpsAsAffinePatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};

} // namespace

void mlir::affine::populateAffineExpandIndexOpsAsAffinePatterns(
RewritePatternSet &patterns) {
patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
patterns.getContext());
}

std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsAsAffinePass() {
return std::make_unique<ExpandAffineIndexOpsAsAffinePass>();
}
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRAffineTransforms
AffineDataCopyGeneration.cpp
AffineExpandIndexOps.cpp
AffineExpandIndexOpsAsAffine.cpp
AffineLoopInvariantCodeMotion.cpp
AffineLoopNormalize.cpp
AffineParallelize.cpp
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2772,3 +2772,15 @@ LogicalResult affine::coalescePerfectlyNestedAffineLoops(AffineForOp op) {
}
return result;
}

int64_t mlir::affine::numEnclosingInvariantLoops(OpOperand &operand) {
int64_t count = 0;
Operation *currentOp = operand.getOwner();
while (auto loopOp = currentOp->getParentOfType<LoopLikeOpInterface>()) {
if (!loopOp.isDefinedOutsideOfLoop(operand.get()))
break;
currentOp = loopOp;
count++;
}
return count;
}
Loading
Loading