Skip to content

[mlir] Replace dynamic sizes in insert_slice of tensor.cast canonicalization #91352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,15 @@ class VectorType::Builder {
/// which dimensions must be kept when e.g. compute MemRef strides under
/// rank-reducing operations. Return std::nullopt if reducedShape cannot be
/// obtained by dropping only `1` entries in `originalShape`.
/// If `matchDynamic` is true, then dynamic dims in `originalShape` and
/// `reducedShape` will be considered matching with non-dynamic dims, unless
/// the non-dynamic dim is from `originalShape` and equal to 1. For example,
/// in ([1, 3, ?], [?, 5]), the mask would be {1, 0, 0}, since 3 and 5 will
/// match with the corresponding dynamic dims.
std::optional<llvm::SmallDenseSet<unsigned>>
computeRankReductionMask(ArrayRef<int64_t> originalShape,
ArrayRef<int64_t> reducedShape);
ArrayRef<int64_t> reducedShape,
bool matchDynamic = false);

/// Enum that captures information related to verifier error conditions on
/// slice insert/extract type of ops.
Expand Down
29 changes: 26 additions & 3 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2711,15 +2711,38 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
if (!srcType || !dstType)
return failure();

// The tensor.cast source could have additional static information not seen
// in the insert slice op static sizes, so we ignore dynamic dims when
// computing the rank reduction mask.
SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The values won't be modified, how about using ArrayRef<int64_t> here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They actually do get modified later on L2735.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, somehow I missed it when I was searching the use of the variable.

auto rankReductionMask = computeRankReductionMask(
staticSizes, srcType.getShape(), /*matchDynamic=*/true);
if (!rankReductionMask.has_value())
return failure();
// Replace dimensions in the insert slice op with corresponding static dims
// from the cast source type. If the insert slice sizes have static dims
// that are not static in the tensor.cast source (i.e., when the cast op
// casts a dynamic dim to static), the dim should not be replaced, and the
// pattern will fail later in `verifyInsertSliceOp`.
SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
int64_t rankReducedIdx = 0;
for (auto [idx, size] : enumerate(staticSizes)) {
if (!rankReductionMask.value().contains(idx) &&
!srcType.isDynamicDim(rankReducedIdx)) {
mixedSizes[idx] = getAsIndexOpFoldResult(
rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
size = srcType.getDimSize(rankReducedIdx++);
}
}
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
insertSliceOp.getStaticSizes(),
insertSliceOp.getStaticStrides()) !=
staticSizes, insertSliceOp.getStaticStrides()) !=
SliceVerificationResult::Success)
return failure();

Operation *replacement = rewriter.create<InsertOpTy>(
insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
mixedSizes, insertSliceOp.getMixedStrides());

// In the parallel case there is no result and so nothing to cast.
bool isParallelInsert =
Expand Down
24 changes: 12 additions & 12 deletions mlir/lib/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,32 +408,32 @@ unsigned BaseMemRefType::getMemorySpaceAsInt() const {
// MemRefType
//===----------------------------------------------------------------------===//

/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
/// `originalShape` with some `1` entries erased, return the set of indices
/// that specifies which of the entries of `originalShape` are dropped to obtain
/// `reducedShape`. The returned mask can be applied as a projection to
/// `originalShape` to obtain the `reducedShape`. This mask is useful to track
/// which dimensions must be kept when e.g. compute MemRef strides under
/// rank-reducing operations. Return std::nullopt if reducedShape cannot be
/// obtained by dropping only `1` entries in `originalShape`.
std::optional<llvm::SmallDenseSet<unsigned>>
mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
ArrayRef<int64_t> reducedShape) {
ArrayRef<int64_t> reducedShape,
bool matchDynamic) {
size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
llvm::SmallDenseSet<unsigned> unusedDims;
unsigned reducedIdx = 0;
for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
// Greedily insert `originalIdx` if match.
if (reducedIdx < reducedRank &&
originalShape[originalIdx] == reducedShape[reducedIdx]) {
int64_t origSize = originalShape[originalIdx];
// if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
(ShapedType::isDynamic(reducedShape[reducedIdx]) ||
ShapedType::isDynamic(origSize))) {
reducedIdx++;
continue;
}
if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
reducedIdx++;
continue;
}

unusedDims.insert(originalIdx);
// If no match on `originalIdx`, the `originalShape` at this dimension
// must be 1, otherwise we bail.
if (originalShape[originalIdx] != 1)
if (origSize != 1)
return std::nullopt;
}
// The whole reducedShape must be scanned, otherwise we bail.
Expand Down
43 changes: 28 additions & 15 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,34 @@ func.func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {

// -----

// CHECK-LABEL: func @insert_slice_cast
func.func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32>
%0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[ARG0]]
// CHECK-SAME: [{{.*}}, {{.*}}] [1, {{.*}}] [{{.*}}, {{.*}}]
// CHECK-SAME: : tensor<1x?xf32> into tensor<?x?xf32>
%1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32>
// CHECK: return %[[RES]] : tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}

// -----

// CHECK-LABEL: func @insert_slice_cast_no_fold
func.func @insert_slice_cast_no_fold(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
%0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x5xf32>
// CHECK: %[[CAST:.*]] = tensor.cast
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]]
// CHECK-SAME: [{{.*}}, {{.*}}] [{{.*}}, 5] [{{.*}}, {{.*}}]
// CHECK-SAME: : tensor<?x5xf32> into tensor<?x?xf32>
%1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, 5] [%arg6, %arg7] : tensor<?x5xf32> into tensor<?x?xf32>
// CHECK: return %[[RES]] : tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}

// -----

// CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src(
// CHECK-SAME: %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32>
// CHECK: %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x64xf32>
Expand Down Expand Up @@ -1890,21 +1918,6 @@ func.func @splat_dynamic_no_fold(%m: index) -> tensor<4x?xf32> {

// -----

// There was an issue in cast + insert_slice folding generating invalid ir.
// https://github.com/llvm/llvm-project/issues/53099
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm this test was guarding against an issue that was previously fixed.
Why is it removed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is moved to l.758. I think the revision adds the support for the case. It is generating valid IR now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes, thank you!

// CHECK-LABEL: func @insert_slice_cast
func.func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
// CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<1x?xf32> to tensor<?x?xf32>
%0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]]
// CHECK-SAME: : tensor<?x?xf32> into tensor<?x?xf32>
%1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32>
// CHECK: return %[[RES]] : tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}

// -----

// CHECK-LABEL: func @cast_extract_slice
func.func @cast_extract_slice(%arg0 : tensor<128x512xf32>, %s : index, %o : index)
-> tensor<16x512xf32> {
Expand Down