Skip to content

Commit 49f90e7

Browse files
authored
[mlir][affine] Cancel exactly-matching delinearize/linearize pairs (#115758)
If we linearize values (with an assertion tha they are disjoint) and then delinearize that linear index with th exact same basis, we know that these operations are exact inverses of each other and can be replaced with the original inputs to the linearization. Similarly, if we take a linear index, delinearize it with some bases, and then re-linearize it with that same basis (noting that the outputs of the delinearization are guaranteed to by `disjoint`, even if this is not asserted on the linearize_index operation), the re-linearization is the inverse of the delinearization, so those two operations can also be canceled out. This commit adds canonicalization patterns for these simple cancelations.
1 parent fe83a72 commit 49f90e7

File tree

3 files changed

+176
-8
lines changed

3 files changed

+176
-8
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
10901090
let results = (outs Variadic<Index>:$multi_index);
10911091

10921092
let assemblyFormat = [{
1093-
$linear_index `into` ` `
1093+
$linear_index `into`
10941094
custom<DynamicIndexList>($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren")
10951095
attr-dict `:` type($multi_index)
10961096
}];

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4586,7 +4586,8 @@ struct DropUnitExtentBasis
45864586
}
45874587

45884588
if (newOperands.size() == delinearizeOp.getStaticBasis().size())
4589-
return failure();
4589+
return rewriter.notifyMatchFailure(delinearizeOp,
4590+
"no unit basis elements");
45904591

45914592
if (!newOperands.empty()) {
45924593
auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
@@ -4619,17 +4620,48 @@ struct DropDelinearizeOneBasisElement
46194620
LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
46204621
PatternRewriter &rewriter) const override {
46214622
if (delinearizeOp.getStaticBasis().size() != 1)
4622-
return failure();
4623+
return rewriter.notifyMatchFailure(delinearizeOp,
4624+
"doesn't have a length-1 basis");
46234625
rewriter.replaceOp(delinearizeOp, delinearizeOp.getLinearIndex());
46244626
return success();
46254627
}
46264628
};
46274629

4630+
/// If a `affine.delinearize_index`'s input is a `affine.linearize_index
4631+
/// disjoint` and the two operations have the same basis, replace the
4632+
/// delinearizeation results with the inputs of the `affine.linearize_index`
4633+
/// since they are exact inverses of each other.
4634+
///
4635+
/// The `disjoint` flag is needed on the `affine.linearize_index` because
4636+
/// otherwise, there is no guarantee that the inputs to the linearization are
4637+
/// in-bounds the way the outputs of the delinearization would be.
4638+
struct CancelDelinearizeOfLinearizeDisjointExact
4639+
: public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4640+
using OpRewritePattern::OpRewritePattern;
4641+
4642+
LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4643+
PatternRewriter &rewriter) const override {
4644+
auto linearizeOp = delinearizeOp.getLinearIndex()
4645+
.getDefiningOp<affine::AffineLinearizeIndexOp>();
4646+
if (!linearizeOp)
4647+
return rewriter.notifyMatchFailure(delinearizeOp,
4648+
"index doesn't come from linearize");
4649+
4650+
if (!linearizeOp.getDisjoint() ||
4651+
linearizeOp.getMixedBasis() != delinearizeOp.getMixedBasis())
4652+
return rewriter.notifyMatchFailure(
4653+
linearizeOp, "not disjoint or basis doesn't match delinearize");
4654+
4655+
rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
4656+
return success();
4657+
}
4658+
};
46284659
} // namespace
46294660

46304661
void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
46314662
RewritePatternSet &patterns, MLIRContext *context) {
4632-
patterns.insert<DropDelinearizeOneBasisElement, DropUnitExtentBasis>(context);
4663+
patterns.insert<CancelDelinearizeOfLinearizeDisjointExact,
4664+
DropDelinearizeOneBasisElement, DropUnitExtentBasis>(context);
46334665
}
46344666

46354667
//===----------------------------------------------------------------------===//
@@ -4723,7 +4755,8 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
47234755
}
47244756
}
47254757
if (newIndices.size() == numIndices)
4726-
return failure();
4758+
return rewriter.notifyMatchFailure(op,
4759+
"no unit basis entries to replace");
47274760

47284761
if (newIndices.size() == 0) {
47294762
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
@@ -4746,17 +4779,53 @@ struct DropLinearizeOneBasisElement final
47464779
LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
47474780
PatternRewriter &rewriter) const override {
47484781
if (op.getStaticBasis().size() != 1 || op.getMultiIndex().size() != 1)
4749-
return failure();
4782+
return rewriter.notifyMatchFailure(op, "doesn't have a a length-1 basis");
47504783
rewriter.replaceOp(op, op.getMultiIndex().front());
47514784
return success();
47524785
}
47534786
};
4787+
4788+
/// Cancel out linearize_index(delinearize_index(x, B), B).
4789+
///
4790+
/// That is, rewrite
4791+
/// ```
4792+
/// %0:N = affine.delinearize_index %x by (%b1, %b2, ... %bN)
4793+
/// %y = affine.linearize_index [%0#0, %0#1, ... %0#(N-1)] by (%b1, %b2, ...
4794+
/// %bN)
4795+
/// ```
4796+
/// to replacing `%y` with `%x`.
4797+
struct CancelLinearizeOfDelinearizeExact final
4798+
: OpRewritePattern<affine::AffineLinearizeIndexOp> {
4799+
using OpRewritePattern::OpRewritePattern;
4800+
4801+
LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
4802+
PatternRewriter &rewriter) const override {
4803+
auto delinearizeOp = linearizeOp.getMultiIndex()
4804+
.front()
4805+
.getDefiningOp<affine::AffineDelinearizeIndexOp>();
4806+
if (!delinearizeOp)
4807+
return rewriter.notifyMatchFailure(
4808+
linearizeOp, "last entry doesn't come from a delinearize");
4809+
4810+
if (linearizeOp.getMixedBasis() != delinearizeOp.getMixedBasis())
4811+
return rewriter.notifyMatchFailure(
4812+
linearizeOp,
4813+
"basis of linearize and delinearize don't match exactly");
4814+
4815+
if (delinearizeOp.getResults() != linearizeOp.getMultiIndex())
4816+
return rewriter.notifyMatchFailure(
4817+
linearizeOp, "not all indices come from delinearize");
4818+
4819+
rewriter.replaceOp(linearizeOp, delinearizeOp.getLinearIndex());
4820+
return success();
4821+
}
4822+
};
47544823
} // namespace
47554824

47564825
void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
47574826
RewritePatternSet &patterns, MLIRContext *context) {
4758-
patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero,
4759-
DropLinearizeOneBasisElement>(context);
4827+
patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeOneBasisElement,
4828+
DropLinearizeUnitComponentsIfDisjointOrZero>(context);
47604829
}
47614830

47624831
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Affine/canonicalize.mlir

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,6 +1535,60 @@ func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index
15351535

15361536
// -----
15371537

1538+
// CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_exact(
1539+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
1540+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
1541+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index,
1542+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index,
1543+
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index)
1544+
// CHECK: return %[[ARG0]], %[[ARG1]], %[[ARG2]]
1545+
func.func @cancel_delinearize_linearize_disjoint_exact(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
1546+
%0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
1547+
%1:3 = affine.delinearize_index %0 into (%arg3, 4, %arg4)
1548+
: index, index, index
1549+
return %1#0, %1#1, %1#2 : index, index, index
1550+
}
1551+
1552+
// -----
1553+
1554+
// Without `disjoint`, the cancelation isn't guaranteed to be the identity.
1555+
// CHECK-LABEL: func @no_cancel_delinearize_linearize_exact(
1556+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
1557+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
1558+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index,
1559+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index,
1560+
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index)
1561+
// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (%[[ARG3]], 4, %[[ARG4]])
1562+
// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (%[[ARG3]], 4, %[[ARG4]])
1563+
// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2
1564+
func.func @no_cancel_delinearize_linearize_exact(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
1565+
%0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
1566+
%1:3 = affine.delinearize_index %0 into (%arg3, 4, %arg4)
1567+
: index, index, index
1568+
return %1#0, %1#1, %1#2 : index, index, index
1569+
}
1570+
1571+
// -----
1572+
1573+
// These don't cancel because the delinearize and linearize have a different basis.
1574+
// CHECK-LABEL: func @no_cancel_delinearize_linearize_different_basis(
1575+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
1576+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
1577+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index,
1578+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index,
1579+
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index)
1580+
// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (%[[ARG3]], 4, %[[ARG4]])
1581+
// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (%[[ARG3]], 8, %[[ARG4]])
1582+
// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2
1583+
func.func @no_cancel_delinearize_linearize_different_basis(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
1584+
%0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
1585+
%1:3 = affine.delinearize_index %0 into (%arg3, 8, %arg4)
1586+
: index, index, index
1587+
return %1#0, %1#1, %1#2 : index, index, index
1588+
}
1589+
1590+
// -----
1591+
15381592
// CHECK-LABEL: @linearize_unit_basis_disjoint
15391593
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
15401594
// CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index
@@ -1577,3 +1631,48 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index {
15771631
%ret = affine.linearize_index [%arg0] by (%arg1) : index
15781632
return %ret : index
15791633
}
1634+
1635+
// -----
1636+
1637+
// CHECK-LABEL: func @cancel_linearize_denearize_exact(
1638+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
1639+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
1640+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
1641+
// CHECK: return %[[ARG0]]
1642+
func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
1643+
%0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
1644+
%1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 4, %arg2) : index
1645+
return %1 : index
1646+
}
1647+
1648+
// -----
1649+
1650+
// Don't cancel because the values from the delinearize aren't used in order
1651+
// CHECK-LABEL: func @no_cancel_linearize_denearize_permuted(
1652+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
1653+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
1654+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
1655+
// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[ARG2]])
1656+
// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#2, %[[DELIN]]#1] by (%[[ARG1]], 4, %[[ARG2]])
1657+
// CHECK: return %[[LIN]]
1658+
func.func @no_cancel_linearize_denearize_permuted(%arg0: index, %arg1: index, %arg2: index) -> index {
1659+
%0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
1660+
%1 = affine.linearize_index [%0#0, %0#2, %0#1] by (%arg1, 4, %arg2) : index
1661+
return %1 : index
1662+
}
1663+
1664+
// -----
1665+
1666+
// Won't cancel because the linearize and delinearize are using a different basis
1667+
// CHECK-LABEL: func @no_cancel_linearize_denearize_different_basis(
1668+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
1669+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
1670+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
1671+
// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[ARG2]])
1672+
// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2] by (%[[ARG1]], 8, %[[ARG2]])
1673+
// CHECK: return %[[LIN]]
1674+
func.func @no_cancel_linearize_denearize_different_basis(%arg0: index, %arg1: index, %arg2: index) -> index {
1675+
%0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
1676+
%1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 8, %arg2) : index
1677+
return %1 : index
1678+
}

0 commit comments

Comments
 (0)