Skip to content

Commit 5845db6

Browse files
committed
Additional clarifications
1 parent 1aaf3c9 commit 5845db6

File tree

1 file changed

+29
-15
lines changed

1 file changed

+29
-15
lines changed

mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ struct BubbleUpCollapseShapeThroughExtractSlice
550550
enumerate(collapseShapeOp.getReassociationIndices())) {
551551
OpFoldResult collapsedSize = collapsedSizes[groupIdx];
552552
OpFoldResult collapsedOffset = collapsedOffsets[groupIdx];
553-
// Case #1 - size and/or offset are dynamic.
553+
// CASE #1 - size and/or offset are dynamic.
554554
// In this case, the slice can be represented as a contiguous slice only
555555
// if there is a single dimension in the reassociation group that has a
556556
// size not equal to 1.
@@ -576,16 +576,24 @@ struct BubbleUpCollapseShapeThroughExtractSlice
576576
continue;
577577
}
578578

579-
// Case #2 = size and offset are static.
579+
// CASE #2 = size and offset are static.
580580
// Verify that the slice can be represented as a contiguous slice of the
581581
// src of the collapse_shape.
582582
// Checking this is done on order of most internal dimensions first,
583583
// so traversal is done in reverse order of the reassociation group.
584584
// If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
585585
// ...,An] then we first find the size and offset for n...k+1 then for k
586586
// and then for k-1...0.
587-
int64_t collapsedSizeValue = getConstantIntValue(collapsedSize).value();
588-
int64_t collapsedOffsetValue =
587+
588+
// currentCollapsedsize and currentCollapsedOffset are initialized with
589+
// the original collapsed size and offset and divided by the expanded
590+
// shape size in each dimension as we go along the reassociation group.
591+
// In essence we are spreading the original collapsed size and offset over
592+
// the various expanded slice dimensions.
593+
// The variables are used both to check the validity of the slice and to
594+
// compute the expanded sizes and offsets.
595+
int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
596+
int64_t currentCollapsedOffset =
589597
getConstantIntValue(collapsedOffset).value();
590598

591599
SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
@@ -600,49 +608,55 @@ struct BubbleUpCollapseShapeThroughExtractSlice
600608
for (; idx < reassocGroupSize; ++idx) {
601609
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
602610

603-
if (collapsedSizeValue < expandedShapeSize)
611+
if (currentCollapsedsize < expandedShapeSize)
604612
break;
605613

606614
// We need to make sure that the slice size can be set to the shape size
607615
// and the offset to 0.
608-
if ((collapsedSizeValue % expandedShapeSize) != 0 ||
609-
(collapsedOffsetValue % expandedShapeSize) != 0)
616+
if ((currentCollapsedsize % expandedShapeSize) != 0 ||
617+
(currentCollapsedOffset % expandedShapeSize) != 0)
610618
return rewriter.notifyMatchFailure(
611619
sliceOp, "unsupported: cannot be extracted as a contiguous slice "
612620
"of the src of the collapse_shape");
613621

614622
groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
615623
groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
616624

617-
collapsedSizeValue /= expandedShapeSize;
618-
collapsedOffsetValue /= expandedShapeSize;
625+
currentCollapsedsize /= expandedShapeSize;
626+
currentCollapsedOffset /= expandedShapeSize;
619627
}
620628

621629
// Now handle the first dim where slicing occurs on (k).
622630
if (idx < reassocGroupSize) {
623631
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
624-
int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
632+
int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
625633
// We need to make sure that the slice size in this dim + offset will
626634
// not exceed the shape size.
627-
if ((collapsedSizeValue + offsetInDim) >= expandedShapeSize)
635+
if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize)
628636
return rewriter.notifyMatchFailure(
629637
sliceOp, "unsupported: slice cannot be extracted as a contiguous "
630638
"slice of the src of the collapse_shape");
631639

632-
groupExpandedSizes.push_back(rewriter.getIndexAttr(collapsedSizeValue));
640+
groupExpandedSizes.push_back(
641+
rewriter.getIndexAttr(currentCollapsedsize));
633642
groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
634643

635-
collapsedOffsetValue /= expandedShapeSize;
644+
currentCollapsedOffset /= expandedShapeSize;
636645
}
637646

638647
// Now handle the leading dimensions where the slice size is equal to 1
639648
// (k-1...0).
649+
// The size for these dimensions must be 1 because of how we constructed
650+
// the slice size of the expanded shape. We spread the original collapsed
651+
// size over the expanded shape sizes until we reached dimension k where
652+
// the remaining size was smaller than the expanded shape size, and spread
653+
// the remaining size on it. So, now we are left with only 1s.
640654
for (idx++; idx < reassocGroupSize; ++idx) {
641655
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
642-
int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
656+
int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
643657
groupExpandedSizes.push_back(rewriter.getIndexAttr(1));
644658
groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
645-
collapsedOffsetValue /= expandedShapeSize;
659+
currentCollapsedOffset /= expandedShapeSize;
646660
}
647661

648662
expandedSizes.append(groupExpandedSizes.rbegin(),

0 commit comments

Comments
 (0)