Skip to content

Commit ece4e12

Browse files
[mlir][Affine] Split off delinearize parts that depend on last component (#117015)
If we have %0 = affine.linearize_index disjoint [%a, %b] by (A, B) %1:3 = affine.delinearize_index %0 into (A, B1, B2) where B = B1 * B2 (or some mor complex product), we can simplify this to %0 = affine.linearize_index disjoint [%a] by (A) %1a:1 = affine.delinearize_index %0 into (A) %1b:2 = affine.delinearize_index %b into (B1, B2) This, and more complex cases, prevent us from adding terms together only to divide them away from each other. --------- Co-authored-by: Abhishek Varma <[email protected]>
1 parent 935da49 commit ece4e12

File tree

2 files changed

+154
-2
lines changed

2 files changed

+154
-2
lines changed

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4729,12 +4729,98 @@ struct CancelDelinearizeOfLinearizeDisjointExactTail
47294729
return success();
47304730
}
47314731
};
4732+
4733+
/// If the input to a delinearization is a disjoint linearization, and the
4734+
/// last k > 1 components of the delinearization basis multiply to the
4735+
/// last component of the linearization basis, break the linearization and
4736+
/// delinearization into two parts, peeling off the last input to linearization.
4737+
///
4738+
/// For example:
4739+
/// %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
4740+
/// %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
4741+
/// becomes
4742+
/// %0 = affine.linearize_index [%z, %y] by (3, 2) : index
4743+
/// %1:2 = affine.delinearize_index %0 by (2, 3) : index
4744+
/// %2:2 = affine.delinearize_index %x by (8, 4) : index
4745+
/// where the original %1:4 is replaced by %1:2 ++ %2:2
4746+
struct SplitDelinearizeSpanningLastLinearizeArg final
4747+
: OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4748+
using OpRewritePattern::OpRewritePattern;
4749+
4750+
LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4751+
PatternRewriter &rewriter) const override {
4752+
auto linearizeOp = delinearizeOp.getLinearIndex()
4753+
.getDefiningOp<affine::AffineLinearizeIndexOp>();
4754+
if (!linearizeOp)
4755+
return rewriter.notifyMatchFailure(delinearizeOp,
4756+
"index doesn't come from linearize");
4757+
4758+
if (!linearizeOp.getDisjoint())
4759+
return rewriter.notifyMatchFailure(linearizeOp,
4760+
"linearize isn't disjoint");
4761+
4762+
int64_t target = linearizeOp.getStaticBasis().back();
4763+
if (ShapedType::isDynamic(target))
4764+
return rewriter.notifyMatchFailure(
4765+
linearizeOp, "linearize ends with dynamic basis value");
4766+
4767+
int64_t sizeToSplit = 1;
4768+
size_t elemsToSplit = 0;
4769+
ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
4770+
for (int64_t basisElem : llvm::reverse(basis)) {
4771+
if (ShapedType::isDynamic(basisElem))
4772+
return rewriter.notifyMatchFailure(
4773+
delinearizeOp, "dynamic basis element while scanning for split");
4774+
sizeToSplit *= basisElem;
4775+
elemsToSplit += 1;
4776+
4777+
if (sizeToSplit > target)
4778+
return rewriter.notifyMatchFailure(delinearizeOp,
4779+
"overshot last argument size");
4780+
if (sizeToSplit == target)
4781+
break;
4782+
}
4783+
4784+
if (sizeToSplit < target)
4785+
return rewriter.notifyMatchFailure(
4786+
delinearizeOp, "product of known basis elements doesn't exceed last "
4787+
"linearize argument");
4788+
4789+
if (elemsToSplit < 2)
4790+
return rewriter.notifyMatchFailure(
4791+
delinearizeOp,
4792+
"need at least two elements to form the basis product");
4793+
4794+
Value linearizeWithoutBack =
4795+
rewriter.create<affine::AffineLinearizeIndexOp>(
4796+
linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
4797+
linearizeOp.getDynamicBasis(),
4798+
linearizeOp.getStaticBasis().drop_back(),
4799+
linearizeOp.getDisjoint());
4800+
auto delinearizeWithoutSplitPart =
4801+
rewriter.create<affine::AffineDelinearizeIndexOp>(
4802+
delinearizeOp.getLoc(), linearizeWithoutBack,
4803+
delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
4804+
delinearizeOp.hasOuterBound());
4805+
auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>(
4806+
delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
4807+
basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
4808+
SmallVector<Value> results = llvm::to_vector(
4809+
llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
4810+
delinearizeBack.getResults()));
4811+
rewriter.replaceOp(delinearizeOp, results);
4812+
4813+
return success();
4814+
}
4815+
};
47324816
} // namespace
47334817

47344818
void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
47354819
RewritePatternSet &patterns, MLIRContext *context) {
4736-
patterns.insert<CancelDelinearizeOfLinearizeDisjointExactTail,
4737-
DropUnitExtentBasis>(context);
4820+
patterns
4821+
.insert<CancelDelinearizeOfLinearizeDisjointExactTail,
4822+
DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
4823+
context);
47384824
}
47394825

47404826
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Affine/canonicalize.mlir

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,6 +1795,72 @@ func.func @no_cancel_delinearize_linearize_different_basis(%arg0: index, %arg1:
17951795

17961796
// -----
17971797

1798+
// CHECK-LABEL: func @split_delinearize_spanning_final_part
1799+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
1800+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
1801+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
1802+
// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (2, 4)
1803+
// CHECK: %[[DELIN1:.+]]:2 = affine.delinearize_index %[[LIN]] into (2)
1804+
// CHECK: %[[DELIN2:.+]]:2 = affine.delinearize_index %[[ARG2]] into (8, 8)
1805+
// CHECK: return %[[DELIN1]]#0, %[[DELIN1]]#1, %[[DELIN2]]#0, %[[DELIN2]]#1
1806+
func.func @split_delinearize_spanning_final_part(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
1807+
%0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
1808+
%1:4 = affine.delinearize_index %0 into (2, 8, 8)
1809+
: index, index, index, index
1810+
return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
1811+
}
1812+
1813+
// -----
1814+
1815+
// CHECK-LABEL: func @split_delinearize_spanning_final_part_and_cancel
1816+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
1817+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
1818+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
1819+
// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG2]] into (8, 8)
1820+
// CHECK: return %[[ARG0]], %[[ARG1]], %[[DELIN]]#0, %[[DELIN]]#1
1821+
func.func @split_delinearize_spanning_final_part_and_cancel(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
1822+
%0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
1823+
%1:4 = affine.delinearize_index %0 into (2, 4, 8, 8)
1824+
: index, index, index, index
1825+
return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
1826+
}
1827+
1828+
// -----
1829+
1830+
// The delinearize basis doesn't match the last basis element before
1831+
// overshooting it, don't simplify.
1832+
// CHECK-LABEL: func @dont_split_delinearize_overshooting_target
1833+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
1834+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
1835+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
1836+
// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (2, 4, 64)
1837+
// CHECK: %[[DELIN:.+]]:4 = affine.delinearize_index %[[LIN]] into (2, 16, 8)
1838+
// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2, %[[DELIN]]#3
1839+
func.func @dont_split_delinearize_overshooting_target(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
1840+
%0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
1841+
%1:4 = affine.delinearize_index %0 into (2, 16, 8)
1842+
: index, index, index, index
1843+
return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
1844+
}
1845+
1846+
// -----
1847+
1848+
// The delinearize basis doesn't fully multiply to the final basis element.
1849+
// CHECK-LABEL: func @dont_split_delinearize_undershooting_target
1850+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
1851+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
1852+
// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (2, 64)
1853+
// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (4, 8)
1854+
// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1
1855+
func.func @dont_split_delinearize_undershooting_target(%arg0: index, %arg1: index) -> (index, index, index) {
1856+
%0 = affine.linearize_index disjoint [%arg0, %arg1] by (2, 64) : index
1857+
%1:3 = affine.delinearize_index %0 into (4, 8)
1858+
: index, index, index
1859+
return %1#0, %1#1, %1#2 : index, index, index
1860+
}
1861+
1862+
// -----
1863+
17981864
// CHECK-LABEL: @linearize_unit_basis_disjoint
17991865
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
18001866
// CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index

0 commit comments

Comments
 (0)