Skip to content

Commit cfbc153

Browse files
Add canonicalization/folders for affine.delinearize_index op.
Signed-off-by: MaheshRavishankar <[email protected]>
1 parent bc7fb7c commit cfbc153

File tree

5 files changed

+189
-7
lines changed

5 files changed

+189
-7
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
10961096
];
10971097

10981098
let hasVerifier = 1;
1099+
let hasCanonicalizer = 1;
10991100
}
11001101

11011102
#endif // AFFINE_OPS

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4534,6 +4534,133 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
45344534
return success();
45354535
}
45364536

4537+
namespace {
4538+
4539+
// Drops delinearization indices that correspond to unit-extent basis
4540+
struct DropUnitExtentBasis
4541+
: public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4542+
using OpRewritePattern::OpRewritePattern;
4543+
4544+
LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4545+
PatternRewriter &rewriter) const override {
4546+
SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
4547+
std::optional<Value> zero = std::nullopt;
4548+
Location loc = delinearizeOp->getLoc();
4549+
auto getZero = [&]() -> Value {
4550+
if (!zero)
4551+
zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
4552+
return zero.value();
4553+
};
4554+
4555+
// Replace all indices corresponding to unit-extent basis with 0.
4556+
// Remaining basis can be used to get a new `affine.delinearize_index` op.
4557+
SmallVector<Value> newOperands;
4558+
for (auto [index, basis] : llvm::enumerate(delinearizeOp.getBasis())) {
4559+
if (matchPattern(basis, m_One()))
4560+
replacements[index] = getZero();
4561+
else
4562+
newOperands.push_back(basis);
4563+
}
4564+
4565+
if (newOperands.size() == delinearizeOp.getBasis().size())
4566+
return failure();
4567+
4568+
if (!newOperands.empty()) {
4569+
auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
4570+
loc, delinearizeOp.getLinearIndex(), newOperands);
4571+
int newIndex = 0;
4572+
// Map back the new delinearized indices to the values they replace.
4573+
for (auto &replacement : replacements) {
4574+
if (replacement)
4575+
continue;
4576+
replacement = newDelinearizeOp->getResult(newIndex++);
4577+
}
4578+
}
4579+
4580+
rewriter.replaceOp(delinearizeOp, replacements);
4581+
return success();
4582+
}
4583+
};
4584+
4585+
/// Drop delinearization pattern related to loops in the following way
4586+
///
4587+
/// ```
4588+
/// <loop>(%iv) = (%c0) to (%ub) step (%c1) {
4589+
/// %0 = affine.delinearize_index %iv into (%ub) : index
4590+
/// <some_use>(%0)
4591+
/// }
4592+
/// ```
4593+
///
4594+
/// can be canonicalized to
4595+
///
4596+
/// ```
4597+
/// <loop>(%iv) = (%c0) to (%ub) step (%c1) {
4598+
/// <some_use>(%iv)
4599+
/// }
4600+
/// ```
4601+
struct DropDelinearizeOfSingleLoop
4602+
: public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4603+
using OpRewritePattern::OpRewritePattern;
4604+
4605+
LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4606+
PatternRewriter &rewriter) const override {
4607+
auto basis = delinearizeOp.getBasis();
4608+
if (basis.size() != 1)
4609+
return failure();
4610+
4611+
// Check that the `linear_index` is an induction variable.
4612+
auto inductionVar = cast<BlockArgument>(delinearizeOp.getLinearIndex());
4613+
if (!inductionVar)
4614+
return failure();
4615+
4616+
// Check that the parent is a `LoopLikeOpInterface`.
4617+
auto loopLikeOp = cast<LoopLikeOpInterface>(
4618+
inductionVar.getParentRegion()->getParentOp());
4619+
if (!loopLikeOp)
4620+
return failure();
4621+
4622+
// Check that loop is unit-rank and that the `linear_index` is the induction
4623+
// variable.
4624+
auto inductionVars = loopLikeOp.getLoopInductionVars();
4625+
if (!inductionVars || inductionVars->size() != 1 ||
4626+
inductionVars->front() != inductionVar) {
4627+
return rewriter.notifyMatchFailure(
4628+
delinearizeOp, "`linear_index` is not loop induction variable");
4629+
}
4630+
4631+
// Check that the upper-bound is the basis.
4632+
auto upperBounds = loopLikeOp.getLoopUpperBounds();
4633+
if (!upperBounds || upperBounds->size() != 1 ||
4634+
upperBounds->front() != getAsOpFoldResult(basis.front())) {
4635+
return rewriter.notifyMatchFailure(delinearizeOp,
4636+
"`basis` is not upper bound");
4637+
}
4638+
4639+
// Check that the lower bound is zero.
4640+
auto lowerBounds = loopLikeOp.getLoopLowerBounds();
4641+
if (!lowerBounds || lowerBounds->size() != 1 ||
4642+
!isZeroIndex(lowerBounds->front())) {
4643+
return rewriter.notifyMatchFailure(delinearizeOp,
4644+
"loop lower bound is not zero");
4645+
}
4646+
4647+
// Check that the step is one.
4648+
auto steps = loopLikeOp.getLoopSteps();
4649+
if (!steps || steps->size() != 1 || !isConstantIntValue(steps->front(), 1))
4650+
return rewriter.notifyMatchFailure(delinearizeOp, "loop step is not one");
4651+
4652+
rewriter.replaceOp(delinearizeOp, inductionVar);
4653+
return success();
4654+
}
4655+
};
4656+
4657+
} // namespace
4658+
4659+
void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
4660+
RewritePatternSet &patterns, MLIRContext *context) {
4661+
patterns.insert<DropDelinearizeOfSingleLoop, DropUnitExtentBasis>(context);
4662+
}
4663+
45374664
//===----------------------------------------------------------------------===//
45384665
// TableGen'd op method definitions
45394666
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc,
689689
Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
690690
OpFoldResult lb, OpFoldResult ub,
691691
OpFoldResult step) {
692-
if (getType(lb) == rewriter.getIndexType()) {
692+
if (getType(lb).isIndex()) {
693693
return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step);
694694
}
695695
// For non-index types, generate `arith` instructions
@@ -748,7 +748,7 @@ static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
748748
SmallPtrSet<Operation *, 1> preservedUses;
749749
// If an `affine.apply` operation is generated for denormalization, the use
750750
// of `origLb` in those ops must not be replaced. These arent not generated
751-
// when `orig_lb == 0` and `orig_step == 1`.
751+
// when `origLb == 0` and `origStep == 1`.
752752
if (!isConstantIntValue(origLb, 0) || !isConstantIntValue(origStep, 1)) {
753753
if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
754754
preservedUses.insert(preservedUse);
@@ -760,7 +760,7 @@ static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
760760
void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
761761
Value normalizedIv, OpFoldResult origLb,
762762
OpFoldResult origStep) {
763-
if (getType(origLb) == rewriter.getIndexType()) {
763+
if (getType(origLb).isIndex()) {
764764
return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv,
765765
origLb, origStep);
766766
}
@@ -804,7 +804,7 @@ static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc,
804804
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
805805
ArrayRef<Value> values) {
806806
assert(!values.empty() && "unexpected empty list");
807-
if (getType(values.front()) == rewriter.getIndexType()) {
807+
if (getType(values.front()).isIndex()) {
808808
SmallVector<OpFoldResult> ofrs = getAsOpFoldResult(values);
809809
OpFoldResult product = getProductOfIndexes(rewriter, loc, ofrs);
810810
return getValueOrCreateConstantIndexOp(rewriter, loc, product);
@@ -841,7 +841,7 @@ static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
841841
delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
842842
Value linearizedIv, ArrayRef<Value> ubs) {
843843

844-
if (linearizedIv.getType() == rewriter.getIndexType()) {
844+
if (linearizedIv.getType().isIndex()) {
845845
Operation *delinearizedOp =
846846
rewriter.create<affine::AffineDelinearizeIndexOp>(loc, linearizedIv,
847847
ubs);

mlir/test/Dialect/Affine/canonicalize.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,3 +1466,51 @@ func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
14661466
}
14671467
return
14681468
}
1469+
1470+
// -----
1471+
1472+
func.func @drop_unit_basis_in_delinearize(%arg0 : index, %arg1 : index, %arg2 : index) ->
1473+
(index, index, index, index, index, index) {
1474+
%c1 = arith.constant 1 : index
1475+
%0:6 = affine.delinearize_index %arg0 into (%c1, %arg1, %c1, %c1, %arg2, %c1)
1476+
: index, index, index, index, index, index
1477+
return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : index, index, index, index, index, index
1478+
}
1479+
// CHECK-LABEL: func @drop_unit_basis_in_delinearize(
1480+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
1481+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
1482+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
1483+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1484+
// CHECK-DAG: %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], %[[ARG2]])
1485+
// CHECK: return %[[C0]], %[[DELINEARIZE]]#0, %[[C0]], %[[C0]], %[[DELINEARIZE]]#1, %[[C0]]
1486+
1487+
// -----
1488+
1489+
func.func @drop_all_unit_bases(%arg0 : index) -> (index, index) {
1490+
%c1 = arith.constant 1 : index
1491+
%0:2 = affine.delinearize_index %arg0 into (%c1, %c1) : index, index
1492+
return %0#0, %0#1 : index, index
1493+
}
1494+
// CHECK-LABEL: func @drop_all_unit_bases(
1495+
// CHECK-SAME: %[[ARG0:.+]]: index)
1496+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1497+
// CHECK-NOT: affine.delinearize_index
1498+
// CHECK: return %[[C0]], %[[C0]]
1499+
1500+
// -----
1501+
1502+
func.func @drop_single_loop_delinearize(%arg0 : index, %arg1 : index) -> index {
1503+
%c0 = arith.constant 0 : index
1504+
%c1 = arith.constant 1 : index
1505+
%2 = scf.for %iv = %c0 to %arg1 step %c1 iter_args(%arg2 = %c0) -> index {
1506+
%0 = affine.delinearize_index %iv into (%arg1) : index
1507+
%1 = "some_use"(%arg2, %0) : (index, index) -> (index)
1508+
scf.yield %1 : index
1509+
}
1510+
return %2 : index
1511+
}
1512+
// CHECK-LABEL: func @drop_single_loop_delinearize(
1513+
// CHECK-SAME: %[[ARG0:.+]]: index)
1514+
// CHECK: scf.for %[[IV:[a-zA-Z0-9]+]] =
1515+
// CHECK-NOT: affine.delinearize_index
1516+
// CHECK: "some_use"(%{{.+}}, %[[IV]])

mlir/test/Dialect/SCF/transform-op-coalesce.mlir

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,9 @@ module attributes {transform.with_named_sequence} {
313313
%0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op
314314
%1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for">
315315
%2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
316+
transform.apply_patterns to %2 {
317+
transform.apply_patterns.canonicalization
318+
} : !transform.op<"scf.for">
316319
transform.yield
317320
}
318321
}
@@ -323,8 +326,8 @@ module attributes {transform.with_named_sequence} {
323326
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
324327
// CHECK: %[[UB:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%[[ARG1]], %[[ARG2]]]
325328
// CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[UB]] step %[[C1]]
326-
// CHECK: %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[IV]](%[[ARG1]], %[[ARG2]])
327-
// CHECK: "some_use"(%{{[a-zA-Z0-9]+}}, %[[C0]], %[[C0]], %[[IV2]], %[[C0]], %[[IV1]])
329+
// CHECK: %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[IV]] into (%[[ARG1]], %[[ARG2]])
330+
// CHECK: "some_use"(%{{[a-zA-Z0-9]+}}, %[[C0]], %[[C0]], %[[DELINEARIZE]]#0, %[[C0]], %[[DELINEARIZE]]#1)
328331

329332
// -----
330333

@@ -350,6 +353,9 @@ module attributes {transform.with_named_sequence} {
350353
%0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op
351354
%1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for">
352355
%2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
356+
transform.apply_patterns to %2 {
357+
transform.apply_patterns.canonicalization
358+
} : !transform.op<"scf.for">
353359
transform.yield
354360
}
355361
}

0 commit comments

Comments
 (0)