@@ -429,17 +429,19 @@ struct BubbleUpExpandShapeThroughExtractSlice
429
429
}
430
430
};
431
431
432
- // / Converts `tensor.collapse_shape (tensor.extract_slice )` to
433
- // / `tensor.extract_slice (tensor.collapse_shape )`.
432
+ // / Converts `tensor.extract_slice (tensor.collapse_shape )` to
433
+ // / `tensor.collapse_shape (tensor.extract_slice )`.
434
434
// /
435
- // / For this transformation to be possible, the slice must be representable as a
436
- // / contiguous slice within each reassociation group of the src.
435
+ // / For this transformation to be possible - after bubbling up, the extraction
436
+ // / of the contiguous slice must be representable as a single slice obtained via
437
+ // / tensor.extract_slice within each reassociation group of the src.
437
438
// /
438
439
// / In case the size and offset extracted are static then this is possible if
439
- // / the following conditions are met:
440
- // / Let T be a tensor of shape [A0, A1, ..., An], and let S = [S0, S1, ..., Sn]
441
- // / be the shape of a desired slice. A slice of shape S can be extracted as a
442
- // / contiguous block of memory if and only if there exists an index k in {0, 1,
440
+ // / the following conditions are met within each reassociation group:
441
+ // / Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the
442
+ // / dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the
443
+ // / shape of a desired slice. A slice of shape S can be extracted as a
444
+ // / contiguous span of elements if and only if there exists an index k in {0, 1,
443
445
// / ..., n} such that:
444
446
// / S_i = 1 for all i < k (that is, all leading dimensions are singleton),
445
447
// / 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
@@ -475,6 +477,31 @@ struct BubbleUpExpandShapeThroughExtractSlice
475
477
// / %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
476
478
// / tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
477
479
// / ```
480
+ // /
481
+ // / Negative example:
482
+ // / The transformation is not possible because we cannot use a single slice to
483
+ // / represent the reassociation group [2x3x10->???]. If we would want the
484
+ // / collapse to be after the extraction, we would need to extract multiple
485
+ // / slices and concat them together.
486
+ // / ```
487
+ // / %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into
488
+ // / tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] :
489
+ // / tensor<60xf32> to tensor<15xf32>
490
+ // / ```
491
+ // / If we would want the collapse to be after the extraction, a possible
492
+ // / alternate transformation could be to extract multiple slices and concat them
493
+ // / together:
494
+ // / ```
495
+ // / %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] :
496
+ // / tensor<2x3x10xf32> to tensor <1x1x10xf32>
497
+ // / %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] :
498
+ // / tensor<2x3x10xf32> to tensor <1x1x5xf32>
499
+ // / %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} :
500
+ // / (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32>
501
+ // / %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32>
502
+ // / to tensor<15xf32>
503
+ // / ```
504
+ // / But this is not the intended purpose of the transformation.
478
505
struct BubbleUpCollapseShapeThroughExtractSlice
479
506
: public OpRewritePattern<tensor::ExtractSliceOp> {
480
507
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
@@ -552,47 +579,69 @@ struct BubbleUpCollapseShapeThroughExtractSlice
552
579
// Case #2 = size and offset are static.
553
580
// Verify that the slice can be represented as a contiguous slice of the
554
581
// src of the collapse_shape.
555
- // Checking this must be done on order of most
556
- // internal dimensions first, so traversal is done in reverse order of the
557
- // reassociation group.
582
+ // Checking this is done on order of most internal dimensions first,
583
+ // so traversal is done in reverse order of the reassociation group.
584
+ // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
585
+ // ...,An] then we first find the size and offset for n...k+1 then for k
586
+ // and then for k-1...0.
558
587
int64_t collapsedSizeValue = getConstantIntValue (collapsedSize).value ();
559
588
int64_t collapsedOffsetValue =
560
589
getConstantIntValue (collapsedOffset).value ();
561
590
562
591
SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
563
592
564
- for (int64_t expandedShapeIdx : llvm::reverse (reassocIndices)) {
565
- int64_t expandedShapeSize = srcShape[expandedShapeIdx];
593
+ ReassociationIndices reversedReassocIndices (reassocIndices.rbegin (),
594
+ reassocIndices.rend ());
595
+ int64_t idx = 0 ;
596
+ int64_t reassocGroupSize = reassocIndices.size ();
597
+
598
+ // First handle the trailing dimensions where the slice size should be
599
+ // equal to the tensor shape and the offset should be 0 (n...k+1).
600
+ for (; idx < reassocGroupSize; ++idx) {
601
+ int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
566
602
567
- // This is a dimension that slicing will occur on, so need to make sure
568
- // that the slice size can be set to the shape size and the offset to 0.
569
- if (collapsedSizeValue >= expandedShapeSize &&
570
- (collapsedSizeValue % expandedShapeSize != 0 ||
571
- collapsedOffsetValue % expandedShapeSize != 0 )) {
603
+ if (collapsedSizeValue < expandedShapeSize)
604
+ break ;
605
+
606
+ // We need to make sure that the slice size can be set to the shape size
607
+ // and the offset to 0.
608
+ if ((collapsedSizeValue % expandedShapeSize) != 0 ||
609
+ (collapsedOffsetValue % expandedShapeSize) != 0 )
572
610
return rewriter.notifyMatchFailure (
573
611
sliceOp, " unsupported: cannot be extracted as a contiguous slice "
574
612
" of the src of the collapse_shape" );
575
- }
576
613
577
- int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
614
+ groupExpandedSizes.push_back (rewriter.getIndexAttr (expandedShapeSize));
615
+ groupExpandedOffsets.push_back (rewriter.getIndexAttr (0 ));
616
+
617
+ collapsedSizeValue /= expandedShapeSize;
618
+ collapsedOffsetValue /= expandedShapeSize;
619
+ }
578
620
579
- // This is the dimension that slicing will occur along, so need to make
580
- // sure that the slice size + offset will not exceed the shape size.
581
- if (collapsedSizeValue < expandedShapeSize &&
582
- (collapsedSizeValue + offsetInDim) >= expandedShapeSize) {
621
+ // Now handle the first dim where slicing occurs on (k).
622
+ if (idx < reassocGroupSize) {
623
+ int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
624
+ int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
625
+ // We need to make sure that the slice size in this dim + offset will
626
+ // not exceed the shape size.
627
+ if ((collapsedSizeValue + offsetInDim) >= expandedShapeSize)
583
628
return rewriter.notifyMatchFailure (
584
629
sliceOp, " unsupported: slice cannot be extracted as a contiguous "
585
630
" slice of the src of the collapse_shape" );
586
- }
587
631
588
- groupExpandedSizes.push_back (rewriter.getIndexAttr (
589
- std::min (collapsedSizeValue, expandedShapeSize)));
632
+ groupExpandedSizes.push_back (rewriter.getIndexAttr (collapsedSizeValue));
590
633
groupExpandedOffsets.push_back (rewriter.getIndexAttr (offsetInDim));
591
634
592
- // Remove the size and offset of trailing dimensions from the size and
593
- // offset of the slice.
594
- collapsedSizeValue /= expandedShapeSize;
595
- collapsedSizeValue = std::max<int64_t >(collapsedSizeValue, 1 );
635
+ collapsedOffsetValue /= expandedShapeSize;
636
+ }
637
+
638
+ // Now handle the leading dimensions where the slice size is equal to 1
639
+ // (k-1...0).
640
+ for (idx++; idx < reassocGroupSize; ++idx) {
641
+ int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
642
+ int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
643
+ groupExpandedSizes.push_back (rewriter.getIndexAttr (1 ));
644
+ groupExpandedOffsets.push_back (rewriter.getIndexAttr (offsetInDim));
596
645
collapsedOffsetValue /= expandedShapeSize;
597
646
}
598
647
0 commit comments