Skip to content

[mlir] Replace dynamic sizes in insert_slice of tensor.cast canonicalization #91352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 8, 2024

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented May 7, 2024

In some cases this pattern may ignore static information due to dynamic operands in the insert_slice sizes operands, e.g.:

%0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
%1 = tensor.insert_slice %0 into %arg1[...] [%s0, %s1] [...] 
    : tensor<?x?xf32> into tensor<?x?xf32>

Can be rewritten into:

%1 = tensor.insert_slice %arg0 into %arg1[...] [1, %s1] [...] 
    : tensor<1x?xf32> into tensor<?x?xf32>

This PR updates the matching in the pattern to allow rewrites like this.

@llvmbot
Copy link
Member

llvmbot commented May 7, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: None (Max191)

Changes

In some cases this pattern may ignore static information due to dynamic operands in the insert_slice sizes operands, e.g.:

%0 = tensor.cast %arg0 : tensor&lt;1x?xf32&gt; to tensor&lt;?x?xf32&gt;
%1 = tensor.insert_slice %0 into %arg1[...] [%s0, %s1] [...] 
    : tensor&lt;?x?xf32&gt; into tensor&lt;?x?xf32&gt;

Can be rewritten into:

%1 = tensor.insert_slice %arg0 into %arg1[...] [1, %s1] [...] 
    : tensor&lt;1x?xf32&gt; into tensor&lt;?x?xf32&gt;

This PR updates the matching in the pattern to allow rewrites like this.


Full diff: https://github.com/llvm/llvm-project/pull/91352.diff

4 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+7-1)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+26-3)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+12-12)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+4-5)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 2361cf1371237b..5579b138668d2b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -360,9 +360,15 @@ class VectorType::Builder {
 /// which dimensions must be kept when e.g. compute MemRef strides under
 /// rank-reducing operations. Return std::nullopt if reducedShape cannot be
 /// obtained by dropping only `1` entries in `originalShape`.
+/// If `matchDynamic` is true, then dynamic dims in `originalShape` and
+/// `reducedShape` will be considered matching with non-dynamic dims, unless
+/// the non-dynamic dim is from `originalShape` and equal to 1. For example,
+/// in ([1, 3, ?], [?, 5]), the mask would be {1, 0, 0}, since 3 and 5 will
+/// match with the corresponding dynamic dims.
 std::optional<llvm::SmallDenseSet<unsigned>>
 computeRankReductionMask(ArrayRef<int64_t> originalShape,
-                         ArrayRef<int64_t> reducedShape);
+                         ArrayRef<int64_t> reducedShape,
+                         bool matchDynamic = false);
 
 /// Enum that captures information related to verifier error conditions on
 /// slice insert/extract type of ops.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4c65045084dc5f..d560c11464f1c1 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2711,15 +2711,38 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
     auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
     if (!srcType || !dstType)
       return failure();
+
+    // The tensor.cast source could have additional static information not seen
+    // in the insert slice op static sizes, so we ignore dynamic dims when
+    // computing the rank reduction mask.
+    SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
+    auto rankReductionMask = computeRankReductionMask(
+        staticSizes, srcType.getShape(), /*matchDynamic=*/true);
+    if (!rankReductionMask.has_value())
+      return failure();
+    // Replace dimensions in the insert slice op with corresponding static dims
+    // from the cast source type. If the insert slice sizes have static dims
+    // that are not static in the tensor.cast source (i.e., when the cast op
+    // casts a dynamic dim to static), the dim should not be replaced, and the
+    // pattern will fail later in `verifyInsertSliceOp`.
+    SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
+    int64_t rankReducedIdx = 0;
+    for (auto [idx, size] : enumerate(staticSizes)) {
+      if (!rankReductionMask.value().contains(idx) &&
+          !srcType.isDynamicDim(rankReducedIdx)) {
+        mixedSizes[idx] = getAsIndexOpFoldResult(
+            rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
+        size = srcType.getDimSize(rankReducedIdx++);
+      }
+    }
     if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
-                            insertSliceOp.getStaticSizes(),
-                            insertSliceOp.getStaticStrides()) !=
+                            staticSizes, insertSliceOp.getStaticStrides()) !=
         SliceVerificationResult::Success)
       return failure();
 
     Operation *replacement = rewriter.create<InsertOpTy>(
         insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
-        insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+        mixedSizes, insertSliceOp.getMixedStrides());
 
     // In the parallel case there is no result and so nothing to cast.
     bool isParallelInsert =
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index a2738946de410e..179797cb943a1a 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -408,24 +408,24 @@ unsigned BaseMemRefType::getMemorySpaceAsInt() const {
 // MemRefType
 //===----------------------------------------------------------------------===//
 
-/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
-/// `originalShape` with some `1` entries erased, return the set of indices
-/// that specifies which of the entries of `originalShape` are dropped to obtain
-/// `reducedShape`. The returned mask can be applied as a projection to
-/// `originalShape` to obtain the `reducedShape`. This mask is useful to track
-/// which dimensions must be kept when e.g. compute MemRef strides under
-/// rank-reducing operations. Return std::nullopt if reducedShape cannot be
-/// obtained by dropping only `1` entries in `originalShape`.
 std::optional<llvm::SmallDenseSet<unsigned>>
 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
-                               ArrayRef<int64_t> reducedShape) {
+                               ArrayRef<int64_t> reducedShape,
+                               bool matchDynamic) {
   size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
   llvm::SmallDenseSet<unsigned> unusedDims;
   unsigned reducedIdx = 0;
   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
     // Greedily insert `originalIdx` if match.
-    if (reducedIdx < reducedRank &&
-        originalShape[originalIdx] == reducedShape[reducedIdx]) {
+    int64_t origSize = originalShape[originalIdx];
+    // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
+    if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
+        (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
+         ShapedType::isDynamic(origSize))) {
+      reducedIdx++;
+      continue;
+    }
+    if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
       reducedIdx++;
       continue;
     }
@@ -433,7 +433,7 @@ mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
     unusedDims.insert(originalIdx);
     // If no match on `originalIdx`, the `originalShape` at this dimension
     // must be 1, otherwise we bail.
-    if (originalShape[originalIdx] != 1)
+    if (origSize != 1)
       return std::nullopt;
   }
   // The whole reducedShape must be scanned, otherwise we bail.
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 6177fe3c752c93..53c8a65d39e633 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1890,14 +1890,13 @@ func.func @splat_dynamic_no_fold(%m: index) -> tensor<4x?xf32> {
 
 // -----
 
-// There was an issue in cast + insert_slice folding generating invalid ir.
-// https://github.com/llvm/llvm-project/issues/53099
 // CHECK-LABEL: func @insert_slice_cast
 func.func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
-  // CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<1x?xf32> to tensor<?x?xf32>
+  // CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32>
   %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
-  // CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]]
-  // CHECK-SAME: : tensor<?x?xf32> into tensor<?x?xf32>
+  // CHECK: %[[RES:.*]] = tensor.insert_slice %[[ARG0]]
+  // CHECK-SAME: [{{.*}}, {{.*}}] [1, {{.*}}] [{{.*}}, {{.*}}]
+  // CHECK-SAME: : tensor<1x?xf32> into tensor<?x?xf32>
   %1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32>
   // CHECK: return %[[RES]] : tensor<?x?xf32>
   return %1 : tensor<?x?xf32>

@Max191 Max191 requested a review from Hardcode84 May 7, 2024 16:01
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the improvement!

// The tensor.cast source could have additional static information not seen
// in the insert slice op static sizes, so we ignore dynamic dims when
// computing the rank reduction mask.
SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The values won't be modified, how about using ArrayRef<int64_t> here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They actually do get modified later on L2735.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, somehow I missed it when I was searching the use of the variable.

@Max191 Max191 merged commit 7e35a9a into llvm:main May 8, 2024
@@ -1890,21 +1918,6 @@ func.func @splat_dynamic_no_fold(%m: index) -> tensor<4x?xf32> {

// -----

// There was an issue in cast + insert_slice folding generating invalid ir.
// https://github.com/llvm/llvm-project/issues/53099
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm this test was guarding against an issue that was previously fixed.
Why is it removed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is moved to l.758. I think the revision adds the support for the case. It is generating valid IR now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:tensor mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants