@@ -550,7 +550,7 @@ struct BubbleUpCollapseShapeThroughExtractSlice
550
550
enumerate(collapseShapeOp.getReassociationIndices ())) {
551
551
OpFoldResult collapsedSize = collapsedSizes[groupIdx];
552
552
OpFoldResult collapsedOffset = collapsedOffsets[groupIdx];
553
- // Case #1 - size and/or offset are dynamic.
553
+ // CASE #1 - size and/or offset are dynamic.
554
554
// In this case, the slice can be represented as a contiguous slice only
555
555
// if there is a single dimension in the reassociation group that has a
556
556
// size not equal to 1.
@@ -576,16 +576,24 @@ struct BubbleUpCollapseShapeThroughExtractSlice
576
576
continue ;
577
577
}
578
578
579
- // Case #2 = size and offset are static.
579
+ // CASE #2 = size and offset are static.
580
580
// Verify that the slice can be represented as a contiguous slice of the
581
581
// src of the collapse_shape.
582
582
// Checking this is done on order of most internal dimensions first,
583
583
// so traversal is done in reverse order of the reassociation group.
584
584
// If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
585
585
// ...,An] then we first find the size and offset for n...k+1 then for k
586
586
// and then for k-1...0.
587
- int64_t collapsedSizeValue = getConstantIntValue (collapsedSize).value ();
588
- int64_t collapsedOffsetValue =
587
+
588
+ // currentCollapsedsize and currentCollapsedOffset are initialized with
589
+ // the original collapsed size and offset and divided by the expanded
590
+ // shape size in each dimension as we go along the reassociation group.
591
+ // In essence we are spreading the original collapsed size and offset over
592
+ // the various expanded slice dimensions.
593
+ // The variables are used both to check the validity of the slice and to
594
+ // compute the expanded sizes and offsets.
595
+ int64_t currentCollapsedsize = getConstantIntValue (collapsedSize).value ();
596
+ int64_t currentCollapsedOffset =
589
597
getConstantIntValue (collapsedOffset).value ();
590
598
591
599
SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
@@ -600,49 +608,55 @@ struct BubbleUpCollapseShapeThroughExtractSlice
600
608
for (; idx < reassocGroupSize; ++idx) {
601
609
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
602
610
603
- if (collapsedSizeValue < expandedShapeSize)
611
+ if (currentCollapsedsize < expandedShapeSize)
604
612
break ;
605
613
606
614
// We need to make sure that the slice size can be set to the shape size
607
615
// and the offset to 0.
608
- if ((collapsedSizeValue % expandedShapeSize) != 0 ||
609
- (collapsedOffsetValue % expandedShapeSize) != 0 )
616
+ if ((currentCollapsedsize % expandedShapeSize) != 0 ||
617
+ (currentCollapsedOffset % expandedShapeSize) != 0 )
610
618
return rewriter.notifyMatchFailure (
611
619
sliceOp, " unsupported: cannot be extracted as a contiguous slice "
612
620
" of the src of the collapse_shape" );
613
621
614
622
groupExpandedSizes.push_back (rewriter.getIndexAttr (expandedShapeSize));
615
623
groupExpandedOffsets.push_back (rewriter.getIndexAttr (0 ));
616
624
617
- collapsedSizeValue /= expandedShapeSize;
618
- collapsedOffsetValue /= expandedShapeSize;
625
+ currentCollapsedsize /= expandedShapeSize;
626
+ currentCollapsedOffset /= expandedShapeSize;
619
627
}
620
628
621
629
// Now handle the first dim where slicing occurs on (k).
622
630
if (idx < reassocGroupSize) {
623
631
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
624
- int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
632
+ int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
625
633
// We need to make sure that the slice size in this dim + offset will
626
634
// not exceed the shape size.
627
- if ((collapsedSizeValue + offsetInDim) >= expandedShapeSize)
635
+ if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize)
628
636
return rewriter.notifyMatchFailure (
629
637
sliceOp, " unsupported: slice cannot be extracted as a contiguous "
630
638
" slice of the src of the collapse_shape" );
631
639
632
- groupExpandedSizes.push_back (rewriter.getIndexAttr (collapsedSizeValue));
640
+ groupExpandedSizes.push_back (
641
+ rewriter.getIndexAttr (currentCollapsedsize));
633
642
groupExpandedOffsets.push_back (rewriter.getIndexAttr (offsetInDim));
634
643
635
- collapsedOffsetValue /= expandedShapeSize;
644
+ currentCollapsedOffset /= expandedShapeSize;
636
645
}
637
646
638
647
// Now handle the leading dimensions where the slice size is equal to 1
639
648
// (k-1...0).
649
+ // The size for these dimensions must be 1 because of how we constructed
650
+ // the slice size of the expanded shape. We spread the original collapsed
651
+ // size over the expanded shape sizes until we reached dimension k where
652
+ // the remaining size was smaller than the expanded shape size, and spread
653
+ // the remaining size on it. So, now we are left with only 1s.
640
654
for (idx++; idx < reassocGroupSize; ++idx) {
641
655
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
642
- int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
656
+ int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
643
657
groupExpandedSizes.push_back (rewriter.getIndexAttr (1 ));
644
658
groupExpandedOffsets.push_back (rewriter.getIndexAttr (offsetInDim));
645
- collapsedOffsetValue /= expandedShapeSize;
659
+ currentCollapsedOffset /= expandedShapeSize;
646
660
}
647
661
648
662
expandedSizes.append (groupExpandedSizes.rbegin (),
0 commit comments