Skip to content

Commit 225b960

Browse files
[mlir][linalg] Support low padding in subtensor(pad_tensor) lowering
Differential Revision: https://reviews.llvm.org/D104591
1 parent 808ac8d commit 225b960

File tree

2 files changed

+111
-23
lines changed

2 files changed

+111
-23
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -731,9 +731,6 @@ LogicalResult SubTensorOfPadTensorSwapPattern::matchAndRewrite(
731731
Value padValue = padOp.getConstantPaddingValue();
732732
if (!padValue)
733733
return failure();
734-
// Only zero low padding supported at the moment.
735-
if (!padOp.hasZeroLowPad())
736-
return failure();
737734

738735
// Helper variables and functions for various arithmetic operations. These are
739736
// used extensively for computing new offset/length and padding values.
@@ -788,33 +785,53 @@ LogicalResult SubTensorOfPadTensorSwapPattern::matchAndRewrite(
788785

789786
int64_t rank = padOp.getSourceType().getRank();
790787
for (unsigned dim = 0; dim < rank; ++dim) {
788+
auto low = asValue(rewriter, loc, padOp.getMixedLowPad()[dim]);
791789
auto offset = asValue(rewriter, loc, subTensorOp.getMixedOffsets()[dim]);
792790
auto length = asValue(rewriter, loc, subTensorOp.getMixedSizes()[dim]);
793791
auto srcSize = rewriter.createOrFold<memref::DimOp>(
794792
loc, padOp.source(), dim);
795793

796-
// Existing low padding is zero, so new low padding is also zero.
797-
Value newLow = zero;
794+
// The new amount of low padding is `low - offset`. Except for the case
795+
// where none of the low padding is read. In that case, the new amount of
796+
// low padding is zero.
797+
Value newLow = max(zero, sub(low, offset));
798798
appendIndex(newLow, newLows, staticNewLows);
799799

800-
// There is no low padding, so the offset remains unchanged. Except for the
801-
// case where the SubTensorOp starts reading from a position within the high
802-
// padding. In that case, set the offset to the end of source tensor. The
803-
// new SubTensorOp length will be zero in that case. (Effectively reading no
800+
// Start reading the data from position `offset - low`. Since the original
801+
// read may have started in the low padding zone, this value could be
802+
// negative. Therefore, start reading from:
803+
//
804+
// max(offset - low, 0)
805+
//
806+
// The original read could also have started in the high padding zone.
807+
// In that case, set the offset to the end of source tensor. The new
808+
// SubTensorOp length will be zero in that case. (Effectively reading no
804809
// data from the source.)
805-
Value newOffset = min(offset, srcSize);
810+
Value newOffset = min(max(sub(offset, low), zero), srcSize);
806811
newOffsets.push_back(asOpFoldResult(rewriter, newOffset));
807812

808-
// The new SubTensorOp starts reading at `newOffset` and reads until
809-
// `offset + length`. This position may be outside of the source (i.e.,
810-
// within the high padding). In that case, read only until the end of the
811-
// source. In mathematical terms:
813+
// The original SubTensorOp was reading until position `offset + length`.
814+
// Therefore, the corresponding position within the source tensor is:
815+
//
816+
// offset + length - low
812817
//
813-
// endLoc = min(offset + length, srcSize)
818+
// In case the original SubTensorOp stopped reading within the low padding
819+
// zone, this value can be negative. In that case, the end position of the
820+
// read should be zero. (Similar to newOffset.)
821+
//
822+
// The original read could also have stopped in the high padding zone.
823+
// In that case, set the end positition of the read should be the end of the
824+
// source tensor. (Similar to newOffset.)
825+
//
826+
// endLoc = min(max(offset - low + length, 0), srcSize)
814827
//
815828
// The new SubTensorOp length is `endLoc - newOffset`.
816-
Value newLength = sub(min(add(offset, length), srcSize), newOffset);
829+
Value endLoc = min(max(add(sub(offset, low), length), zero), srcSize);
830+
Value newLength = sub(endLoc, newOffset);
817831
newLengths.push_back(asOpFoldResult(rewriter, newLength));
832+
833+
// Check if newLength is zero. In that case, no SubTensorOp should be
834+
// executed.
818835
if (auto newLengthInt = getConstantIntValue(newLength)) {
819836
hasZeroLen |= *newLengthInt == 0;
820837
} else {
@@ -824,13 +841,9 @@ LogicalResult SubTensorOfPadTensorSwapPattern::matchAndRewrite(
824841
? rewriter.create<AndOp>(loc, check, dynHasZeroLenCond) : check;
825842
}
826843

827-
// The number of elements available to read from the source (starting from
828-
// the new offset) is `maxRead = srcSize - newOffset`. The original
829-
// SubTensorOp may have read a larger number of elements `length > maxRead`.
830-
// In that case, the missing number of elements `length - maxRead` must be
831-
// paddded. (If `maxRead > length`, more than enough data is available to
832-
// read and no high padding is needed.)
833-
Value newHigh = max(zero, add(sub(newOffset, srcSize), length));
844+
// The amount of high padding is simply the number of elements remaining,
845+
// so that the result has the same length as the original SubTensorOp.
846+
Value newHigh = sub(sub(length, newLength), newLow);
834847
appendIndex(newHigh, newHighs, staticNewHighs);
835848

836849
// Only unit stride supported.

mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,44 @@ func @static_high_pad_only(%arg0 : tensor<4x5xf32>, %pad : f32)
3535

3636
// -----
3737

38+
// CHECK-LABEL: @static_low_pad_only
39+
// CHECK-SAME: %[[ARG0:.*]]: tensor<4x5xf32>, %[[PAD:.*]]: f32
40+
// CHECK-NOT: linalg.pad_tensor
41+
// CHECK-NOT: subtensor
42+
// CHECK: %[[RESULT:.*]] = tensor.generate
43+
// CHECK: tensor.yield %[[PAD]]
44+
// CHECK: return %[[RESULT]] : tensor<2x3xf32>
45+
func @static_low_pad_only(%arg0 : tensor<4x5xf32>, %pad : f32)
46+
-> tensor<2x3xf32> {
47+
%0 = linalg.pad_tensor %arg0 low[3, 7] high[7, 8] {
48+
^bb0(%arg1: index, %arg2: index):
49+
linalg.yield %pad : f32
50+
} : tensor<4x5xf32> to tensor<14x20xf32>
51+
%1 = subtensor %0[1, 3] [2, 3] [1, 1] : tensor<14x20xf32> to tensor<2x3xf32>
52+
return %1 : tensor<2x3xf32>
53+
}
54+
55+
// -----
56+
57+
// CHECK-LABEL: @static_low_pad_only_2
58+
// CHECK-SAME: %[[ARG0:.*]]: tensor<4x5xf32>, %[[PAD:.*]]: f32
59+
// CHECK-NOT: linalg.pad_tensor
60+
// CHECK-NOT: subtensor
61+
// CHECK: %[[RESULT:.*]] = tensor.generate
62+
// CHECK: tensor.yield %[[PAD]]
63+
// CHECK: return %[[RESULT]] : tensor<1x3xf32>
64+
func @static_low_pad_only_2(%arg0 : tensor<4x5xf32>, %pad : f32)
65+
-> tensor<1x3xf32> {
66+
%0 = linalg.pad_tensor %arg0 low[3, 7] high[7, 8] {
67+
^bb0(%arg1: index, %arg2: index):
68+
linalg.yield %pad : f32
69+
} : tensor<4x5xf32> to tensor<14x20xf32>
70+
%1 = subtensor %0[1, 3] [1, 3] [1, 1] : tensor<14x20xf32> to tensor<1x3xf32>
71+
return %1 : tensor<1x3xf32>
72+
}
73+
74+
// -----
75+
3876
// CHECK-LABEL: @static_mixed_data_high_pad
3977
// CHECK-SAME: %[[ARG0:.*]]: tensor<4x5xf32>, %[[PAD:.*]]: f32
4078
// CHECK-NOT: linalg.pad_tensor
@@ -54,6 +92,43 @@ func @static_mixed_data_high_pad(%arg0 : tensor<4x5xf32>, %pad : f32)
5492

5593
// -----
5694

95+
// CHECK-LABEL: @static_mixed_data_low_pad
96+
// CHECK-SAME: %[[ARG0:.*]]: tensor<4x5xf32>, %[[PAD:.*]]: f32
97+
// CHECK-NOT: linalg.pad_tensor
98+
// CHECK: %[[SUBTENSOR:.*]] = subtensor %[[ARG0]][0, 0] [2, 1] [1, 1] : tensor<4x5xf32> to tensor<2x1xf32>
99+
// CHECK: %[[RESULT:.*]] = linalg.pad_tensor %[[SUBTENSOR]] low[1, 3] high[0, 0]
100+
// CHECK: linalg.yield %[[PAD]]
101+
// CHECK: return %[[RESULT]] : tensor<3x4xf32>
102+
func @static_mixed_data_low_pad(%arg0 : tensor<4x5xf32>, %pad : f32)
103+
-> tensor<3x4xf32> {
104+
%0 = linalg.pad_tensor %arg0 low[3, 7] high[7, 8] {
105+
^bb0(%arg1: index, %arg2: index):
106+
linalg.yield %pad : f32
107+
} : tensor<4x5xf32> to tensor<14x20xf32>
108+
%1 = subtensor %0[2, 4] [3, 4] [1, 1] : tensor<14x20xf32> to tensor<3x4xf32>
109+
return %1 : tensor<3x4xf32>
110+
}
111+
112+
// -----
113+
114+
// CHECK-LABEL: @static_mixed_data_low_high_pad
115+
// CHECK-SAME: %[[ARG0:.*]]: tensor<4x5xf32>, %[[PAD:.*]]: f32
116+
// CHECK-NOT: linalg.pad_tensor
117+
// CHECK: %[[RESULT:.*]] = linalg.pad_tensor %[[ARG0]] low[1, 1] high[2, 3]
118+
// CHECK: linalg.yield %[[PAD]]
119+
// CHECK: return %[[RESULT]] : tensor<7x9xf32>
120+
func @static_mixed_data_low_high_pad(%arg0 : tensor<4x5xf32>, %pad : f32)
121+
-> tensor<7x9xf32> {
122+
%0 = linalg.pad_tensor %arg0 low[2, 3] high[7, 8] {
123+
^bb0(%arg1: index, %arg2: index):
124+
linalg.yield %pad : f32
125+
} : tensor<4x5xf32> to tensor<13x16xf32>
126+
%1 = subtensor %0[1, 2] [7, 9] [1, 1] : tensor<13x16xf32> to tensor<7x9xf32>
127+
return %1 : tensor<7x9xf32>
128+
}
129+
130+
// -----
131+
57132
// CHECK-LABEL: @dynamic_high_pad
58133
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x5xf32>
59134
// CHECK-NOT: linalg.pad_tensor

0 commit comments

Comments
 (0)