@@ -1254,64 +1254,98 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
1254
1254
" require the tiled outer dimensions of the result are all 1s" );
1255
1255
}
1256
1256
1257
- // 1. Use rank-reduced tensor.extract_slice op to extract the tile.
1257
+ // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
1258
+ // %extracted_tile = tensor.extract_slice(%unpack_op_input)
1258
1259
Location loc = unpackOp.getLoc ();
1259
1260
Value source = unpackOp.getSource ();
1260
1261
DenseMap<int64_t , OpFoldResult> dimAndTileMapping =
1261
1262
unpackOp.getDimAndTileMapping ();
1262
1263
Attribute zeroIdxAttr = rewriter.getIndexAttr (0 );
1263
1264
Attribute oneIdxAttr = rewriter.getIndexAttr (1 );
1264
- SmallVector<OpFoldResult> readOffsets (srcRank, zeroIdxAttr);
1265
- SmallVector<OpFoldResult> readStrides (srcRank, oneIdxAttr);
1266
- SmallVector<OpFoldResult> readSizes;
1267
- SmallVector<int64_t > readShape;
1268
- SmallVector<Value> dynamicDims;
1265
+
1266
+ // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
1267
+ // dims:
1268
+ // [ outer-untiled-dims, outer-tiled-dims, tile-sizes ]
1269
+ SmallVector<int64_t > readShapeForExtractSlice;
1270
+ // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and
1271
+ // outer-tiled-dims being all 1), this will be
1272
+ // [ outer-untiled-dims, tile-sizes ]
1273
+ SmallVector<OpFoldResult> extractSliceSizes;
1274
+ // The offset and strides attributes for ExtractSliceOp.
1275
+ SmallVector<OpFoldResult> extractSliceOffsets (srcRank, zeroIdxAttr);
1276
+ SmallVector<OpFoldResult> extractSliceStrides (srcRank, oneIdxAttr);
1277
+
1278
+ // Shape for EmptyOp that's used as the init value for TransposeOp below.
1279
+ // This should be:
1280
+ // [ outer-untiled-dims, tile-sizes ]
1281
+ // However, skip unit dims - TransposeOp (below) applies rank-reduced
1282
+ // permutation.
1283
+ SmallVector<OpFoldResult> shapeForEmptyOp;
1284
+
1269
1285
for (auto i : llvm::seq<unsigned >(0 , destRank)) {
1286
+ // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims.
1287
+ //
1288
+ // As all outer tiled dims are 1, so the corresponding
1289
+ // slice size to read will also 1. As this will be rank-reducing "extract
1290
+ // slice" (i.e. the unit dims will be "collapsed"), there's no need to
1291
+ // update:
1292
+ // * the output shape for ExtractSliceOp, nor
1293
+ // * the shape for EmptyOp.
1270
1294
if (dimAndTileMapping.count (i)) {
1271
- readSizes .push_back (oneIdxAttr);
1295
+ extractSliceSizes .push_back (oneIdxAttr);
1272
1296
continue ;
1273
1297
}
1274
1298
1299
+ // Compute sizes attribute for ExtractSliceOp + EmptyOp -
1300
+ // outer-untiled-dims
1275
1301
if (ShapedType::isDynamic (srcShape[i])) {
1276
- Value dynamicDim =
1302
+ OpFoldResult dynamicDim =
1277
1303
rewriter.create <tensor::DimOp>(loc, source, i).getResult ();
1278
- readSizes .push_back (dynamicDim);
1279
- dynamicDims .push_back (dynamicDim);
1304
+ extractSliceSizes .push_back (dynamicDim);
1305
+ shapeForEmptyOp .push_back (dynamicDim);
1280
1306
} else {
1281
- readSizes.push_back (rewriter.getIndexAttr (srcShape[i]));
1307
+ extractSliceSizes.push_back (rewriter.getIndexAttr (srcShape[i]));
1308
+ if (srcShape[i] != 1 )
1309
+ shapeForEmptyOp.push_back (rewriter.getIndexAttr (srcShape[i]));
1310
+ }
1311
+ // Compute the output shape for ExtractSliceOp - outer-untiled-dims (take
1312
+ // into account rank-reducing)
1313
+ if (srcShape[i] != 1 ) {
1314
+ readShapeForExtractSlice.push_back (srcShape[i]);
1282
1315
}
1283
- if (srcShape[i] != 1 )
1284
- readShape.push_back (srcShape[i]);
1285
1316
}
1317
+ // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the
1318
+ // shape for EmptyOp.
1286
1319
auto mixedTiles = unpackOp.getMixedTiles ();
1287
- readSizes.append (mixedTiles.begin (), mixedTiles.end ());
1320
+ extractSliceSizes.append (mixedTiles.begin (), mixedTiles.end ());
1321
+ shapeForEmptyOp.append (mixedTiles.begin (), mixedTiles.end ());
1288
1322
1289
1323
// Explicitly create the type for extract_slice op because the inner tile
1290
1324
// size could be 1. We want to represent the whole inner tile in this case.
1291
1325
auto tileShape = srcShape.drop_front (destRank);
1292
1326
// Append the inner tile shape to the permuted and rank-reduced outer shape.
1293
- readShape .append (tileShape.begin (), tileShape.end ());
1327
+ readShapeForExtractSlice .append (tileShape.begin (), tileShape.end ());
1294
1328
Type elemType = unpackOp.getSourceType ().getElementType ();
1295
- auto readType = RankedTensorType::get (readShape , elemType);
1329
+ auto readType = RankedTensorType::get (readShapeForExtractSlice , elemType);
1296
1330
Value innerTile = rewriter.create <tensor::ExtractSliceOp>(
1297
- loc, readType, unpackOp.getSource (), readOffsets, readSizes, readStrides);
1331
+ loc, readType, unpackOp.getSource (), extractSliceOffsets,
1332
+ extractSliceSizes, extractSliceStrides);
1298
1333
1299
1334
// 2. Transpose the tile to match the outer corresponding tile order.
1300
1335
SmallVector<int64_t > perm = getPackUnpackRankReducedPerm (
1301
1336
srcShape.take_front (destRank), innerDimsPos, unpackOp.getOuterDimsPerm ());
1302
1337
// Unpack is a transition out of packed space so we invert the permutation.
1303
1338
perm = invertPermutationVector (perm);
1304
- SmallVector<int64_t > transpShape (readShape);
1305
- applyPermutationToVector<int64_t >(transpShape, perm);
1339
+ applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
1306
1340
1307
1341
Value empty =
1308
- rewriter.create <tensor::EmptyOp>(loc, transpShape , elemType, dynamicDims );
1342
+ rewriter.create <tensor::EmptyOp>(loc, shapeForEmptyOp , elemType);
1309
1343
auto transposedOp =
1310
1344
rewriter.create <linalg::TransposeOp>(loc, innerTile, empty, perm);
1311
1345
1312
1346
// 3. Handle in-complete tiles if needed. It truncates trailing data from the
1313
1347
// transposed tile.
1314
- int numLoops = transpShape .size ();
1348
+ int numLoops = shapeForEmptyOp .size ();
1315
1349
SmallVector<OpFoldResult> tileStrides (numLoops, oneIdxAttr);
1316
1350
SmallVector<OpFoldResult> tileOffsets (numLoops, zeroIdxAttr);
1317
1351
SmallVector<OpFoldResult> tileSizes;
0 commit comments