Skip to content

Commit cb9481d

Browse files
authored
[mlir][affine] Add folders for delinearize_index and linearize_index (llvm#115766)
This commit adds implementations of fold() for delinearize_index and linearize_index to constant-fold them away when they have a fully constant basis and constant argument(s). This commit also adds a canonicalization pattern to linearize_index that causes it to drop leading-zero inputs.
1 parent 71ae021 commit cb9481d

File tree

3 files changed

+151
-1
lines changed

3 files changed

+151
-1
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
11101110
}];
11111111

11121112
let hasVerifier = 1;
1113+
let hasFolder = 1;
11131114
let hasCanonicalizer = 1;
11141115
}
11151116

@@ -1179,6 +1180,7 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
11791180
}];
11801181

11811182
let hasVerifier = 1;
1183+
let hasFolder = 1;
11821184
let hasCanonicalizer = 1;
11831185
}
11841186

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4556,6 +4556,26 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
45564556
return success();
45574557
}
45584558

4559+
LogicalResult
4560+
AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
4561+
SmallVectorImpl<OpFoldResult> &result) {
4562+
if (adaptor.getLinearIndex() == nullptr)
4563+
return failure();
4564+
4565+
if (!adaptor.getDynamicBasis().empty())
4566+
return failure();
4567+
4568+
int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
4569+
Type attrType = getLinearIndex().getType();
4570+
for (int64_t modulus : llvm::reverse(getStaticBasis().drop_front())) {
4571+
result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
4572+
highPart = llvm::divideFloorSigned(highPart, modulus);
4573+
}
4574+
result.push_back(IntegerAttr::get(attrType, highPart));
4575+
std::reverse(result.begin(), result.end());
4576+
return success();
4577+
}
4578+
45594579
namespace {
45604580

45614581
// Drops delinearization indices that correspond to unit-extent basis
@@ -4715,6 +4735,26 @@ LogicalResult AffineLinearizeIndexOp::verify() {
47154735
return success();
47164736
}
47174737

4738+
OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
4739+
if (llvm::any_of(adaptor.getMultiIndex(),
4740+
[](Attribute a) { return a == nullptr; }))
4741+
return nullptr;
4742+
4743+
if (!adaptor.getDynamicBasis().empty())
4744+
return nullptr;
4745+
4746+
int64_t result = 0;
4747+
int64_t stride = 1;
4748+
for (auto [indexAttr, length] :
4749+
llvm::zip_equal(llvm::reverse(adaptor.getMultiIndex()),
4750+
llvm::reverse(getStaticBasis()))) {
4751+
result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
4752+
stride = stride * length;
4753+
}
4754+
4755+
return IntegerAttr::get(getResult().getType(), result);
4756+
}
4757+
47184758
namespace {
47194759
/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
47204760
/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
@@ -4820,11 +4860,39 @@ struct CancelLinearizeOfDelinearizeExact final
48204860
return success();
48214861
}
48224862
};
4863+
4864+
/// Strip leading zero from affine.linearize_index.
4865+
///
4866+
/// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
4867+
/// to `affine.linearize_index [...a] by (...b)` in all cases.
4868+
struct DropLinearizeLeadingZero final
4869+
: OpRewritePattern<affine::AffineLinearizeIndexOp> {
4870+
using OpRewritePattern::OpRewritePattern;
4871+
4872+
LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
4873+
PatternRewriter &rewriter) const override {
4874+
Value leadingIdx = op.getMultiIndex().front();
4875+
if (!matchPattern(leadingIdx, m_Zero()))
4876+
return failure();
4877+
4878+
if (op.getMultiIndex().size() == 1) {
4879+
rewriter.replaceOp(op, leadingIdx);
4880+
return success();
4881+
}
4882+
4883+
SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
4884+
rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
4885+
op, op.getMultiIndex().drop_front(),
4886+
ArrayRef<OpFoldResult>(mixedBasis).drop_front(), op.getDisjoint());
4887+
return success();
4888+
}
4889+
};
48234890
} // namespace
48244891

48254892
void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
48264893
RewritePatternSet &patterns, MLIRContext *context) {
4827-
patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeOneBasisElement,
4894+
patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero,
4895+
DropLinearizeOneBasisElement,
48284896
DropLinearizeUnitComponentsIfDisjointOrZero>(context);
48294897
}
48304898

mlir/test/Dialect/Affine/canonicalize.mlir

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1469,6 +1469,45 @@ func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
14691469

14701470
// -----
14711471

1472+
// CHECK-LABEL: @delinearize_fold_constant
1473+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
1474+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
1475+
// CHECK-NOT: affine.delinearize_index
1476+
// CHECK: return %[[C1]], %[[C1]], %[[C2]]
1477+
func.func @delinearize_fold_constant() -> (index, index, index) {
1478+
%c22 = arith.constant 22 : index
1479+
%0:3 = affine.delinearize_index %c22 into (2, 3, 5) : index, index, index
1480+
return %0#0, %0#1, %0#2 : index, index, index
1481+
}
1482+
1483+
// -----
1484+
1485+
// CHECK-LABEL: @delinearize_fold_negative_constant
1486+
// CHECK-DAG: %[[C_2:.+]] = arith.constant -2 : index
1487+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
1488+
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
1489+
// CHECK-NOT: affine.delinearize_index
1490+
// CHECK: return %[[C_2]], %[[C1]], %[[C3]]
1491+
func.func @delinearize_fold_negative_constant() -> (index, index, index) {
1492+
%c_22 = arith.constant -22 : index
1493+
%0:3 = affine.delinearize_index %c_22 into (2, 3, 5) : index, index, index
1494+
return %0#0, %0#1, %0#2 : index, index, index
1495+
}
1496+
1497+
// -----
1498+
1499+
// CHECK-LABEL: @delinearize_dont_fold_constant_dynamic_basis
1500+
// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
1501+
// CHECK: %[[RET:.+]]:3 = affine.delinearize_index %[[C22]]
1502+
// CHECK: return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2
1503+
func.func @delinearize_dont_fold_constant_dynamic_basis(%arg0: index) -> (index, index, index) {
1504+
%c22 = arith.constant 22 : index
1505+
%0:3 = affine.delinearize_index %c22 into (2, %arg0, 5) : index, index, index
1506+
return %0#0, %0#1, %0#2 : index, index, index
1507+
}
1508+
1509+
// -----
1510+
14721511
func.func @drop_unit_basis_in_delinearize(%arg0 : index, %arg1 : index, %arg2 : index) ->
14731512
(index, index, index, index, index, index) {
14741513
%c1 = arith.constant 1 : index
@@ -1535,6 +1574,33 @@ func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index
15351574

15361575
// -----
15371576

1577+
// CHECK-LABEL: @linearize_fold_constants
1578+
// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
1579+
// CHECK-NOT: affine.linearize
1580+
// CHECK: return %[[C22]]
1581+
func.func @linearize_fold_constants() -> index {
1582+
%c2 = arith.constant 2 : index
1583+
%c1 = arith.constant 1 : index
1584+
1585+
%ret = affine.linearize_index [%c1, %c1, %c2] by (2, 3, 5) : index
1586+
return %ret : index
1587+
}
1588+
1589+
// -----
1590+
1591+
// CHECK-LABEL: @linearize_dont_fold_dynamic_basis
1592+
// CHECK: %[[RET:.+]] = affine.linearize_index
1593+
// CHECK: return %[[RET]]
1594+
func.func @linearize_dont_fold_dynamic_basis(%arg0: index) -> index {
1595+
%c2 = arith.constant 2 : index
1596+
%c1 = arith.constant 1 : index
1597+
1598+
%ret = affine.linearize_index [%c1, %c1, %c2] by (2, %arg0, 5) : index
1599+
return %ret : index
1600+
}
1601+
1602+
// -----
1603+
15381604
// CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_exact(
15391605
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
15401606
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
@@ -1676,3 +1742,17 @@ func.func @no_cancel_linearize_denearize_different_basis(%arg0: index, %arg1: in
16761742
%1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 8, %arg2) : index
16771743
return %1 : index
16781744
}
1745+
1746+
// -----
1747+
1748+
// CHECK-LABEL: func @affine_leading_zero(
1749+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
1750+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
1751+
// CHECK: %[[RET:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (3, 5)
1752+
// CHECK: return %[[RET]]
1753+
func.func @affine_leading_zero(%arg0: index, %arg1: index) -> index {
1754+
%c0 = arith.constant 0 : index
1755+
%ret = affine.linearize_index [%c0, %arg0, %arg1] by (2, 3, 5) : index
1756+
return %ret : index
1757+
}
1758+

0 commit comments

Comments
 (0)