Skip to content

[mlir][Affine] Split off delinearize parts that depend on last component #117015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 25, 2024

Conversation

krzysz00
Copy link
Contributor

@krzysz00 krzysz00 commented Nov 20, 2024

If we have

%0 = affine.linearize_index disjoint [%a, %b] by (A, B)
%1:3 = affine.delinearize_index %0 into (A, B1, B2)

where B = B1 * B2 (or some mor complex product), we can simplify this to

%0 = affine.linearize_index disjoint [%a] by (A)
%1a:1 = affine.delinearize_index %0 into (A)
%1b:2 = affine.delinearize_index %b into (B1, B2)

This, and more complex cases, prevent us from adding terms together only to divide them away from each other.

If we have
    %0 = affine.linearize_index disjoint [%a, %b] by (A, B)
    %1:3 = affine.delinearize_index %0 into (A, B1, B2)
where B = B1 * B2 (or some mor complex product), we can simplify this
to
    %0 = affine.linearize_index disjoint [%a] by (A)
    %1a:1 = affine.delinearize_index %0 into (A)
    %1b:2 = affine.delinearize_index %b into (B1, B2)

This, and more complex cases, prevent us from adding terms together
only to divide them away from each other.
@llvmbot
Copy link
Member

llvmbot commented Nov 20, 2024

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

If we have
%0 = affine.linearize_index disjoint [%a, %b] by (A, B)
%1:3 = affine.delinearize_index %0 into (A, B1, B2)
where B = B1 * B2 (or some mor complex product), we can simplify this to
%0 = affine.linearize_index disjoint [%a] by (A)
%1a:1 = affine.delinearize_index %0 into (A)
%1b:2 = affine.delinearize_index %b into (B1, B2)

This, and more complex cases, prevent us from adding terms together only to divide them away from each other.


Full diff: https://github.com/llvm/llvm-project/pull/117015.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+85-2)
  • (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+66)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 4cf07bc167eab9..b13331abc32ada 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4694,13 +4694,96 @@ struct CancelDelinearizeOfLinearizeDisjointExact
     return success();
   }
 };
+
+/// If the input to a delinearization is a disjoint linearization, and the
+/// last k > 1 components of the delinearization basis multiply to the
+/// last component of the linearization basis, break the linearization and
+/// delinearization into two parts, peeling off the last input to linearization.
+///
+/// For example:
+///    %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
+///    %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
+/// becomes
+///    %0 = affine.linearize_index [%z, %y] by (3, 2) : index
+///    %1:2 = affine.delinearize_index %0 by (2, 3) : index
+///    %2:2 = affine.delinearize_index %x by (8, 4) : index
+/// where the original %1:4 is replaced by %1:2 ++ %2:2
+struct SplitDelinearizeSpanningLastLinearizeArg final
+    : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
+                                PatternRewriter &rewriter) const override {
+    auto linearizeOp = delinearizeOp.getLinearIndex()
+                           .getDefiningOp<affine::AffineLinearizeIndexOp>();
+    if (!linearizeOp)
+      return rewriter.notifyMatchFailure(delinearizeOp,
+                                         "index doesn't come from linearize");
+
+    if (!linearizeOp.getDisjoint())
+      return rewriter.notifyMatchFailure(linearizeOp,
+                                         "linearize isn't disjoint");
+
+    int64_t target = linearizeOp.getStaticBasis().back();
+    if (ShapedType::isDynamic(target))
+      return rewriter.notifyMatchFailure(
+          linearizeOp, "linearize ends with dynamic basis value");
+
+    int64_t sizeToSplit = 1;
+    size_t elemsToSplit = 0;
+    ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
+    for (int64_t basisElem : llvm::reverse(basis)) {
+      if (ShapedType::isDynamic(basisElem))
+        return rewriter.notifyMatchFailure(
+            delinearizeOp, "dynamic basis element while scanning for split");
+      sizeToSplit *= basisElem;
+      elemsToSplit += 1;
+
+      if (sizeToSplit > target)
+        return rewriter.notifyMatchFailure(delinearizeOp,
+                                           "overshot last argument size");
+      if (sizeToSplit == target)
+        break;
+    }
+
+    if (sizeToSplit < target)
+      return rewriter.notifyMatchFailure(
+          delinearizeOp, "product of known basis elements doesn't exceed last "
+                         "linearize argument");
+
+    if (elemsToSplit < 2)
+      return rewriter.notifyMatchFailure(
+          delinearizeOp, "don't have a non-trivial basis product");
+
+    Value linearizeWithoutBack =
+        rewriter.create<affine::AffineLinearizeIndexOp>(
+            linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
+            linearizeOp.getDynamicBasis(),
+            linearizeOp.getStaticBasis().drop_back(),
+            linearizeOp.getDisjoint());
+    auto delinearizeWithoutSplitPart =
+        rewriter.create<affine::AffineDelinearizeIndexOp>(
+            delinearizeOp.getLoc(), linearizeWithoutBack,
+            delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
+            delinearizeOp.hasOuterBound());
+    auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>(
+        delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
+        basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
+    SmallVector<Value> results = llvm::to_vector(
+        llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
+                            delinearizeBack.getResults()));
+    rewriter.replaceOp(delinearizeOp, results);
+
+    return success();
+  }
+};
 } // namespace
 
 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
   patterns
-      .insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis>(
-          context);
+      .insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis,
+              SplitDelinearizeSpanningLastLinearizeArg>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index b54a13cffe7771..efeea7eb2af530 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1777,6 +1777,72 @@ func.func @no_cancel_delinearize_linearize_different_basis(%arg0: index, %arg1:
 
 // -----
 
+// CHECK-LABEL: func @split_delinearize_spanning_final_part
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (2, 4)
+//       CHECK:     %[[DELIN1:.+]]:2 = affine.delinearize_index %[[LIN]] into (2)
+//       CHECK:     %[[DELIN2:.+]]:2 = affine.delinearize_index %[[ARG2]] into (8, 8)
+//       CHECK:     return %[[DELIN1]]#0, %[[DELIN1]]#1, %[[DELIN2]]#0, %[[DELIN2]]#1
+func.func @split_delinearize_spanning_final_part(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
+  %1:4 = affine.delinearize_index %0 into (2, 8, 8)
+      : index, index, index, index
+  return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @split_delinearize_spanning_final_part_and_cancel
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG2]] into (8, 8)
+//       CHECK:     return %[[ARG0]], %[[ARG1]], %[[DELIN]]#0, %[[DELIN]]#1
+func.func @split_delinearize_spanning_final_part_and_cancel(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
+  %1:4 = affine.delinearize_index %0 into (2, 4, 8, 8)
+      : index, index, index, index
+  return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
+}
+
+// -----
+
+// The delinearize basis doesn't match the last basis element before
+// overshooting it, don't simplify.
+// CHECK-LABEL: func @dont_split_delinearize_overshooting_target
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (2, 4, 64)
+//       CHECK:     %[[DELIN:.+]]:4 = affine.delinearize_index %[[LIN]] into (2, 16, 8)
+//       CHECK:     return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2, %[[DELIN]]#3
+func.func @dont_split_delinearize_overshooting_target(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
+  %1:4 = affine.delinearize_index %0 into (2, 16, 8)
+      : index, index, index, index
+  return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
+}
+
+// -----
+
+// The delinearize basis doesn't fully multiply to the final basis element.
+// CHECK-LABEL: func @dont_split_delinearize_undershooting_target
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (2, 64)
+//       CHECK:     %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (4, 8)
+//       CHECK:     return %[[DELIN]]#0, %[[DELIN]]#1
+func.func @dont_split_delinearize_undershooting_target(%arg0: index, %arg1: index) -> (index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1] by (2, 64) : index
+  %1:3 = affine.delinearize_index %0 into (4, 8)
+      : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
 // CHECK-LABEL: @linearize_unit_basis_disjoint
 // CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
 // CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index

@llvmbot
Copy link
Member

llvmbot commented Nov 20, 2024

@llvm/pr-subscribers-mlir-affine

Author: Krzysztof Drewniak (krzysz00)

Changes

If we have
%0 = affine.linearize_index disjoint [%a, %b] by (A, B)
%1:3 = affine.delinearize_index %0 into (A, B1, B2)
where B = B1 * B2 (or some mor complex product), we can simplify this to
%0 = affine.linearize_index disjoint [%a] by (A)
%1a:1 = affine.delinearize_index %0 into (A)
%1b:2 = affine.delinearize_index %b into (B1, B2)

This, and more complex cases, prevent us from adding terms together only to divide them away from each other.


Full diff: https://github.com/llvm/llvm-project/pull/117015.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+85-2)
  • (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+66)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 4cf07bc167eab9..b13331abc32ada 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4694,13 +4694,96 @@ struct CancelDelinearizeOfLinearizeDisjointExact
     return success();
   }
 };
+
+/// If the input to a delinearization is a disjoint linearization, and the
+/// last k > 1 components of the delinearization basis multiply to the
+/// last component of the linearization basis, break the linearization and
+/// delinearization into two parts, peeling off the last input to linearization.
+///
+/// For example:
+///    %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
+///    %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
+/// becomes
+///    %0 = affine.linearize_index [%z, %y] by (3, 2) : index
+///    %1:2 = affine.delinearize_index %0 by (2, 3) : index
+///    %2:2 = affine.delinearize_index %x by (8, 4) : index
+/// where the original %1:4 is replaced by %1:2 ++ %2:2
+struct SplitDelinearizeSpanningLastLinearizeArg final
+    : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
+                                PatternRewriter &rewriter) const override {
+    auto linearizeOp = delinearizeOp.getLinearIndex()
+                           .getDefiningOp<affine::AffineLinearizeIndexOp>();
+    if (!linearizeOp)
+      return rewriter.notifyMatchFailure(delinearizeOp,
+                                         "index doesn't come from linearize");
+
+    if (!linearizeOp.getDisjoint())
+      return rewriter.notifyMatchFailure(linearizeOp,
+                                         "linearize isn't disjoint");
+
+    int64_t target = linearizeOp.getStaticBasis().back();
+    if (ShapedType::isDynamic(target))
+      return rewriter.notifyMatchFailure(
+          linearizeOp, "linearize ends with dynamic basis value");
+
+    int64_t sizeToSplit = 1;
+    size_t elemsToSplit = 0;
+    ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
+    for (int64_t basisElem : llvm::reverse(basis)) {
+      if (ShapedType::isDynamic(basisElem))
+        return rewriter.notifyMatchFailure(
+            delinearizeOp, "dynamic basis element while scanning for split");
+      sizeToSplit *= basisElem;
+      elemsToSplit += 1;
+
+      if (sizeToSplit > target)
+        return rewriter.notifyMatchFailure(delinearizeOp,
+                                           "overshot last argument size");
+      if (sizeToSplit == target)
+        break;
+    }
+
+    if (sizeToSplit < target)
+      return rewriter.notifyMatchFailure(
+          delinearizeOp, "product of known basis elements doesn't exceed last "
+                         "linearize argument");
+
+    if (elemsToSplit < 2)
+      return rewriter.notifyMatchFailure(
+          delinearizeOp, "don't have a non-trivial basis product");
+
+    Value linearizeWithoutBack =
+        rewriter.create<affine::AffineLinearizeIndexOp>(
+            linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
+            linearizeOp.getDynamicBasis(),
+            linearizeOp.getStaticBasis().drop_back(),
+            linearizeOp.getDisjoint());
+    auto delinearizeWithoutSplitPart =
+        rewriter.create<affine::AffineDelinearizeIndexOp>(
+            delinearizeOp.getLoc(), linearizeWithoutBack,
+            delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
+            delinearizeOp.hasOuterBound());
+    auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>(
+        delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
+        basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
+    SmallVector<Value> results = llvm::to_vector(
+        llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
+                            delinearizeBack.getResults()));
+    rewriter.replaceOp(delinearizeOp, results);
+
+    return success();
+  }
+};
 } // namespace
 
 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
   patterns
-      .insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis>(
-          context);
+      .insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis,
+              SplitDelinearizeSpanningLastLinearizeArg>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index b54a13cffe7771..efeea7eb2af530 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1777,6 +1777,72 @@ func.func @no_cancel_delinearize_linearize_different_basis(%arg0: index, %arg1:
 
 // -----
 
+// CHECK-LABEL: func @split_delinearize_spanning_final_part
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (2, 4)
+//       CHECK:     %[[DELIN1:.+]]:2 = affine.delinearize_index %[[LIN]] into (2)
+//       CHECK:     %[[DELIN2:.+]]:2 = affine.delinearize_index %[[ARG2]] into (8, 8)
+//       CHECK:     return %[[DELIN1]]#0, %[[DELIN1]]#1, %[[DELIN2]]#0, %[[DELIN2]]#1
+func.func @split_delinearize_spanning_final_part(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
+  %1:4 = affine.delinearize_index %0 into (2, 8, 8)
+      : index, index, index, index
+  return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @split_delinearize_spanning_final_part_and_cancel
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG2]] into (8, 8)
+//       CHECK:     return %[[ARG0]], %[[ARG1]], %[[DELIN]]#0, %[[DELIN]]#1
+func.func @split_delinearize_spanning_final_part_and_cancel(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
+  %1:4 = affine.delinearize_index %0 into (2, 4, 8, 8)
+      : index, index, index, index
+  return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
+}
+
+// -----
+
+// The delinearize basis doesn't match the last basis element before
+// overshooting it, don't simplify.
+// CHECK-LABEL: func @dont_split_delinearize_overshooting_target
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (2, 4, 64)
+//       CHECK:     %[[DELIN:.+]]:4 = affine.delinearize_index %[[LIN]] into (2, 16, 8)
+//       CHECK:     return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2, %[[DELIN]]#3
+func.func @dont_split_delinearize_overshooting_target(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
+  %1:4 = affine.delinearize_index %0 into (2, 16, 8)
+      : index, index, index, index
+  return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
+}
+
+// -----
+
+// The delinearize basis doesn't fully multiply to the final basis element.
+// CHECK-LABEL: func @dont_split_delinearize_undershooting_target
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (2, 64)
+//       CHECK:     %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (4, 8)
+//       CHECK:     return %[[DELIN]]#0, %[[DELIN]]#1
+func.func @dont_split_delinearize_undershooting_target(%arg0: index, %arg1: index) -> (index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1] by (2, 64) : index
+  %1:3 = affine.delinearize_index %0 into (4, 8)
+      : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
 // CHECK-LABEL: @linearize_unit_basis_disjoint
 // CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
 // CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index

@krzysz00 krzysz00 requested a review from Groverkss November 20, 2024 17:46
Copy link
Contributor

@Abhishek-Varma Abhishek-Varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing PR @krzysz00 ! Approving with two requests! Thanks!

Comment on lines 4789 to 4791
if (elemsToSplit < 2)
return rewriter.notifyMatchFailure(
delinearizeOp, "don't have a non-trivial basis product");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, maybe add a lit test to test this path as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's implicit in an existing test that permutes but I'll go add another one if it isn't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... Ok, yeah, the case where the final components agree exactly gets sent down a different pattern, is what's going on

@krzysz00 krzysz00 merged commit ece4e12 into llvm:main Nov 25, 2024
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants