@@ -1021,8 +1021,11 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
1021
1021
return success ();
1022
1022
}
1023
1023
1024
- // / Returns a tensor.pad op if padding value is set. Otherwise, returns the
1025
- // / source directly. The method assumes that the `packOp` has static shapes.
1024
+ // / If padding value is set, returns a tensor.pad Op for the source tensor,
1025
+ // / with the output shape matching the output of `packOp`. Otherwise, returns
1026
+ // / the source directly.
1027
+ // /
1028
+ // / This method assumes that all outer dims for this pack Op are 1.
1026
1029
static Value getPackOpSourceOrPaddedSource (OpBuilder &builder,
1027
1030
tensor::PackOp packOp) {
1028
1031
Value input = packOp.getSource ();
@@ -1038,26 +1041,48 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
1038
1041
ShapedType inputType = packOp.getSourceType ();
1039
1042
int64_t inputRank = inputType.getRank ();
1040
1043
1041
- SmallVector<int64_t > paddedShape;
1042
1044
DenseMap<int64_t , OpFoldResult> tileAndPosMapping =
1043
1045
packOp.getDimAndTileMapping ();
1044
- for (int64_t dim = 0 ; dim < inputRank; ++dim) {
1045
- int64_t size = inputType.getDimSize (dim);
1046
- if (!tileAndPosMapping.count (dim)) {
1047
- paddedShape.push_back (size);
1046
+
1047
+ // The sizes of dynamic tiles
1048
+ SmallVector<Value> dynamicTileSizes;
1049
+
1050
+ // Collect dims for the padded shape.
1051
+ SmallVector<int64_t > paddedShape;
1052
+ for (int64_t dimIdx = 0 ; dimIdx < inputRank; ++dimIdx) {
1053
+ // 1. Non-tiled outer dims.
1054
+ // These dims should be 1 and we simply preserve them.
1055
+ if (!tileAndPosMapping.count (dimIdx)) {
1056
+ int64_t inputDimSize = inputType.getDimSize (dimIdx);
1057
+ assert (inputDimSize == 1 &&
1058
+ " with all outer dims == 1, this non-tiled input dim should be 1!" );
1059
+ paddedShape.push_back (inputDimSize);
1060
+ continue ;
1061
+ }
1062
+
1063
+ // 2. Tiled outer dims
1064
+ // As all outer dims == 1, it is safe to use the tile size for the padded
1065
+ // shape.
1066
+ OpFoldResult tileSizeForDim = tileAndPosMapping.lookup (dimIdx);
1067
+
1068
+ // 2.1 Static tile sizes
1069
+ std::optional<int64_t > cstTileSize = getConstantIntValue (tileSizeForDim);
1070
+ if (cstTileSize.has_value ()) {
1071
+ paddedShape.push_back (cstTileSize.value ());
1048
1072
continue ;
1049
1073
}
1050
1074
1051
- // The size is less than or equal to tileSize because outer dims are all 1s.
1052
- std::optional< int64_t > tileSize =
1053
- getConstantIntValue (tileAndPosMapping. lookup (dim));
1054
- assert (tileSize. has_value () && " dynamic inner tile size is not supported " );
1055
- paddedShape .push_back (tileSize. value ( ));
1075
+ // 2.2 Dynamic tile sizes
1076
+ paddedShape. push_back (ShapedType:: kDynamic );
1077
+
1078
+ // Get the value that holds the dynamic size.
1079
+ dynamicTileSizes .push_back (llvm::dyn_cast<Value>(tileSizeForDim ));
1056
1080
}
1057
1081
auto resultType =
1058
1082
RankedTensorType::get (paddedShape, inputType.getElementType ());
1059
1083
return tensor::createPadHighOp (resultType, input, packOp.getPaddingValue (),
1060
- /* nofold=*/ false , loc, builder);
1084
+ /* nofold=*/ false , loc, builder,
1085
+ dynamicTileSizes);
1061
1086
}
1062
1087
1063
1088
// Normalizes a permutation on a higher rank space to its actual size, e.g.
@@ -1120,10 +1145,10 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
1120
1145
1121
1146
LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite (
1122
1147
tensor::PackOp packOp, PatternRewriter &rewriter) const {
1123
- if (llvm::any_of (packOp.getMixedTiles (),
1124
- [](OpFoldResult tile) { return tile.is <Value>(); })) {
1125
- return rewriter.notifyMatchFailure (packOp,
1126
- " require inner tile sizes being static " );
1148
+ if (llvm::count_if (packOp.getMixedTiles (),
1149
+ [](OpFoldResult tile) { return tile.is <Value>(); }) > 1 ) {
1150
+ return rewriter.notifyMatchFailure (
1151
+ packOp, " at most one dynamic tile size is supported " );
1127
1152
}
1128
1153
1129
1154
// TODO: support the case that outer dimensions are not all 1s. A
@@ -1147,12 +1172,15 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
1147
1172
SmallVector<OpFoldResult> readOffsets (srcRank, zeroIdxAttr);
1148
1173
SmallVector<OpFoldResult> readStrides (srcRank, oneIdxAttr);
1149
1174
SmallVector<OpFoldResult> readSizes;
1150
- SmallVector<int64_t > readShape;
1175
+ SmallVector<OpFoldResult> transShapeForEmpty;
1176
+ SmallVector<int64_t > readShapeForExtractSlice;
1151
1177
for (auto i : llvm::seq<unsigned >(0 , srcRank)) {
1152
1178
if (dimAndTileMapping.count (i)) {
1153
- readShape.push_back (getConstantIntValue (dimAndTileMapping[i])
1154
- .value_or (ShapedType::kDynamic ));
1179
+ readShapeForExtractSlice.push_back (
1180
+ getConstantIntValue (dimAndTileMapping[i])
1181
+ .value_or (ShapedType::kDynamic ));
1155
1182
readSizes.push_back (dimAndTileMapping[i]);
1183
+ transShapeForEmpty.push_back (dimAndTileMapping[i]);
1156
1184
continue ;
1157
1185
}
1158
1186
if (ShapedType::isDynamic (inputShape[i])) {
@@ -1161,12 +1189,14 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
1161
1189
} else {
1162
1190
readSizes.push_back (rewriter.getIndexAttr (inputShape[i]));
1163
1191
}
1164
- if (inputShape[i] != 1 )
1165
- readShape.push_back (inputShape[i]);
1192
+ if (inputShape[i] != 1 ) {
1193
+ readShapeForExtractSlice.push_back (inputShape[i]);
1194
+ transShapeForEmpty.push_back (rewriter.getIndexAttr (inputShape[i]));
1195
+ }
1166
1196
}
1167
1197
1168
1198
Type elemType = packOp.getSourceType ().getElementType ();
1169
- auto readType = RankedTensorType::get (readShape , elemType);
1199
+ auto readType = RankedTensorType::get (readShapeForExtractSlice , elemType);
1170
1200
1171
1201
Value tile = rewriter.create <tensor::ExtractSliceOp>(
1172
1202
loc, readType, input, readOffsets, readSizes, readStrides);
@@ -1178,10 +1208,10 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
1178
1208
LLVM_DEBUG (DBGS () << " Pack permutation: " << packOp << " \n " ;
1179
1209
llvm::interleaveComma (perm, DBGS () << " perm: " ); DBGSNL (););
1180
1210
1181
- SmallVector<int64_t > transpShape = readShape;
1182
- applyPermutationToVector<int64_t >(transpShape, perm);
1211
+ applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
1183
1212
1184
- Value empty = rewriter.create <tensor::EmptyOp>(loc, transpShape, elemType);
1213
+ Value empty =
1214
+ rewriter.create <tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
1185
1215
auto transposedOp =
1186
1216
rewriter.create <linalg::TransposeOp>(loc, tile, empty, perm);
1187
1217
0 commit comments