@@ -1408,15 +1408,16 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1408
1408
// / dimension.
1409
1409
// / i.e. for the following tensor.pack:
1410
1410
// / ```mlir
1411
- // / %pack = tensor.pack %0 padding_value(%1)
1412
- // / outer_dims_perm = [0, 2, 1]
1413
- // / inner_dims_pos = [2, 1]
1414
- // / inner_tiles = [16, 2]
1411
+ // / %pack = tensor.pack %0 padding_value(%1)
1412
+ // / outer_dims_perm = [0, 2, 1]
1413
+ // / inner_dims_pos = [2, 1]
1414
+ // / inner_tiles = [16, 2]
1415
1415
// / into %2 : tensor<32x8x16> -> tensor<32x1x4x16x2>
1416
1416
// / ```
1417
1417
// / The "packed" shape is `32x1x4x16x2`
1418
1418
// / The "tiled" shape is `32x(4x2)x(1x16)`
1419
- static SmallVector<int64_t > getTiledShapeToPackedShapePerm (tensor::PackOp packOp) {
1419
+ static SmallVector<int64_t >
1420
+ getTiledShapeToPackedShapePerm (tensor::PackOp packOp) {
1420
1421
auto innerTiles = packOp.getInnerTiles ();
1421
1422
int64_t srcRank = packOp.getSourceRank ();
1422
1423
auto innerDimsPos = packOp.getInnerDimsPos ();
@@ -1425,7 +1426,7 @@ static SmallVector<int64_t> getTiledShapeToPackedShapePerm(tensor::PackOp packOp
1425
1426
auto outerDimsPerm = packOp.getOuterDimsPerm ();
1426
1427
if (outerDimsPerm.empty ())
1427
1428
outerDimsPerm = to_vector (llvm::seq<int64_t >(srcRank));
1428
- auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t {
1429
+ auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t {
1429
1430
int64_t srcIdx;
1430
1431
if (idx >= srcRank)
1431
1432
srcIdx = innerDimsPos[idx - srcRank];
@@ -1440,7 +1441,7 @@ static SmallVector<int64_t> getTiledShapeToPackedShapePerm(tensor::PackOp packOp
1440
1441
return tiledIdx;
1441
1442
};
1442
1443
SmallVector<int64_t > perm;
1443
- for (int i = 0 ; i < packOp.getDestRank (); i++)
1444
+ for (int i = 0 ; i < packOp.getDestRank (); i++)
1444
1445
perm.push_back (packedIdxToTiledIdx (i));
1445
1446
return perm;
1446
1447
}
@@ -1453,11 +1454,11 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp) {
1453
1454
return applyPermutation (destShape, invertPermutationVector (perm));
1454
1455
}
1455
1456
1456
- // /
1457
+ // /
1457
1458
static LogicalResult
1458
1459
vectorizeAsTensorPackOp (RewriterBase &rewriter, tensor::PackOp packOp,
1459
- ArrayRef<int64_t > inputVectorSizes,
1460
- SmallVectorImpl<Value> &newResults) {
1460
+ ArrayRef<int64_t > inputVectorSizes,
1461
+ SmallVectorImpl<Value> &newResults) {
1461
1462
OpBuilder::InsertionGuard g (rewriter);
1462
1463
rewriter.setInsertionPoint (packOp);
1463
1464
@@ -1496,10 +1497,13 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1496
1497
mlir::vector::maskOperation (rewriter, transferReadOp, mask));
1497
1498
// ShapeCast
1498
1499
auto tiledPackShape = getTiledPackShape (packOp);
1499
- auto tiledPackType = VectorType::get (tiledPackShape, packOp.getDestType ().getElementType ());
1500
- auto shapeCastOp = rewriter.create <vector::ShapeCastOp>(loc, tiledPackType, maskedOp->getResult (0 ));
1500
+ auto tiledPackType =
1501
+ VectorType::get (tiledPackShape, packOp.getDestType ().getElementType ());
1502
+ auto shapeCastOp = rewriter.create <vector::ShapeCastOp>(
1503
+ loc, tiledPackType, maskedOp->getResult (0 ));
1501
1504
auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm (packOp);
1502
- auto transposeOp = rewriter.create <vector::TransposeOp>(loc, shapeCastOp->getResult (0 ), tiledShapeToPackedShapePerm);
1505
+ auto transposeOp = rewriter.create <vector::TransposeOp>(
1506
+ loc, shapeCastOp->getResult (0 ), tiledShapeToPackedShapePerm);
1503
1507
Operation *write = rewriter.create <vector::TransferWriteOp>(
1504
1508
loc,
1505
1509
/* vector=*/ transposeOp->getResult (0 ),
@@ -1704,7 +1708,7 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
1704
1708
1705
1709
static LogicalResult
1706
1710
vectorizePackOpPrecondition (tensor::PackOp packOp,
1707
- ArrayRef<int64_t > inputVectorSizes) {
1711
+ ArrayRef<int64_t > inputVectorSizes) {
1708
1712
auto padValue = packOp.getPaddingValue ();
1709
1713
if (padValue && getConstantIntValue (padValue) != std::nullopt) {
1710
1714
LDBG (" pad value is not constant: " << packOp << " \n " );
@@ -1877,7 +1881,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
1877
1881
results);
1878
1882
})
1879
1883
.Case <tensor::PackOp>([&](auto packOp) {
1880
- return vectorizeAsTensorPackOp (rewriter, packOp, inputVectorSizes, results);
1884
+ return vectorizeAsTensorPackOp (rewriter, packOp, inputVectorSizes,
1885
+ results);
1881
1886
})
1882
1887
.Default ([](auto ) { return failure (); });
1883
1888
0 commit comments