@@ -1142,75 +1142,100 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
1142
1142
tensor::PackOp packOp, PatternRewriter &rewriter) const {
1143
1143
// TODO: support the case that outer dimensions are not all 1s. A
1144
1144
// tensor.expand_shape will be generated in this case.
1145
- if (llvm::any_of (packOp.getTiledOuterDims (),
1145
+ if (llvm::any_of (packOp.getAllOuterDims (),
1146
1146
[](int64_t dim) { return dim != 1 ; })) {
1147
1147
return rewriter.notifyMatchFailure (
1148
- packOp, " require the tiled outer dimensions of the result are all 1s" );
1148
+ packOp, " not all outer dimensions of the result are 1s" );
1149
1149
}
1150
1150
1151
- // 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled
1152
- // outer dims.
1151
+ Attribute zeroIdxAttr = rewriter. getIndexAttr ( 0 );
1152
+ Attribute oneIdxAttr = rewriter. getIndexAttr ( 1 );
1153
1153
Location loc = packOp.getLoc ();
1154
+
1154
1155
Value input = getPackOpSourceOrPaddedSource (rewriter, packOp);
1155
1156
auto inputShape = packOp.getSourceType ().getShape ();
1156
1157
DenseMap<int64_t , OpFoldResult> dimAndTileMapping =
1157
1158
packOp.getDimAndTileMapping ();
1158
- Attribute zeroIdxAttr = rewriter.getIndexAttr (0 );
1159
- Attribute oneIdxAttr = rewriter.getIndexAttr (1 );
1160
1159
int64_t srcRank = packOp.getSourceRank ();
1160
+
1161
+ int64_t destRank = packOp.getDestRank ();
1162
+ size_t numTiles = destRank - srcRank;
1163
+
1164
+ // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
1165
+ // %extracted_tile = tensor.extract_slice(%pack_op_input)
1161
1166
SmallVector<OpFoldResult> readOffsets (srcRank, zeroIdxAttr);
1162
1167
SmallVector<OpFoldResult> readStrides (srcRank, oneIdxAttr);
1163
- SmallVector<OpFoldResult> readSizes;
1164
- SmallVector<OpFoldResult> transShapeForEmpty;
1165
- SmallVector<int64_t > readShapeForExtractSlice;
1168
+
1169
+ // The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as
1170
+ // all outer dims are 1.
1171
+ SmallVector<OpFoldResult> extractSliceSizes (srcRank - numTiles, oneIdxAttr);
1172
+ // The shape of the output for ExtractSliceOp. All leading unit dims are
1173
+ // effectively rank-reduced, hence skipped.
1174
+ SmallVector<int64_t > outputShapeForExtractSlice;
1175
+
1176
+ // Extract the trailing sizes and shape dims for ExtractSliceOp. These should
1177
+ // be equal to the inner tile sizes.
1166
1178
for (auto i : llvm::seq<unsigned >(0 , srcRank)) {
1167
1179
if (dimAndTileMapping.count (i)) {
1168
- readShapeForExtractSlice.push_back (
1169
- getConstantIntValue (dimAndTileMapping[i])
1170
- .value_or (ShapedType::kDynamic ));
1171
- readSizes.push_back (dimAndTileMapping[i]);
1172
- transShapeForEmpty.push_back (dimAndTileMapping[i]);
1173
- continue ;
1174
- }
1175
- if (ShapedType::isDynamic (inputShape[i])) {
1176
- readSizes.push_back (
1177
- rewriter.create <tensor::DimOp>(loc, input, i).getResult ());
1178
- } else {
1179
- readSizes.push_back (rewriter.getIndexAttr (inputShape[i]));
1180
- }
1181
- if (inputShape[i] != 1 ) {
1182
- readShapeForExtractSlice.push_back (inputShape[i]);
1183
- transShapeForEmpty.push_back (rewriter.getIndexAttr (inputShape[i]));
1180
+ auto [tileSize, tileSizeOfr] =
1181
+ getSimplifiedOfrAndStaticSizePair (dimAndTileMapping[i], rewriter);
1182
+ extractSliceSizes.push_back (tileSizeOfr);
1183
+ outputShapeForExtractSlice.push_back (tileSize);
1184
1184
}
1185
1185
}
1186
1186
1187
1187
Type elemType = packOp.getSourceType ().getElementType ();
1188
- auto readType = RankedTensorType::get (readShapeForExtractSlice , elemType);
1188
+ auto readType = RankedTensorType::get (outputShapeForExtractSlice , elemType);
1189
1189
1190
1190
Value tile = rewriter.create <tensor::ExtractSliceOp>(
1191
- loc, readType, input, readOffsets, readSizes , readStrides);
1191
+ loc, readType, input, readOffsets, extractSliceSizes , readStrides);
1192
1192
1193
- // 2. Transpose the tile to match the inner tile order.
1193
+ // 2. Transpose the tile to match the inner tile order:
1194
+ // %init = tensor.empty()
1195
+ // %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
1196
+ // NOTE: Outer dims are 1 and hence effectively ignored.
1194
1197
SmallVector<int64_t > perm = getPackUnpackRankReducedPerm (
1195
1198
inputShape, packOp.getInnerDimsPos (), packOp.getOuterDimsPerm ());
1196
1199
1197
1200
LLVM_DEBUG (DBGS () << " Pack permutation: " << packOp << " \n " ;
1198
1201
llvm::interleaveComma (perm, DBGS () << " perm: " ); DBGSNL (););
1199
1202
1200
- applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
1203
+ // 2.1 Create tensor.empty (init value for TransposeOp)
1204
+ SmallVector<OpFoldResult> transShapeForEmptyOp;
1201
1205
1206
+ // Acquire tensor shape required to create EmptyOp. This will match the inner
1207
+ // tile sizes.
1208
+ size_t idx = numTiles;
1209
+ while (idx != 0 ) {
1210
+ transShapeForEmptyOp.push_back (extractSliceSizes[srcRank - idx]);
1211
+ idx--;
1212
+ }
1213
+
1214
+ applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp, perm);
1202
1215
Value empty =
1203
- rewriter.create <tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
1216
+ rewriter.create <tensor::EmptyOp>(loc, transShapeForEmptyOp, elemType);
1217
+
1218
+ // 2.2 Create linalg.transpose
1204
1219
auto transposedOp =
1205
1220
rewriter.create <linalg::TransposeOp>(loc, tile, empty, perm);
1206
1221
1207
- // 3. Insert the inner tile to the destination.
1208
- int64_t destRank = packOp. getDestRank ();
1222
+ // 3. Insert the inner tile to the destination:
1223
+ // %inserted_tile = tensor.insert_slice(%transposed_tile)
1209
1224
SmallVector<OpFoldResult> writeStrides (destRank, oneIdxAttr);
1210
1225
SmallVector<OpFoldResult> writeOffsets (destRank, zeroIdxAttr);
1211
- SmallVector<OpFoldResult> writeSizes =
1212
- tensor::getMixedSizes (rewriter, loc, packOp.getDest ());
1226
+ // Outer dims are all 1s!
1227
+ SmallVector<OpFoldResult> writeSizes (destRank - dimAndTileMapping.size (),
1228
+ oneIdxAttr);
1229
+ SmallVector<int64_t > writeShape;
1230
+
1231
+ for (auto tileSize : packOp.getMixedTiles ()) {
1232
+ auto [tileSizeStatic, tileSizeOfr] =
1233
+ getSimplifiedOfrAndStaticSizePair (tileSize, rewriter);
1234
+ writeSizes.push_back (tileSizeOfr);
1235
+ writeShape.push_back (tileSizeStatic);
1236
+ }
1213
1237
1238
+ // 4. Replace tensor.packOp with tensor.insert_slice created above
1214
1239
auto insert = rewriter.create <tensor::InsertSliceOp>(
1215
1240
loc, transposedOp.getResult ()[0 ], packOp.getDest (), writeOffsets,
1216
1241
writeSizes, writeStrides);
0 commit comments