-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Replace dynamic sizes in insert_slice of tensor.cast canonicalization #91352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: None (Max191) ChangesIn some cases this pattern may ignore static information due to dynamic operands in the insert_slice sizes operands, e.g.:
Can be rewritten into:
This PR updates the matching in the pattern to allow rewrites like this. Full diff: https://github.com/llvm/llvm-project/pull/91352.diff 4 Files Affected:
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 2361cf1371237b..5579b138668d2b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -360,9 +360,15 @@ class VectorType::Builder {
/// which dimensions must be kept when e.g. compute MemRef strides under
/// rank-reducing operations. Return std::nullopt if reducedShape cannot be
/// obtained by dropping only `1` entries in `originalShape`.
+/// If `matchDynamic` is true, then dynamic dims in `originalShape` and
+/// `reducedShape` will be considered matching with non-dynamic dims, unless
+/// the non-dynamic dim is from `originalShape` and equal to 1. For example,
+/// in ([1, 3, ?], [?, 5]), the mask would be {1, 0, 0}, since 3 and 5 will
+/// match with the corresponding dynamic dims.
std::optional<llvm::SmallDenseSet<unsigned>>
computeRankReductionMask(ArrayRef<int64_t> originalShape,
- ArrayRef<int64_t> reducedShape);
+ ArrayRef<int64_t> reducedShape,
+ bool matchDynamic = false);
/// Enum that captures information related to verifier error conditions on
/// slice insert/extract type of ops.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4c65045084dc5f..d560c11464f1c1 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2711,15 +2711,38 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
if (!srcType || !dstType)
return failure();
+
+ // The tensor.cast source could have additional static information not seen
+ // in the insert slice op static sizes, so we ignore dynamic dims when
+ // computing the rank reduction mask.
+ SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
+ auto rankReductionMask = computeRankReductionMask(
+ staticSizes, srcType.getShape(), /*matchDynamic=*/true);
+ if (!rankReductionMask.has_value())
+ return failure();
+ // Replace dimensions in the insert slice op with corresponding static dims
+ // from the cast source type. If the insert slice sizes have static dims
+ // that are not static in the tensor.cast source (i.e., when the cast op
+ // casts a dynamic dim to static), the dim should not be replaced, and the
+ // pattern will fail later in `verifyInsertSliceOp`.
+ SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
+ int64_t rankReducedIdx = 0;
+ for (auto [idx, size] : enumerate(staticSizes)) {
+ if (!rankReductionMask.value().contains(idx) &&
+ !srcType.isDynamicDim(rankReducedIdx)) {
+ mixedSizes[idx] = getAsIndexOpFoldResult(
+ rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
+ size = srcType.getDimSize(rankReducedIdx++);
+ }
+ }
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
- insertSliceOp.getStaticSizes(),
- insertSliceOp.getStaticStrides()) !=
+ staticSizes, insertSliceOp.getStaticStrides()) !=
SliceVerificationResult::Success)
return failure();
Operation *replacement = rewriter.create<InsertOpTy>(
insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
- insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+ mixedSizes, insertSliceOp.getMixedStrides());
// In the parallel case there is no result and so nothing to cast.
bool isParallelInsert =
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index a2738946de410e..179797cb943a1a 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -408,24 +408,24 @@ unsigned BaseMemRefType::getMemorySpaceAsInt() const {
// MemRefType
//===----------------------------------------------------------------------===//
-/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
-/// `originalShape` with some `1` entries erased, return the set of indices
-/// that specifies which of the entries of `originalShape` are dropped to obtain
-/// `reducedShape`. The returned mask can be applied as a projection to
-/// `originalShape` to obtain the `reducedShape`. This mask is useful to track
-/// which dimensions must be kept when e.g. compute MemRef strides under
-/// rank-reducing operations. Return std::nullopt if reducedShape cannot be
-/// obtained by dropping only `1` entries in `originalShape`.
std::optional<llvm::SmallDenseSet<unsigned>>
mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
- ArrayRef<int64_t> reducedShape) {
+ ArrayRef<int64_t> reducedShape,
+ bool matchDynamic) {
size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
llvm::SmallDenseSet<unsigned> unusedDims;
unsigned reducedIdx = 0;
for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
// Greedily insert `originalIdx` if match.
- if (reducedIdx < reducedRank &&
- originalShape[originalIdx] == reducedShape[reducedIdx]) {
+ int64_t origSize = originalShape[originalIdx];
+ // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
+ if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
+ (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
+ ShapedType::isDynamic(origSize))) {
+ reducedIdx++;
+ continue;
+ }
+ if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
reducedIdx++;
continue;
}
@@ -433,7 +433,7 @@ mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
unusedDims.insert(originalIdx);
// If no match on `originalIdx`, the `originalShape` at this dimension
// must be 1, otherwise we bail.
- if (originalShape[originalIdx] != 1)
+ if (origSize != 1)
return std::nullopt;
}
// The whole reducedShape must be scanned, otherwise we bail.
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 6177fe3c752c93..53c8a65d39e633 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1890,14 +1890,13 @@ func.func @splat_dynamic_no_fold(%m: index) -> tensor<4x?xf32> {
// -----
-// There was an issue in cast + insert_slice folding generating invalid ir.
-// https://github.com/llvm/llvm-project/issues/53099
// CHECK-LABEL: func @insert_slice_cast
func.func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
- // CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<1x?xf32> to tensor<?x?xf32>
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32>
%0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
- // CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]]
- // CHECK-SAME: : tensor<?x?xf32> into tensor<?x?xf32>
+ // CHECK: %[[RES:.*]] = tensor.insert_slice %[[ARG0]]
+ // CHECK-SAME: [{{.*}}, {{.*}}] [1, {{.*}}] [{{.*}}, {{.*}}]
+ // CHECK-SAME: : tensor<1x?xf32> into tensor<?x?xf32>
%1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32>
// CHECK: return %[[RES]] : tensor<?x?xf32>
return %1 : tensor<?x?xf32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the improvement!
// The tensor.cast source could have additional static information not seen | ||
// in the insert slice op static sizes, so we ignore dynamic dims when | ||
// computing the rank reduction mask. | ||
SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The values won't be modified, how about using ArrayRef<int64_t>
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They actually do get modified later on L2735.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, somehow I missed it when I was searching the use of the variable.
@@ -1890,21 +1918,6 @@ func.func @splat_dynamic_no_fold(%m: index) -> tensor<4x?xf32> { | |||
|
|||
// ----- | |||
|
|||
// There was an issue in cast + insert_slice folding generating invalid ir. | |||
// https://github.com/llvm/llvm-project/issues/53099 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm this test was guarding against an issue that was previously fixed.
Why is it removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is moved to l.758. I think the revision adds the support for the case. It is generating valid IR now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yes, thank you!
In some cases this pattern may ignore static information due to dynamic operands in the insert_slice sizes operands, e.g.:
Can be rewritten into:
This PR updates the matching in the pattern to allow rewrites like this.