Skip to content

Commit bc7fb7c

Browse files
[mlir][SCF] Use Affine ops for indexing math.
For index type of induction variable, the indexing math is better represented using affine ops such as `affine.delinearize_index`. This also further demonstrates that some of these `affine` ops might need to move to a different dialect. For one these ops only support `IndexType` when they should be able to work with any integer type. Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 970e2c1 commit bc7fb7c

File tree

8 files changed

+233
-206
lines changed

8 files changed

+233
-206
lines changed

mlir/include/mlir/Dialect/Affine/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def LoopCoalescing : Pass<"affine-loop-coalescing", "func::FuncOp"> {
394394
let summary = "Coalesce nested loops with independent bounds into a single "
395395
"loop";
396396
let constructor = "mlir::affine::createLoopCoalescingPass()";
397-
let dependentDialects = ["arith::ArithDialect"];
397+
let dependentDialects = ["affine::AffineDialect","arith::ArithDialect"];
398398
}
399399

400400
def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> {

mlir/include/mlir/Dialect/SCF/Transforms/Passes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def SCFParallelLoopFusion : Pass<"scf-parallel-loop-fusion"> {
5656
def TestSCFParallelLoopCollapsing : Pass<"test-scf-parallel-loop-collapsing"> {
5757
let summary = "Test parallel loops collapsing transformation";
5858
let constructor = "mlir::createTestSCFParallelLoopCollapsingPass()";
59+
let dependentDialects = ["affine::AffineDialect"];
5960
let description = [{
6061
This pass is purely for testing the scf::collapseParallelLoops
6162
transformation. The transformation does not have opinions on how a

mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Dialect/SCF/Transforms/Passes.h"
1010

11+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1112
#include "mlir/Dialect/SCF/IR/SCF.h"
1213
#include "mlir/Dialect/SCF/Utils/Utils.h"
1314
#include "mlir/Transforms/RegionUtils.h"

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/SCF/Utils/Utils.h"
1414
#include "mlir/Analysis/SliceAnalysis.h"
15+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1516
#include "mlir/Dialect/Arith/IR/Arith.h"
1617
#include "mlir/Dialect/Arith/Utils/Utils.h"
1718
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -671,9 +672,26 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
671672
return success();
672673
}
673674

675+
Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc,
676+
OpFoldResult lb, OpFoldResult ub,
677+
OpFoldResult step) {
678+
Range normalizedLoopBounds;
679+
normalizedLoopBounds.offset = rewriter.getIndexAttr(0);
680+
normalizedLoopBounds.stride = rewriter.getIndexAttr(1);
681+
AffineExpr s0, s1, s2;
682+
bindSymbols(rewriter.getContext(), s0, s1, s2);
683+
AffineExpr e = (s1 - s0).ceilDiv(s2);
684+
normalizedLoopBounds.size =
685+
affine::makeComposedFoldedAffineApply(rewriter, loc, e, {lb, ub, step});
686+
return normalizedLoopBounds;
687+
}
688+
674689
Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
675690
OpFoldResult lb, OpFoldResult ub,
676691
OpFoldResult step) {
692+
if (getType(lb) == rewriter.getIndexType()) {
693+
return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step);
694+
}
677695
// For non-index types, generate `arith` instructions
678696
// Check if the loop is already known to have a constant zero lower bound or
679697
// a constant one step.
@@ -714,9 +732,38 @@ Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
714732
return {newLowerBound, newUpperBound, newStep};
715733
}
716734

735+
static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
736+
Location loc,
737+
Value normalizedIv,
738+
OpFoldResult origLb,
739+
OpFoldResult origStep) {
740+
AffineExpr d0, s0, s1;
741+
bindSymbols(rewriter.getContext(), s0, s1);
742+
bindDims(rewriter.getContext(), d0);
743+
AffineExpr e = d0 * s1 + s0;
744+
OpFoldResult denormalizedIv = affine::makeComposedFoldedAffineApply(
745+
rewriter, loc, e, ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
746+
Value denormalizedIvVal =
747+
getValueOrCreateConstantIndexOp(rewriter, loc, denormalizedIv);
748+
SmallPtrSet<Operation *, 1> preservedUses;
749+
// If an `affine.apply` operation is generated for denormalization, the use
750+
// of `origLb` in those ops must not be replaced. These arent not generated
751+
// when `orig_lb == 0` and `orig_step == 1`.
752+
if (!isConstantIntValue(origLb, 0) || !isConstantIntValue(origStep, 1)) {
753+
if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
754+
preservedUses.insert(preservedUse);
755+
}
756+
}
757+
rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIvVal, preservedUses);
758+
}
759+
717760
void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
718761
Value normalizedIv, OpFoldResult origLb,
719762
OpFoldResult origStep) {
763+
if (getType(origLb) == rewriter.getIndexType()) {
764+
return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv,
765+
origLb, origStep);
766+
}
720767
Value denormalizedIv;
721768
SmallPtrSet<Operation *, 2> preserve;
722769
bool isStepOne = isConstantIntValue(origStep, 1);
@@ -739,10 +786,29 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
739786
rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve);
740787
}
741788

789+
static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc,
790+
ArrayRef<OpFoldResult> values) {
791+
assert(!values.empty() && "unexecpted empty array");
792+
AffineExpr s0, s1;
793+
bindSymbols(rewriter.getContext(), s0, s1);
794+
AffineExpr mul = s0 * s1;
795+
OpFoldResult products = rewriter.getIndexAttr(1);
796+
for (auto v : values) {
797+
products = affine::makeComposedFoldedAffineApply(
798+
rewriter, loc, mul, ArrayRef<OpFoldResult>{products, v});
799+
}
800+
return products;
801+
}
802+
742803
/// Helper function to multiply a sequence of values.
743804
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
744805
ArrayRef<Value> values) {
745806
assert(!values.empty() && "unexpected empty list");
807+
if (getType(values.front()) == rewriter.getIndexType()) {
808+
SmallVector<OpFoldResult> ofrs = getAsOpFoldResult(values);
809+
OpFoldResult product = getProductOfIndexes(rewriter, loc, ofrs);
810+
return getValueOrCreateConstantIndexOp(rewriter, loc, product);
811+
}
746812
std::optional<Value> productOf;
747813
for (auto v : values) {
748814
auto vOne = getConstantIntValue(v);
@@ -757,7 +823,7 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
757823
if (!productOf) {
758824
productOf = rewriter
759825
.create<arith::ConstantOp>(
760-
loc, rewriter.getOneAttr(values.front().getType()))
826+
loc, rewriter.getOneAttr(getType(values.front())))
761827
.getResult();
762828
}
763829
return productOf.value();
@@ -774,6 +840,16 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
774840
static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
775841
delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
776842
Value linearizedIv, ArrayRef<Value> ubs) {
843+
844+
if (linearizedIv.getType() == rewriter.getIndexType()) {
845+
Operation *delinearizedOp =
846+
rewriter.create<affine::AffineDelinearizeIndexOp>(loc, linearizedIv,
847+
ubs);
848+
auto resultVals = llvm::map_to_vector(
849+
delinearizedOp->getResults(), [](OpResult r) -> Value { return r; });
850+
return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}};
851+
}
852+
777853
SmallVector<Value> delinearizedIvs(ubs.size());
778854
SmallPtrSet<Operation *, 2> preservedUsers;
779855

0 commit comments

Comments
 (0)