-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][affine] Add folders for delinearize_index and linearize_index #115766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-affine Author: Krzysztof Drewniak (krzysz00) ChangesThis commit adds implementations of fold() for delinearize_index and linearize_index to constant-fold them away when they have a fully constant basis and constant argument(s). This commit also adds a canonicalization pattern to linearize_index that causes it to drop leading-zero inputs. Full diff: https://github.com/llvm/llvm-project/pull/115766.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 1dd9b9a440ecc8..753b8951fb084b 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1110,6 +1110,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
}];
let hasVerifier = 1;
+ let hasFolder = 1;
let hasCanonicalizer = 1;
}
@@ -1179,6 +1180,7 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
}];
let hasVerifier = 1;
+ let hasFolder = 1;
let hasCanonicalizer = 1;
}
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 3d38de4bf1068e..37316632a6a06f 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4556,6 +4556,26 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
return success();
}
+LogicalResult
+AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
+ SmallVectorImpl<OpFoldResult> &result) {
+ if (adaptor.getLinearIndex() == nullptr)
+ return failure();
+
+ if (!adaptor.getDynamicBasis().empty())
+ return failure();
+
+ int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
+ Type attrType = getLinearIndex().getType();
+ for (int64_t modulus : llvm::reverse(getStaticBasis().drop_front())) {
+ result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
+ highPart = llvm::divideFloorSigned(highPart, modulus);
+ }
+ result.push_back(IntegerAttr::get(attrType, highPart));
+ std::reverse(result.begin(), result.end());
+ return success();
+}
+
namespace {
// Drops delinearization indices that correspond to unit-extent basis
@@ -4683,6 +4703,26 @@ LogicalResult AffineLinearizeIndexOp::verify() {
return success();
}
+OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
+ if (llvm::any_of(adaptor.getMultiIndex(),
+ [](Attribute a) { return a == nullptr; }))
+ return nullptr;
+
+ if (!adaptor.getDynamicBasis().empty())
+ return nullptr;
+
+ int64_t result = 0;
+ int64_t stride = 1;
+ for (auto [indexAttr, length] :
+ llvm::zip_equal(llvm::reverse(adaptor.getMultiIndex()),
+ llvm::reverse(getStaticBasis()))) {
+ result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
+ stride = stride * length;
+ }
+
+ return IntegerAttr::get(getResult().getType(), result);
+}
+
namespace {
/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
@@ -4751,12 +4791,39 @@ struct DropLinearizeOneBasisElement final
return success();
}
};
+
+/// Strip leading zero from affine.linearize_index.
+///
+/// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
+/// to `affine.linearize_index [...a] by (...b)` in all cases.
+struct DropLinearizeLeadingZero final
+ : OpRewritePattern<affine::AffineLinearizeIndexOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
+ PatternRewriter &rewriter) const override {
+ Value leadingIdx = op.getMultiIndex().front();
+ if (!matchPattern(leadingIdx, m_Zero()))
+ return failure();
+
+ if (op.getMultiIndex().size() == 1) {
+ rewriter.replaceOp(op, leadingIdx);
+ return success();
+ }
+
+ SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
+ rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
+ op, op.getMultiIndex().drop_front(),
+ ArrayRef<OpFoldResult>(mixedBasis).drop_front(), op.getDisjoint());
+ return success();
+ }
+};
} // namespace
void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero,
- DropLinearizeOneBasisElement>(context);
+ DropLinearizeOneBasisElement, DropLinearizeLeadingZero>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index fa179744094c67..8ed6ab5c965ca0 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1469,6 +1469,45 @@ func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
// -----
+// CHECK-LABEL: @delinearize_fold_constant
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-NOT: affine.delinearize_index
+// CHECK: return %[[C1]], %[[C1]], %[[C2]]
+func.func @delinearize_fold_constant() -> (index, index, index) {
+ %c22 = arith.constant 22 : index
+ %0:3 = affine.delinearize_index %c22 into (2, 3, 5) : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @delinearize_fold_negative_constant
+// CHECK-DAG: %[[C_2:.+]] = arith.constant -2 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK-NOT: affine.delinearize_index
+// CHECK: return %[[C_2]], %[[C1]], %[[C3]]
+func.func @delinearize_fold_negative_constant() -> (index, index, index) {
+ %c_22 = arith.constant -22 : index
+ %0:3 = affine.delinearize_index %c_22 into (2, 3, 5) : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @delinearize_dont_fold_constant_dynamic_basis
+// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
+// CHECK: %[[RET:.+]]:3 = affine.delinearize_index %[[C22]]
+// CHECK: return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2
+func.func @delinearize_dont_fold_constant_dynamic_basis(%arg0: index) -> (index, index, index) {
+ %c22 = arith.constant 22 : index
+ %0:3 = affine.delinearize_index %c22 into (2, %arg0, 5) : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
func.func @drop_unit_basis_in_delinearize(%arg0 : index, %arg1 : index, %arg2 : index) ->
(index, index, index, index, index, index) {
%c1 = arith.constant 1 : index
@@ -1535,6 +1574,33 @@ func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index
// -----
+// CHECK-LABEL: @linearize_fold_constants
+// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
+// CHECK-NOT: affine.linearize
+// CHECK: return %[[C22]]
+func.func @linearize_fold_constants() -> index {
+ %c2 = arith.constant 2 : index
+ %c1 = arith.constant 1 : index
+
+ %ret = affine.linearize_index [%c1, %c1, %c2] by (2, 3, 5) : index
+ return %ret : index
+}
+
+// -----
+
+// CHECK-LABEL: @linearize_dont_fold_dynamic_basis
+// CHECK: %[[RET:.+]] = affine.linearize_index
+// CHECK: return %[[RET]]
+func.func @linearize_dont_fold_dynamic_basis(%arg0: index) -> index {
+ %c2 = arith.constant 2 : index
+ %c1 = arith.constant 1 : index
+
+ %ret = affine.linearize_index [%c1, %c1, %c2] by (2, %arg0, 5) : index
+ return %ret : index
+}
+
+// -----
+
// CHECK-LABEL: @linearize_unit_basis_disjoint
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
// CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index
@@ -1577,3 +1643,17 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index {
%ret = affine.linearize_index [%arg0] by (%arg1) : index
return %ret : index
}
+
+// -----
+
+// CHECK-LABEL: func @affine_leading_zero(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[RET:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (3, 5)
+// CHECK: return %[[RET]]
+func.func @affine_leading_zero(%arg0: index, %arg1: index) -> index {
+ %c0 = arith.constant 0 : index
+ %ret = affine.linearize_index [%c0, %arg0, %arg1] by (2, 3, 5) : index
+ return %ret : index
+}
+
|
@llvm/pr-subscribers-mlir Author: Krzysztof Drewniak (krzysz00) ChangesThis commit adds implementations of fold() for delinearize_index and linearize_index to constant-fold them away when they have a fully constant basis and constant argument(s). This commit also adds a canonicalization pattern to linearize_index that causes it to drop leading-zero inputs. Full diff: https://github.com/llvm/llvm-project/pull/115766.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 1dd9b9a440ecc8..753b8951fb084b 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1110,6 +1110,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
}];
let hasVerifier = 1;
+ let hasFolder = 1;
let hasCanonicalizer = 1;
}
@@ -1179,6 +1180,7 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
}];
let hasVerifier = 1;
+ let hasFolder = 1;
let hasCanonicalizer = 1;
}
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 3d38de4bf1068e..37316632a6a06f 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4556,6 +4556,26 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
return success();
}
+LogicalResult
+AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
+ SmallVectorImpl<OpFoldResult> &result) {
+ if (adaptor.getLinearIndex() == nullptr)
+ return failure();
+
+ if (!adaptor.getDynamicBasis().empty())
+ return failure();
+
+ int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
+ Type attrType = getLinearIndex().getType();
+ for (int64_t modulus : llvm::reverse(getStaticBasis().drop_front())) {
+ result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
+ highPart = llvm::divideFloorSigned(highPart, modulus);
+ }
+ result.push_back(IntegerAttr::get(attrType, highPart));
+ std::reverse(result.begin(), result.end());
+ return success();
+}
+
namespace {
// Drops delinearization indices that correspond to unit-extent basis
@@ -4683,6 +4703,26 @@ LogicalResult AffineLinearizeIndexOp::verify() {
return success();
}
+OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
+ if (llvm::any_of(adaptor.getMultiIndex(),
+ [](Attribute a) { return a == nullptr; }))
+ return nullptr;
+
+ if (!adaptor.getDynamicBasis().empty())
+ return nullptr;
+
+ int64_t result = 0;
+ int64_t stride = 1;
+ for (auto [indexAttr, length] :
+ llvm::zip_equal(llvm::reverse(adaptor.getMultiIndex()),
+ llvm::reverse(getStaticBasis()))) {
+ result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
+ stride = stride * length;
+ }
+
+ return IntegerAttr::get(getResult().getType(), result);
+}
+
namespace {
/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
@@ -4751,12 +4791,39 @@ struct DropLinearizeOneBasisElement final
return success();
}
};
+
+/// Strip leading zero from affine.linearize_index.
+///
+/// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
+/// to `affine.linearize_index [...a] by (...b)` in all cases.
+struct DropLinearizeLeadingZero final
+ : OpRewritePattern<affine::AffineLinearizeIndexOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
+ PatternRewriter &rewriter) const override {
+ Value leadingIdx = op.getMultiIndex().front();
+ if (!matchPattern(leadingIdx, m_Zero()))
+ return failure();
+
+ if (op.getMultiIndex().size() == 1) {
+ rewriter.replaceOp(op, leadingIdx);
+ return success();
+ }
+
+ SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
+ rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
+ op, op.getMultiIndex().drop_front(),
+ ArrayRef<OpFoldResult>(mixedBasis).drop_front(), op.getDisjoint());
+ return success();
+ }
+};
} // namespace
void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero,
- DropLinearizeOneBasisElement>(context);
+ DropLinearizeOneBasisElement, DropLinearizeLeadingZero>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index fa179744094c67..8ed6ab5c965ca0 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1469,6 +1469,45 @@ func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
// -----
+// CHECK-LABEL: @delinearize_fold_constant
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-NOT: affine.delinearize_index
+// CHECK: return %[[C1]], %[[C1]], %[[C2]]
+func.func @delinearize_fold_constant() -> (index, index, index) {
+ %c22 = arith.constant 22 : index
+ %0:3 = affine.delinearize_index %c22 into (2, 3, 5) : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @delinearize_fold_negative_constant
+// CHECK-DAG: %[[C_2:.+]] = arith.constant -2 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK-NOT: affine.delinearize_index
+// CHECK: return %[[C_2]], %[[C1]], %[[C3]]
+func.func @delinearize_fold_negative_constant() -> (index, index, index) {
+ %c_22 = arith.constant -22 : index
+ %0:3 = affine.delinearize_index %c_22 into (2, 3, 5) : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @delinearize_dont_fold_constant_dynamic_basis
+// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
+// CHECK: %[[RET:.+]]:3 = affine.delinearize_index %[[C22]]
+// CHECK: return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2
+func.func @delinearize_dont_fold_constant_dynamic_basis(%arg0: index) -> (index, index, index) {
+ %c22 = arith.constant 22 : index
+ %0:3 = affine.delinearize_index %c22 into (2, %arg0, 5) : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
func.func @drop_unit_basis_in_delinearize(%arg0 : index, %arg1 : index, %arg2 : index) ->
(index, index, index, index, index, index) {
%c1 = arith.constant 1 : index
@@ -1535,6 +1574,33 @@ func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index
// -----
+// CHECK-LABEL: @linearize_fold_constants
+// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
+// CHECK-NOT: affine.linearize
+// CHECK: return %[[C22]]
+func.func @linearize_fold_constants() -> index {
+ %c2 = arith.constant 2 : index
+ %c1 = arith.constant 1 : index
+
+ %ret = affine.linearize_index [%c1, %c1, %c2] by (2, 3, 5) : index
+ return %ret : index
+}
+
+// -----
+
+// CHECK-LABEL: @linearize_dont_fold_dynamic_basis
+// CHECK: %[[RET:.+]] = affine.linearize_index
+// CHECK: return %[[RET]]
+func.func @linearize_dont_fold_dynamic_basis(%arg0: index) -> index {
+ %c2 = arith.constant 2 : index
+ %c1 = arith.constant 1 : index
+
+ %ret = affine.linearize_index [%c1, %c1, %c2] by (2, %arg0, 5) : index
+ return %ret : index
+}
+
+// -----
+
// CHECK-LABEL: @linearize_unit_basis_disjoint
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
// CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index
@@ -1577,3 +1643,17 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index {
%ret = affine.linearize_index [%arg0] by (%arg1) : index
return %ret : index
}
+
+// -----
+
+// CHECK-LABEL: func @affine_leading_zero(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[RET:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (3, 5)
+// CHECK: return %[[RET]]
+func.func @affine_leading_zero(%arg0: index, %arg1: index) -> index {
+ %c0 = arith.constant 0 : index
+ %ret = affine.linearize_index [%c0, %arg0, %arg1] by (2, 3, 5) : index
+ return %ret : index
+}
+
|
This commit adds implementations of fold() for delinearize_index and linearize_index to constant-fold them away when they have a fully constant basis and constant argument(s). This commit also adds a canonicalization pattern to linearize_index that causes it to drop leading-zero inputs.
e3448a1
to
1a88bf1
Compare
|
||
LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op, | ||
PatternRewriter &rewriter) const override { | ||
Value leadingIdx = op.getMultiIndex().front(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC you should be able to drop any zero-value basis (and the corresponding index) right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is
%0 = affine.linearize_index [%c0, %x, %y] by (A, B, C)
going to
%0 = affine.linearize_index [%x, %y] by (B, C)
We can't do
%0 = affine.linearize_index [%x, %c0, %y] by (A, B, C)
to
%0 = affine.linearize_index [%x, %y] by (A, C)
since that has an %x * C
term and not a %x * BC
term.
Whether or not canonicalizing that to
%0 = affine.linearize_index [%x, %y] (A, BC)
is profitable is something I'm not sure about yet.
(We definitely can't drop trailing zeroes, since they'd be used to express %y * C
-type operations)
This commit adds implementations of fold() for delinearize_index and linearize_index to constant-fold them away when they have a fully constant basis and constant argument(s).
This commit also adds a canonicalization pattern to linearize_index that causes it to drop leading-zero inputs.