Skip to content

[mlir][affine] Add folders for delinearize_index and linearize_index #115766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 13, 2024

Conversation

krzysz00
Copy link
Contributor

This commit adds implementations of fold() for delinearize_index and linearize_index to constant-fold them away when they have a fully constant basis and constant argument(s).

This commit also adds a canonicalization pattern to linearize_index that causes it to drop leading-zero inputs.

@llvmbot
Copy link
Member

llvmbot commented Nov 11, 2024

@llvm/pr-subscribers-mlir-affine

Author: Krzysztof Drewniak (krzysz00)

Changes

This commit adds implementations of fold() for delinearize_index and linearize_index to constant-fold them away when they have a fully constant basis and constant argument(s).

This commit also adds a canonicalization pattern to linearize_index that causes it to drop leading-zero inputs.


Full diff: https://github.com/llvm/llvm-project/pull/115766.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+2)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+68-1)
  • (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+80)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 1dd9b9a440ecc8..753b8951fb084b 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1110,6 +1110,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
   }];
 
   let hasVerifier = 1;
+  let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
 
@@ -1179,6 +1180,7 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
   }];
 
   let hasVerifier = 1;
+  let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
 
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 3d38de4bf1068e..37316632a6a06f 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4556,6 +4556,26 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
   return success();
 }
 
+LogicalResult
+AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
+                               SmallVectorImpl<OpFoldResult> &result) {
+  if (adaptor.getLinearIndex() == nullptr)
+    return failure();
+
+  if (!adaptor.getDynamicBasis().empty())
+    return failure();
+
+  int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
+  Type attrType = getLinearIndex().getType();
+  for (int64_t modulus : llvm::reverse(getStaticBasis().drop_front())) {
+    result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
+    highPart = llvm::divideFloorSigned(highPart, modulus);
+  }
+  result.push_back(IntegerAttr::get(attrType, highPart));
+  std::reverse(result.begin(), result.end());
+  return success();
+}
+
 namespace {
 
 // Drops delinearization indices that correspond to unit-extent basis
@@ -4683,6 +4703,26 @@ LogicalResult AffineLinearizeIndexOp::verify() {
   return success();
 }
 
+OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
+  if (llvm::any_of(adaptor.getMultiIndex(),
+                   [](Attribute a) { return a == nullptr; }))
+    return nullptr;
+
+  if (!adaptor.getDynamicBasis().empty())
+    return nullptr;
+
+  int64_t result = 0;
+  int64_t stride = 1;
+  for (auto [indexAttr, length] :
+       llvm::zip_equal(llvm::reverse(adaptor.getMultiIndex()),
+                       llvm::reverse(getStaticBasis()))) {
+    result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
+    stride = stride * length;
+  }
+
+  return IntegerAttr::get(getResult().getType(), result);
+}
+
 namespace {
 /// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
 /// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
@@ -4751,12 +4791,39 @@ struct DropLinearizeOneBasisElement final
     return success();
   }
 };
+
+/// Strip leading zero from affine.linearize_index.
+///
+/// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
+/// to `affine.linearize_index [...a] by (...b)` in all cases.
+struct DropLinearizeLeadingZero final
+    : OpRewritePattern<affine::AffineLinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
+                                PatternRewriter &rewriter) const override {
+    Value leadingIdx = op.getMultiIndex().front();
+    if (!matchPattern(leadingIdx, m_Zero()))
+      return failure();
+
+    if (op.getMultiIndex().size() == 1) {
+      rewriter.replaceOp(op, leadingIdx);
+      return success();
+    }
+
+    SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
+    rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
+        op, op.getMultiIndex().drop_front(),
+        ArrayRef<OpFoldResult>(mixedBasis).drop_front(), op.getDisjoint());
+    return success();
+  }
+};
 } // namespace
 
 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
   patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero,
-               DropLinearizeOneBasisElement>(context);
+               DropLinearizeOneBasisElement, DropLinearizeLeadingZero>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index fa179744094c67..8ed6ab5c965ca0 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1469,6 +1469,45 @@ func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
 
 // -----
 
+// CHECK-LABEL: @delinearize_fold_constant
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-NOT: affine.delinearize_index
+// CHECK: return %[[C1]], %[[C1]], %[[C2]]
+func.func @delinearize_fold_constant() -> (index, index, index) {
+  %c22 = arith.constant 22 : index
+  %0:3 = affine.delinearize_index %c22 into (2, 3, 5) : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @delinearize_fold_negative_constant
+// CHECK-DAG: %[[C_2:.+]] = arith.constant -2 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK-NOT: affine.delinearize_index
+// CHECK: return %[[C_2]], %[[C1]], %[[C3]]
+func.func @delinearize_fold_negative_constant() -> (index, index, index) {
+  %c_22 = arith.constant -22 : index
+  %0:3 = affine.delinearize_index %c_22 into (2, 3, 5) : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @delinearize_dont_fold_constant_dynamic_basis
+// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
+// CHECK: %[[RET:.+]]:3 = affine.delinearize_index %[[C22]]
+// CHECK: return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2
+func.func @delinearize_dont_fold_constant_dynamic_basis(%arg0: index) -> (index, index, index) {
+  %c22 = arith.constant 22 : index
+  %0:3 = affine.delinearize_index %c22 into (2, %arg0, 5) : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
 func.func @drop_unit_basis_in_delinearize(%arg0 : index, %arg1 : index, %arg2 : index) ->
     (index, index, index, index, index, index) {
   %c1 = arith.constant 1 : index
@@ -1535,6 +1574,33 @@ func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index
 
 // -----
 
+// CHECK-LABEL: @linearize_fold_constants
+// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
+// CHECK-NOT: affine.linearize
+// CHECK: return %[[C22]]
+func.func @linearize_fold_constants() -> index {
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : index
+
+  %ret = affine.linearize_index [%c1, %c1, %c2] by (2, 3, 5) : index
+  return %ret : index
+}
+
+// -----
+
+// CHECK-LABEL: @linearize_dont_fold_dynamic_basis
+// CHECK: %[[RET:.+]] = affine.linearize_index
+// CHECK: return %[[RET]]
+func.func @linearize_dont_fold_dynamic_basis(%arg0: index) -> index {
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : index
+
+  %ret = affine.linearize_index [%c1, %c1, %c2] by (2, %arg0, 5) : index
+  return %ret : index
+}
+
+// -----
+
 // CHECK-LABEL: @linearize_unit_basis_disjoint
 // CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
 // CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index
@@ -1577,3 +1643,17 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index {
   %ret = affine.linearize_index [%arg0] by (%arg1) : index
   return %ret : index
 }
+
+// -----
+
+// CHECK-LABEL: func @affine_leading_zero(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[RET:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (3, 5)
+//       CHECK:     return %[[RET]]
+func.func @affine_leading_zero(%arg0: index, %arg1: index) -> index {
+  %c0 = arith.constant 0 : index
+  %ret = affine.linearize_index [%c0, %arg0, %arg1] by (2, 3, 5) : index
+  return %ret : index
+}
+

@llvmbot
Copy link
Member

llvmbot commented Nov 11, 2024

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

This commit adds implementations of fold() for delinearize_index and linearize_index to constant-fold them away when they have a fully constant basis and constant argument(s).

This commit also adds a canonicalization pattern to linearize_index that causes it to drop leading-zero inputs.


Full diff: https://github.com/llvm/llvm-project/pull/115766.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+2)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+68-1)
  • (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+80)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 1dd9b9a440ecc8..753b8951fb084b 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1110,6 +1110,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
   }];
 
   let hasVerifier = 1;
+  let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
 
@@ -1179,6 +1180,7 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
   }];
 
   let hasVerifier = 1;
+  let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
 
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 3d38de4bf1068e..37316632a6a06f 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4556,6 +4556,26 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
   return success();
 }
 
+LogicalResult
+AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
+                               SmallVectorImpl<OpFoldResult> &result) {
+  if (adaptor.getLinearIndex() == nullptr)
+    return failure();
+
+  if (!adaptor.getDynamicBasis().empty())
+    return failure();
+
+  int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
+  Type attrType = getLinearIndex().getType();
+  for (int64_t modulus : llvm::reverse(getStaticBasis().drop_front())) {
+    result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
+    highPart = llvm::divideFloorSigned(highPart, modulus);
+  }
+  result.push_back(IntegerAttr::get(attrType, highPart));
+  std::reverse(result.begin(), result.end());
+  return success();
+}
+
 namespace {
 
 // Drops delinearization indices that correspond to unit-extent basis
@@ -4683,6 +4703,26 @@ LogicalResult AffineLinearizeIndexOp::verify() {
   return success();
 }
 
+OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
+  if (llvm::any_of(adaptor.getMultiIndex(),
+                   [](Attribute a) { return a == nullptr; }))
+    return nullptr;
+
+  if (!adaptor.getDynamicBasis().empty())
+    return nullptr;
+
+  int64_t result = 0;
+  int64_t stride = 1;
+  for (auto [indexAttr, length] :
+       llvm::zip_equal(llvm::reverse(adaptor.getMultiIndex()),
+                       llvm::reverse(getStaticBasis()))) {
+    result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
+    stride = stride * length;
+  }
+
+  return IntegerAttr::get(getResult().getType(), result);
+}
+
 namespace {
 /// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
 /// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
@@ -4751,12 +4791,39 @@ struct DropLinearizeOneBasisElement final
     return success();
   }
 };
+
+/// Strip leading zero from affine.linearize_index.
+///
+/// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
+/// to `affine.linearize_index [...a] by (...b)` in all cases.
+struct DropLinearizeLeadingZero final
+    : OpRewritePattern<affine::AffineLinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
+                                PatternRewriter &rewriter) const override {
+    Value leadingIdx = op.getMultiIndex().front();
+    if (!matchPattern(leadingIdx, m_Zero()))
+      return failure();
+
+    if (op.getMultiIndex().size() == 1) {
+      rewriter.replaceOp(op, leadingIdx);
+      return success();
+    }
+
+    SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
+    rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
+        op, op.getMultiIndex().drop_front(),
+        ArrayRef<OpFoldResult>(mixedBasis).drop_front(), op.getDisjoint());
+    return success();
+  }
+};
 } // namespace
 
 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
   patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero,
-               DropLinearizeOneBasisElement>(context);
+               DropLinearizeOneBasisElement, DropLinearizeLeadingZero>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index fa179744094c67..8ed6ab5c965ca0 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1469,6 +1469,45 @@ func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
 
 // -----
 
+// CHECK-LABEL: @delinearize_fold_constant
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-NOT: affine.delinearize_index
+// CHECK: return %[[C1]], %[[C1]], %[[C2]]
+func.func @delinearize_fold_constant() -> (index, index, index) {
+  %c22 = arith.constant 22 : index
+  %0:3 = affine.delinearize_index %c22 into (2, 3, 5) : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @delinearize_fold_negative_constant
+// CHECK-DAG: %[[C_2:.+]] = arith.constant -2 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK-NOT: affine.delinearize_index
+// CHECK: return %[[C_2]], %[[C1]], %[[C3]]
+func.func @delinearize_fold_negative_constant() -> (index, index, index) {
+  %c_22 = arith.constant -22 : index
+  %0:3 = affine.delinearize_index %c_22 into (2, 3, 5) : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @delinearize_dont_fold_constant_dynamic_basis
+// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
+// CHECK: %[[RET:.+]]:3 = affine.delinearize_index %[[C22]]
+// CHECK: return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2
+func.func @delinearize_dont_fold_constant_dynamic_basis(%arg0: index) -> (index, index, index) {
+  %c22 = arith.constant 22 : index
+  %0:3 = affine.delinearize_index %c22 into (2, %arg0, 5) : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
 func.func @drop_unit_basis_in_delinearize(%arg0 : index, %arg1 : index, %arg2 : index) ->
     (index, index, index, index, index, index) {
   %c1 = arith.constant 1 : index
@@ -1535,6 +1574,33 @@ func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index
 
 // -----
 
+// CHECK-LABEL: @linearize_fold_constants
+// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
+// CHECK-NOT: affine.linearize
+// CHECK: return %[[C22]]
+func.func @linearize_fold_constants() -> index {
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : index
+
+  %ret = affine.linearize_index [%c1, %c1, %c2] by (2, 3, 5) : index
+  return %ret : index
+}
+
+// -----
+
+// CHECK-LABEL: @linearize_dont_fold_dynamic_basis
+// CHECK: %[[RET:.+]] = affine.linearize_index
+// CHECK: return %[[RET]]
+func.func @linearize_dont_fold_dynamic_basis(%arg0: index) -> index {
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : index
+
+  %ret = affine.linearize_index [%c1, %c1, %c2] by (2, %arg0, 5) : index
+  return %ret : index
+}
+
+// -----
+
 // CHECK-LABEL: @linearize_unit_basis_disjoint
 // CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
 // CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index
@@ -1577,3 +1643,17 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index {
   %ret = affine.linearize_index [%arg0] by (%arg1) : index
   return %ret : index
 }
+
+// -----
+
+// CHECK-LABEL: func @affine_leading_zero(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[RET:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (3, 5)
+//       CHECK:     return %[[RET]]
+func.func @affine_leading_zero(%arg0: index, %arg1: index) -> index {
+  %c0 = arith.constant 0 : index
+  %ret = affine.linearize_index [%c0, %arg0, %arg1] by (2, 3, 5) : index
+  return %ret : index
+}
+

This commit adds implementations of fold() for delinearize_index and
linearize_index to constant-fold them away when they have a fully
constant basis and constant argument(s).

This commit also adds a canonicalization pattern to linearize_index
that causes it to drop leading-zero inputs.
@krzysz00 krzysz00 force-pushed the affine-idx-ops-folders branch from e3448a1 to 1a88bf1 Compare November 12, 2024 21:59

LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
PatternRewriter &rewriter) const override {
Value leadingIdx = op.getMultiIndex().front();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC you should be able to drop any zero-value basis (and the corresponding index) right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is

%0 = affine.linearize_index [%c0, %x, %y] by (A, B, C)

going to

%0 = affine.linearize_index [%x, %y] by (B, C)

We can't do

%0 = affine.linearize_index [%x, %c0, %y] by (A, B, C)

to

%0 = affine.linearize_index [%x, %y] by (A, C)

since that has an %x * C term and not a %x * BC term.

Whether or not canonicalizing that to

%0 = affine.linearize_index [%x, %y] (A, BC)

is profitable is something I'm not sure about yet.

(We definitely can't drop trailing zeroes, since they'd be used to express %y * C-type operations)

@krzysz00 krzysz00 merged commit cb9481d into llvm:main Nov 13, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants