@@ -731,9 +731,6 @@ LogicalResult SubTensorOfPadTensorSwapPattern::matchAndRewrite(
731
731
Value padValue = padOp.getConstantPaddingValue ();
732
732
if (!padValue)
733
733
return failure ();
734
- // Only zero low padding supported at the moment.
735
- if (!padOp.hasZeroLowPad ())
736
- return failure ();
737
734
738
735
// Helper variables and functions for various arithmetic operations. These are
739
736
// used extensively for computing new offset/length and padding values.
@@ -788,33 +785,53 @@ LogicalResult SubTensorOfPadTensorSwapPattern::matchAndRewrite(
788
785
789
786
int64_t rank = padOp.getSourceType ().getRank ();
790
787
for (unsigned dim = 0 ; dim < rank; ++dim) {
788
+ auto low = asValue (rewriter, loc, padOp.getMixedLowPad ()[dim]);
791
789
auto offset = asValue (rewriter, loc, subTensorOp.getMixedOffsets ()[dim]);
792
790
auto length = asValue (rewriter, loc, subTensorOp.getMixedSizes ()[dim]);
793
791
auto srcSize = rewriter.createOrFold <memref::DimOp>(
794
792
loc, padOp.source (), dim);
795
793
796
- // Existing low padding is zero, so new low padding is also zero.
797
- Value newLow = zero;
794
+ // The new amount of low padding is `low - offset`. Except for the case
795
+ // where none of the low padding is read. In that case, the new amount of
796
+ // low padding is zero.
797
+ Value newLow = max (zero, sub (low, offset));
798
798
appendIndex (newLow, newLows, staticNewLows);
799
799
800
- // There is no low padding, so the offset remains unchanged. Except for the
801
- // case where the SubTensorOp starts reading from a position within the high
802
- // padding. In that case, set the offset to the end of source tensor. The
803
- // new SubTensorOp length will be zero in that case. (Effectively reading no
800
+ // Start reading the data from position `offset - low`. Since the original
801
+ // read may have started in the low padding zone, this value could be
802
+ // negative. Therefore, start reading from:
803
+ //
804
+ // max(offset - low, 0)
805
+ //
806
+ // The original read could also have started in the high padding zone.
807
+ // In that case, set the offset to the end of source tensor. The new
808
+ // SubTensorOp length will be zero in that case. (Effectively reading no
804
809
// data from the source.)
805
- Value newOffset = min (offset, srcSize);
810
+ Value newOffset = min (max ( sub ( offset, low), zero) , srcSize);
806
811
newOffsets.push_back (asOpFoldResult (rewriter, newOffset));
807
812
808
- // The new SubTensorOp starts reading at `newOffset` and reads until
809
- // `offset + length`. This position may be outside of the source (i.e.,
810
- // within the high padding). In that case, read only until the end of the
811
- // source. In mathematical terms:
813
+ // The original SubTensorOp was reading until position `offset + length`.
814
+ // Therefore, the corresponding position within the source tensor is:
815
+ //
816
+ // offset + length - low
812
817
//
813
- // endLoc = min(offset + length, srcSize)
818
+ // In case the original SubTensorOp stopped reading within the low padding
819
+ // zone, this value can be negative. In that case, the end position of the
820
+ // read should be zero. (Similar to newOffset.)
821
+ //
822
+ // The original read could also have stopped in the high padding zone.
823
+ // In that case, set the end positition of the read should be the end of the
824
+ // source tensor. (Similar to newOffset.)
825
+ //
826
+ // endLoc = min(max(offset - low + length, 0), srcSize)
814
827
//
815
828
// The new SubTensorOp length is `endLoc - newOffset`.
816
- Value newLength = sub (min (add (offset, length), srcSize), newOffset);
829
+ Value endLoc = min (max (add (sub (offset, low), length), zero), srcSize);
830
+ Value newLength = sub (endLoc, newOffset);
817
831
newLengths.push_back (asOpFoldResult (rewriter, newLength));
832
+
833
+ // Check if newLength is zero. In that case, no SubTensorOp should be
834
+ // executed.
818
835
if (auto newLengthInt = getConstantIntValue (newLength)) {
819
836
hasZeroLen |= *newLengthInt == 0 ;
820
837
} else {
@@ -824,13 +841,9 @@ LogicalResult SubTensorOfPadTensorSwapPattern::matchAndRewrite(
824
841
? rewriter.create <AndOp>(loc, check, dynHasZeroLenCond) : check;
825
842
}
826
843
827
- // The number of elements available to read from the source (starting from
828
- // the new offset) is `maxRead = srcSize - newOffset`. The original
829
- // SubTensorOp may have read a larger number of elements `length > maxRead`.
830
- // In that case, the missing number of elements `length - maxRead` must be
831
- // paddded. (If `maxRead > length`, more than enough data is available to
832
- // read and no high padding is needed.)
833
- Value newHigh = max (zero, add (sub (newOffset, srcSize), length));
844
+ // The amount of high padding is simply the number of elements remaining,
845
+ // so that the result has the same length as the original SubTensorOp.
846
+ Value newHigh = sub (sub (length, newLength), newLow);
834
847
appendIndex (newHigh, newHighs, staticNewHighs);
835
848
836
849
// Only unit stride supported.
0 commit comments