Skip to content

[mlir][SCF] Use Affine ops for indexing math. #108450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
];

let hasVerifier = 1;
let hasCanonicalizer = 1;
}

#endif // AFFINE_OPS
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Affine/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def LoopCoalescing : Pass<"affine-loop-coalescing", "func::FuncOp"> {
let summary = "Coalesce nested loops with independent bounds into a single "
"loop";
let constructor = "mlir::affine::createLoopCoalescingPass()";
let dependentDialects = ["arith::ArithDialect"];
let dependentDialects = ["affine::AffineDialect","arith::ArithDialect"];
}

def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> {
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def SCFParallelLoopFusion : Pass<"scf-parallel-loop-fusion"> {
def TestSCFParallelLoopCollapsing : Pass<"test-scf-parallel-loop-collapsing"> {
let summary = "Test parallel loops collapsing transformation";
let constructor = "mlir::createTestSCFParallelLoopCollapsingPass()";
let dependentDialects = ["affine::AffineDialect"];
let description = [{
This pass is purely for testing the scf::collapseParallelLoops
transformation. The transformation does not have opinions on how a
Expand Down
127 changes: 127 additions & 0 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4534,6 +4534,133 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
return success();
}

namespace {

// Drops delinearization indices that correspond to unit-extent basis
struct DropUnitExtentBasis
: public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
PatternRewriter &rewriter) const override {
SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
std::optional<Value> zero = std::nullopt;
Location loc = delinearizeOp->getLoc();
auto getZero = [&]() -> Value {
if (!zero)
zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
return zero.value();
};

// Replace all indices corresponding to unit-extent basis with 0.
// Remaining basis can be used to get a new `affine.delinearize_index` op.
SmallVector<Value> newOperands;
for (auto [index, basis] : llvm::enumerate(delinearizeOp.getBasis())) {
if (matchPattern(basis, m_One()))
replacements[index] = getZero();
else
newOperands.push_back(basis);
}

if (newOperands.size() == delinearizeOp.getBasis().size())
return failure();

if (!newOperands.empty()) {
auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
loc, delinearizeOp.getLinearIndex(), newOperands);
int newIndex = 0;
// Map back the new delinearized indices to the values they replace.
for (auto &replacement : replacements) {
if (replacement)
continue;
replacement = newDelinearizeOp->getResult(newIndex++);
}
}

rewriter.replaceOp(delinearizeOp, replacements);
return success();
}
};

/// Drop delinearization pattern related to loops in the following way
///
/// ```
/// <loop>(%iv) = (%c0) to (%ub) step (%c1) {
/// %0 = affine.delinearize_index %iv into (%ub) : index
/// <some_use>(%0)
/// }
/// ```
///
/// can be canonicalized to
///
/// ```
/// <loop>(%iv) = (%c0) to (%ub) step (%c1) {
/// <some_use>(%iv)
/// }
/// ```
struct DropDelinearizeOfSingleLoop
: public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
PatternRewriter &rewriter) const override {
auto basis = delinearizeOp.getBasis();
if (basis.size() != 1)
return failure();

// Check that the `linear_index` is an induction variable.
auto inductionVar = cast<BlockArgument>(delinearizeOp.getLinearIndex());
if (!inductionVar)
return failure();

// Check that the parent is a `LoopLikeOpInterface`.
auto loopLikeOp = cast<LoopLikeOpInterface>(
inductionVar.getParentRegion()->getParentOp());
if (!loopLikeOp)
return failure();

// Check that loop is unit-rank and that the `linear_index` is the induction
// variable.
auto inductionVars = loopLikeOp.getLoopInductionVars();
if (!inductionVars || inductionVars->size() != 1 ||
inductionVars->front() != inductionVar) {
return rewriter.notifyMatchFailure(
delinearizeOp, "`linear_index` is not loop induction variable");
}

// Check that the upper-bound is the basis.
auto upperBounds = loopLikeOp.getLoopUpperBounds();
if (!upperBounds || upperBounds->size() != 1 ||
upperBounds->front() != getAsOpFoldResult(basis.front())) {
return rewriter.notifyMatchFailure(delinearizeOp,
"`basis` is not upper bound");
}

// Check that the lower bound is zero.
auto lowerBounds = loopLikeOp.getLoopLowerBounds();
if (!lowerBounds || lowerBounds->size() != 1 ||
!isZeroIndex(lowerBounds->front())) {
return rewriter.notifyMatchFailure(delinearizeOp,
"loop lower bound is not zero");
}

// Check that the step is one.
auto steps = loopLikeOp.getLoopSteps();
if (!steps || steps->size() != 1 || !isConstantIntValue(steps->front(), 1))
return rewriter.notifyMatchFailure(delinearizeOp, "loop step is not one");

rewriter.replaceOp(delinearizeOp, inductionVar);
return success();
}
};

} // namespace

void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.insert<DropDelinearizeOfSingleLoop, DropUnitExtentBasis>(context);
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "mlir/Dialect/SCF/Transforms/Passes.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Transforms/RegionUtils.h"
Expand Down
78 changes: 77 additions & 1 deletion mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -671,9 +672,26 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
return success();
}

Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc,
OpFoldResult lb, OpFoldResult ub,
OpFoldResult step) {
Range normalizedLoopBounds;
normalizedLoopBounds.offset = rewriter.getIndexAttr(0);
normalizedLoopBounds.stride = rewriter.getIndexAttr(1);
AffineExpr s0, s1, s2;
bindSymbols(rewriter.getContext(), s0, s1, s2);
AffineExpr e = (s1 - s0).ceilDiv(s2);
normalizedLoopBounds.size =
affine::makeComposedFoldedAffineApply(rewriter, loc, e, {lb, ub, step});
return normalizedLoopBounds;
}

Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
OpFoldResult lb, OpFoldResult ub,
OpFoldResult step) {
if (getType(lb).isIndex()) {
return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step);
}
// For non-index types, generate `arith` instructions
// Check if the loop is already known to have a constant zero lower bound or
// a constant one step.
Expand Down Expand Up @@ -714,9 +732,38 @@ Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
return {newLowerBound, newUpperBound, newStep};
}

static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
Location loc,
Value normalizedIv,
OpFoldResult origLb,
OpFoldResult origStep) {
AffineExpr d0, s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
bindDims(rewriter.getContext(), d0);
AffineExpr e = d0 * s1 + s0;
OpFoldResult denormalizedIv = affine::makeComposedFoldedAffineApply(
rewriter, loc, e, ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
Value denormalizedIvVal =
getValueOrCreateConstantIndexOp(rewriter, loc, denormalizedIv);
SmallPtrSet<Operation *, 1> preservedUses;
// If an `affine.apply` operation is generated for denormalization, the use
// of `origLb` in those ops must not be replaced. These arent not generated
// when `origLb == 0` and `origStep == 1`.
if (!isConstantIntValue(origLb, 0) || !isConstantIntValue(origStep, 1)) {
if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
preservedUses.insert(preservedUse);
}
}
rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIvVal, preservedUses);
}

void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
Value normalizedIv, OpFoldResult origLb,
OpFoldResult origStep) {
if (getType(origLb).isIndex()) {
return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv,
origLb, origStep);
}
Value denormalizedIv;
SmallPtrSet<Operation *, 2> preserve;
bool isStepOne = isConstantIntValue(origStep, 1);
Expand All @@ -739,10 +786,29 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve);
}

static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc,
ArrayRef<OpFoldResult> values) {
assert(!values.empty() && "unexecpted empty array");
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
AffineExpr mul = s0 * s1;
OpFoldResult products = rewriter.getIndexAttr(1);
for (auto v : values) {
products = affine::makeComposedFoldedAffineApply(
rewriter, loc, mul, ArrayRef<OpFoldResult>{products, v});
}
return products;
}

/// Helper function to multiply a sequence of values.
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
ArrayRef<Value> values) {
assert(!values.empty() && "unexpected empty list");
if (getType(values.front()).isIndex()) {
SmallVector<OpFoldResult> ofrs = getAsOpFoldResult(values);
OpFoldResult product = getProductOfIndexes(rewriter, loc, ofrs);
return getValueOrCreateConstantIndexOp(rewriter, loc, product);
}
std::optional<Value> productOf;
for (auto v : values) {
auto vOne = getConstantIntValue(v);
Expand All @@ -757,7 +823,7 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
if (!productOf) {
productOf = rewriter
.create<arith::ConstantOp>(
loc, rewriter.getOneAttr(values.front().getType()))
loc, rewriter.getOneAttr(getType(values.front())))
.getResult();
}
return productOf.value();
Expand All @@ -774,6 +840,16 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
Value linearizedIv, ArrayRef<Value> ubs) {

if (linearizedIv.getType().isIndex()) {
Operation *delinearizedOp =
rewriter.create<affine::AffineDelinearizeIndexOp>(loc, linearizedIv,
ubs);
auto resultVals = llvm::map_to_vector(
delinearizedOp->getResults(), [](OpResult r) -> Value { return r; });
return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}};
}

SmallVector<Value> delinearizedIvs(ubs.size());
SmallPtrSet<Operation *, 2> preservedUsers;

Expand Down
48 changes: 48 additions & 0 deletions mlir/test/Dialect/Affine/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1466,3 +1466,51 @@ func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
}
return
}

// -----

func.func @drop_unit_basis_in_delinearize(%arg0 : index, %arg1 : index, %arg2 : index) ->
(index, index, index, index, index, index) {
%c1 = arith.constant 1 : index
%0:6 = affine.delinearize_index %arg0 into (%c1, %arg1, %c1, %c1, %arg2, %c1)
: index, index, index, index, index, index
return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : index, index, index, index, index, index
}
// CHECK-LABEL: func @drop_unit_basis_in_delinearize(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], %[[ARG2]])
// CHECK: return %[[C0]], %[[DELINEARIZE]]#0, %[[C0]], %[[C0]], %[[DELINEARIZE]]#1, %[[C0]]

// -----

func.func @drop_all_unit_bases(%arg0 : index) -> (index, index) {
%c1 = arith.constant 1 : index
%0:2 = affine.delinearize_index %arg0 into (%c1, %c1) : index, index
return %0#0, %0#1 : index, index
}
// CHECK-LABEL: func @drop_all_unit_bases(
// CHECK-SAME: %[[ARG0:.+]]: index)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-NOT: affine.delinearize_index
// CHECK: return %[[C0]], %[[C0]]

// -----

func.func @drop_single_loop_delinearize(%arg0 : index, %arg1 : index) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%2 = scf.for %iv = %c0 to %arg1 step %c1 iter_args(%arg2 = %c0) -> index {
%0 = affine.delinearize_index %iv into (%arg1) : index
%1 = "some_use"(%arg2, %0) : (index, index) -> (index)
scf.yield %1 : index
}
return %2 : index
}
// CHECK-LABEL: func @drop_single_loop_delinearize(
// CHECK-SAME: %[[ARG0:.+]]: index)
// CHECK: scf.for %[[IV:[a-zA-Z0-9]+]] =
// CHECK-NOT: affine.delinearize_index
// CHECK: "some_use"(%{{.+}}, %[[IV]])
Loading
Loading