Skip to content

Commit 7e35a9a

Browse files
authored
[mlir] Replace dynamic sizes in insert_slice of tensor.cast canonicalization (llvm#91352)
In some cases this pattern may ignore static information due to dynamic operands in the insert_slice sizes operands, e.g.: ``` %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32> %1 = tensor.insert_slice %0 into %arg1[...] [%s0, %s1] [...] : tensor<?x?xf32> into tensor<?x?xf32> ``` Can be rewritten into: ``` %1 = tensor.insert_slice %arg0 into %arg1[...] [1, %s1] [...] : tensor<1x?xf32> into tensor<?x?xf32> ``` This PR updates the matching in the pattern to allow rewrites like this.
1 parent 2f956a3 commit 7e35a9a

File tree

4 files changed

+73
-31
lines changed

4 files changed

+73
-31
lines changed

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,9 +360,15 @@ class VectorType::Builder {
360360
/// which dimensions must be kept when e.g. compute MemRef strides under
361361
/// rank-reducing operations. Return std::nullopt if reducedShape cannot be
362362
/// obtained by dropping only `1` entries in `originalShape`.
363+
/// If `matchDynamic` is true, then dynamic dims in `originalShape` and
364+
/// `reducedShape` will be considered matching with non-dynamic dims, unless
365+
/// the non-dynamic dim is from `originalShape` and equal to 1. For example,
366+
/// in ([1, 3, ?], [?, 5]), the mask would be {1, 0, 0}, since 3 and 5 will
367+
/// match with the corresponding dynamic dims.
363368
std::optional<llvm::SmallDenseSet<unsigned>>
364369
computeRankReductionMask(ArrayRef<int64_t> originalShape,
365-
ArrayRef<int64_t> reducedShape);
370+
ArrayRef<int64_t> reducedShape,
371+
bool matchDynamic = false);
366372

367373
/// Enum that captures information related to verifier error conditions on
368374
/// slice insert/extract type of ops.

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2713,15 +2713,38 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
27132713
auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
27142714
if (!srcType || !dstType)
27152715
return failure();
2716+
2717+
// The tensor.cast source could have additional static information not seen
2718+
// in the insert slice op static sizes, so we ignore dynamic dims when
2719+
// computing the rank reduction mask.
2720+
SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
2721+
auto rankReductionMask = computeRankReductionMask(
2722+
staticSizes, srcType.getShape(), /*matchDynamic=*/true);
2723+
if (!rankReductionMask.has_value())
2724+
return failure();
2725+
// Replace dimensions in the insert slice op with corresponding static dims
2726+
// from the cast source type. If the insert slice sizes have static dims
2727+
// that are not static in the tensor.cast source (i.e., when the cast op
2728+
// casts a dynamic dim to static), the dim should not be replaced, and the
2729+
// pattern will fail later in `verifyInsertSliceOp`.
2730+
SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2731+
int64_t rankReducedIdx = 0;
2732+
for (auto [idx, size] : enumerate(staticSizes)) {
2733+
if (!rankReductionMask.value().contains(idx) &&
2734+
!srcType.isDynamicDim(rankReducedIdx)) {
2735+
mixedSizes[idx] = getAsIndexOpFoldResult(
2736+
rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
2737+
size = srcType.getDimSize(rankReducedIdx++);
2738+
}
2739+
}
27162740
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
2717-
insertSliceOp.getStaticSizes(),
2718-
insertSliceOp.getStaticStrides()) !=
2741+
staticSizes, insertSliceOp.getStaticStrides()) !=
27192742
SliceVerificationResult::Success)
27202743
return failure();
27212744

27222745
Operation *replacement = rewriter.create<InsertOpTy>(
27232746
insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2724-
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
2747+
mixedSizes, insertSliceOp.getMixedStrides());
27252748

27262749
// In the parallel case there is no result and so nothing to cast.
27272750
bool isParallelInsert =

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -408,32 +408,32 @@ unsigned BaseMemRefType::getMemorySpaceAsInt() const {
408408
// MemRefType
409409
//===----------------------------------------------------------------------===//
410410

411-
/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
412-
/// `originalShape` with some `1` entries erased, return the set of indices
413-
/// that specifies which of the entries of `originalShape` are dropped to obtain
414-
/// `reducedShape`. The returned mask can be applied as a projection to
415-
/// `originalShape` to obtain the `reducedShape`. This mask is useful to track
416-
/// which dimensions must be kept when e.g. compute MemRef strides under
417-
/// rank-reducing operations. Return std::nullopt if reducedShape cannot be
418-
/// obtained by dropping only `1` entries in `originalShape`.
419411
std::optional<llvm::SmallDenseSet<unsigned>>
420412
mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
421-
ArrayRef<int64_t> reducedShape) {
413+
ArrayRef<int64_t> reducedShape,
414+
bool matchDynamic) {
422415
size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
423416
llvm::SmallDenseSet<unsigned> unusedDims;
424417
unsigned reducedIdx = 0;
425418
for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
426419
// Greedily insert `originalIdx` if match.
427-
if (reducedIdx < reducedRank &&
428-
originalShape[originalIdx] == reducedShape[reducedIdx]) {
420+
int64_t origSize = originalShape[originalIdx];
421+
// if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
422+
if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
423+
(ShapedType::isDynamic(reducedShape[reducedIdx]) ||
424+
ShapedType::isDynamic(origSize))) {
425+
reducedIdx++;
426+
continue;
427+
}
428+
if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
429429
reducedIdx++;
430430
continue;
431431
}
432432

433433
unusedDims.insert(originalIdx);
434434
// If no match on `originalIdx`, the `originalShape` at this dimension
435435
// must be 1, otherwise we bail.
436-
if (originalShape[originalIdx] != 1)
436+
if (origSize != 1)
437437
return std::nullopt;
438438
}
439439
// The whole reducedShape must be scanned, otherwise we bail.

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,34 @@ func.func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
755755

756756
// -----
757757

758+
// CHECK-LABEL: func @insert_slice_cast
759+
func.func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
760+
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32>
761+
%0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
762+
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[ARG0]]
763+
// CHECK-SAME: [{{.*}}, {{.*}}] [1, {{.*}}] [{{.*}}, {{.*}}]
764+
// CHECK-SAME: : tensor<1x?xf32> into tensor<?x?xf32>
765+
%1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32>
766+
// CHECK: return %[[RES]] : tensor<?x?xf32>
767+
return %1 : tensor<?x?xf32>
768+
}
769+
770+
// -----
771+
772+
// CHECK-LABEL: func @insert_slice_cast_no_fold
773+
func.func @insert_slice_cast_no_fold(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
774+
%0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x5xf32>
775+
// CHECK: %[[CAST:.*]] = tensor.cast
776+
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]]
777+
// CHECK-SAME: [{{.*}}, {{.*}}] [{{.*}}, 5] [{{.*}}, {{.*}}]
778+
// CHECK-SAME: : tensor<?x5xf32> into tensor<?x?xf32>
779+
%1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, 5] [%arg6, %arg7] : tensor<?x5xf32> into tensor<?x?xf32>
780+
// CHECK: return %[[RES]] : tensor<?x?xf32>
781+
return %1 : tensor<?x?xf32>
782+
}
783+
784+
// -----
785+
758786
// CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src(
759787
// CHECK-SAME: %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32>
760788
// CHECK: %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x64xf32>
@@ -1890,21 +1918,6 @@ func.func @splat_dynamic_no_fold(%m: index) -> tensor<4x?xf32> {
18901918

18911919
// -----
18921920

1893-
// There was an issue in cast + insert_slice folding generating invalid ir.
1894-
// https://github.com/llvm/llvm-project/issues/53099
1895-
// CHECK-LABEL: func @insert_slice_cast
1896-
func.func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
1897-
// CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<1x?xf32> to tensor<?x?xf32>
1898-
%0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
1899-
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]]
1900-
// CHECK-SAME: : tensor<?x?xf32> into tensor<?x?xf32>
1901-
%1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32>
1902-
// CHECK: return %[[RES]] : tensor<?x?xf32>
1903-
return %1 : tensor<?x?xf32>
1904-
}
1905-
1906-
// -----
1907-
19081921
// CHECK-LABEL: func @cast_extract_slice
19091922
func.func @cast_extract_slice(%arg0 : tensor<128x512xf32>, %s : index, %o : index)
19101923
-> tensor<16x512xf32> {

0 commit comments

Comments
 (0)