Skip to content

Commit 72b0be3

Browse files
committed
Updates for CR
1 parent c0291d0 commit 72b0be3

File tree

2 files changed

+80
-32
lines changed

2 files changed

+80
-32
lines changed

mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp

Lines changed: 80 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -429,17 +429,19 @@ struct BubbleUpExpandShapeThroughExtractSlice
429429
}
430430
};
431431

432-
/// Converts `tensor.collapse_shape(tensor.extract_slice)` to
433-
/// `tensor.extract_slice(tensor.collapse_shape)`.
432+
/// Converts `tensor.extract_slice(tensor.collapse_shape)` to
433+
/// `tensor.collapse_shape(tensor.extract_slice)`.
434434
///
435-
/// For this transformation to be possible, the slice must be representable as a
436-
/// contiguous slice within each reassociation group of the src.
435+
/// For this transformation to be possible - after bubbling up, the extraction
436+
/// of the contiguous slice must be representable as a single slice obtained via
437+
/// tensor.extract_slice within each reassociation group of the src.
437438
///
438439
/// In case the size and offset extracted are static then this is possible if
439-
/// the following conditions are met:
440-
/// Let T be a tensor of shape [A0, A1, ..., An], and let S = [S0, S1, ..., Sn]
441-
/// be the shape of a desired slice. A slice of shape S can be extracted as a
442-
/// contiguous block of memory if and only if there exists an index k in {0, 1,
440+
/// the following conditions are met within each reassociation group:
441+
/// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the
442+
/// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the
443+
/// shape of a desired slice. A slice of shape S can be extracted as a
444+
/// contiguous span of elements if and only if there exists an index k in {0, 1,
443445
/// ..., n} such that:
444446
/// S_i = 1 for all i < k (that is, all leading dimensions are singleton),
445447
/// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
@@ -475,6 +477,31 @@ struct BubbleUpExpandShapeThroughExtractSlice
475477
/// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
476478
/// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
477479
/// ```
480+
///
481+
/// Negative example:
482+
/// The transformation is not possible because we cannot use a single slice to
483+
/// represent the reassociation group [2x3x10->???]. If we would want the
484+
/// collapse to be after the extraction, we would need to extract multiple
485+
/// slices and concat them together.
486+
/// ```
487+
/// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into
488+
/// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] :
489+
/// tensor<60xf32> to tensor<15xf32>
490+
/// ```
491+
/// If we would want the collapse to be after the extraction, a possible
492+
/// alternate transformation could be to extract multiple slices and concat them
493+
/// together:
494+
/// ```
495+
/// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] :
496+
/// tensor<2x3x10xf32> to tensor <1x1x10xf32>
497+
/// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] :
498+
/// tensor<2x3x10xf32> to tensor <1x1x5xf32>
499+
/// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} :
500+
/// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32>
501+
/// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32>
502+
/// to tensor<15xf32>
503+
/// ```
504+
/// But this is not the intended purpose of the transformation.
478505
struct BubbleUpCollapseShapeThroughExtractSlice
479506
: public OpRewritePattern<tensor::ExtractSliceOp> {
480507
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
@@ -552,47 +579,69 @@ struct BubbleUpCollapseShapeThroughExtractSlice
552579
// Case #2 = size and offset are static.
553580
// Verify that the slice can be represented as a contiguous slice of the
554581
// src of the collapse_shape.
555-
// Checking this must be done on order of most
556-
// internal dimensions first, so traversal is done in reverse order of the
557-
// reassociation group.
582+
// Checking this is done on order of most internal dimensions first,
583+
// so traversal is done in reverse order of the reassociation group.
584+
// If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
585+
// ...,An] then we first find the size and offset for n...k+1 then for k
586+
// and then for k-1...0.
558587
int64_t collapsedSizeValue = getConstantIntValue(collapsedSize).value();
559588
int64_t collapsedOffsetValue =
560589
getConstantIntValue(collapsedOffset).value();
561590

562591
SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
563592

564-
for (int64_t expandedShapeIdx : llvm::reverse(reassocIndices)) {
565-
int64_t expandedShapeSize = srcShape[expandedShapeIdx];
593+
ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
594+
reassocIndices.rend());
595+
int64_t idx = 0;
596+
int64_t reassocGroupSize = reassocIndices.size();
597+
598+
// First handle the trailing dimensions where the slice size should be
599+
// equal to the tensor shape and the offset should be 0 (n...k+1).
600+
for (; idx < reassocGroupSize; ++idx) {
601+
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
566602

567-
// This is a dimension that slicing will occur on, so need to make sure
568-
// that the slice size can be set to the shape size and the offset to 0.
569-
if (collapsedSizeValue >= expandedShapeSize &&
570-
(collapsedSizeValue % expandedShapeSize != 0 ||
571-
collapsedOffsetValue % expandedShapeSize != 0)) {
603+
if (collapsedSizeValue < expandedShapeSize)
604+
break;
605+
606+
// We need to make sure that the slice size can be set to the shape size
607+
// and the offset to 0.
608+
if ((collapsedSizeValue % expandedShapeSize) != 0 ||
609+
(collapsedOffsetValue % expandedShapeSize) != 0)
572610
return rewriter.notifyMatchFailure(
573611
sliceOp, "unsupported: cannot be extracted as a contiguous slice "
574612
"of the src of the collapse_shape");
575-
}
576613

577-
int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
614+
groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
615+
groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
616+
617+
collapsedSizeValue /= expandedShapeSize;
618+
collapsedOffsetValue /= expandedShapeSize;
619+
}
578620

579-
// This is the dimension that slicing will occur along, so need to make
580-
// sure that the slice size + offset will not exceed the shape size.
581-
if (collapsedSizeValue < expandedShapeSize &&
582-
(collapsedSizeValue + offsetInDim) >= expandedShapeSize) {
621+
// Now handle the first dim where slicing occurs on (k).
622+
if (idx < reassocGroupSize) {
623+
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
624+
int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
625+
// We need to make sure that the slice size in this dim + offset will
626+
// not exceed the shape size.
627+
if ((collapsedSizeValue + offsetInDim) >= expandedShapeSize)
583628
return rewriter.notifyMatchFailure(
584629
sliceOp, "unsupported: slice cannot be extracted as a contiguous "
585630
"slice of the src of the collapse_shape");
586-
}
587631

588-
groupExpandedSizes.push_back(rewriter.getIndexAttr(
589-
std::min(collapsedSizeValue, expandedShapeSize)));
632+
groupExpandedSizes.push_back(rewriter.getIndexAttr(collapsedSizeValue));
590633
groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
591634

592-
// Remove the size and offset of trailing dimensions from the size and
593-
// offset of the slice.
594-
collapsedSizeValue /= expandedShapeSize;
595-
collapsedSizeValue = std::max<int64_t>(collapsedSizeValue, 1);
635+
collapsedOffsetValue /= expandedShapeSize;
636+
}
637+
638+
// Now handle the leading dimensions where the slice size is equal to 1
639+
// (k-1...0).
640+
for (idx++; idx < reassocGroupSize; ++idx) {
641+
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
642+
int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
643+
groupExpandedSizes.push_back(rewriter.getIndexAttr(1));
644+
groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
596645
collapsedOffsetValue /= expandedShapeSize;
597646
}
598647

mlir/test/Dialect/Linalg/transform-op-fuse.mlir

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,6 @@ module attributes {transform.with_named_sequence} {
462462
}
463463
}
464464

465-
466465
// -----
467466

468467
// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(

0 commit comments

Comments
 (0)