Skip to content

Commit e08bb0c

Browse files
committed
clang
1 parent 06f86da commit e08bb0c

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,15 +1408,16 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14081408
/// dimension.
14091409
/// i.e. for the following tensor.pack:
14101410
/// ```mlir
1411-
/// %pack = tensor.pack %0 padding_value(%1)
1412-
/// outer_dims_perm = [0, 2, 1]
1413-
/// inner_dims_pos = [2, 1]
1414-
/// inner_tiles = [16, 2]
1411+
/// %pack = tensor.pack %0 padding_value(%1)
1412+
/// outer_dims_perm = [0, 2, 1]
1413+
/// inner_dims_pos = [2, 1]
1414+
/// inner_tiles = [16, 2]
14151415
/// into %2 : tensor<32x8x16> -> tensor<32x1x4x16x2>
14161416
/// ```
14171417
/// The "packed" shape is `32x1x4x16x2`
14181418
/// The "tiled" shape is `32x(4x2)x(1x16)`
1419-
static SmallVector<int64_t> getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
1419+
static SmallVector<int64_t>
1420+
getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
14201421
auto innerTiles = packOp.getInnerTiles();
14211422
int64_t srcRank = packOp.getSourceRank();
14221423
auto innerDimsPos = packOp.getInnerDimsPos();
@@ -1425,7 +1426,7 @@ static SmallVector<int64_t> getTiledShapeToPackedShapePerm(tensor::PackOp packOp
14251426
auto outerDimsPerm = packOp.getOuterDimsPerm();
14261427
if (outerDimsPerm.empty())
14271428
outerDimsPerm = to_vector(llvm::seq<int64_t>(srcRank));
1428-
auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t {
1429+
auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t {
14291430
int64_t srcIdx;
14301431
if (idx >= srcRank)
14311432
srcIdx = innerDimsPos[idx - srcRank];
@@ -1440,7 +1441,7 @@ static SmallVector<int64_t> getTiledShapeToPackedShapePerm(tensor::PackOp packOp
14401441
return tiledIdx;
14411442
};
14421443
SmallVector<int64_t> perm;
1443-
for (int i = 0; i < packOp.getDestRank(); i++)
1444+
for (int i = 0; i < packOp.getDestRank(); i++)
14441445
perm.push_back(packedIdxToTiledIdx(i));
14451446
return perm;
14461447
}
@@ -1453,11 +1454,11 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp) {
14531454
return applyPermutation(destShape, invertPermutationVector(perm));
14541455
}
14551456

1456-
///
1457+
///
14571458
static LogicalResult
14581459
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1459-
ArrayRef<int64_t> inputVectorSizes,
1460-
SmallVectorImpl<Value> &newResults) {
1460+
ArrayRef<int64_t> inputVectorSizes,
1461+
SmallVectorImpl<Value> &newResults) {
14611462
OpBuilder::InsertionGuard g(rewriter);
14621463
rewriter.setInsertionPoint(packOp);
14631464

@@ -1496,10 +1497,13 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
14961497
mlir::vector::maskOperation(rewriter, transferReadOp, mask));
14971498
// ShapeCast
14981499
auto tiledPackShape = getTiledPackShape(packOp);
1499-
auto tiledPackType = VectorType::get(tiledPackShape, packOp.getDestType().getElementType());
1500-
auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedOp->getResult(0));
1500+
auto tiledPackType =
1501+
VectorType::get(tiledPackShape, packOp.getDestType().getElementType());
1502+
auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1503+
loc, tiledPackType, maskedOp->getResult(0));
15011504
auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm(packOp);
1502-
auto transposeOp = rewriter.create<vector::TransposeOp>(loc, shapeCastOp->getResult(0), tiledShapeToPackedShapePerm);
1505+
auto transposeOp = rewriter.create<vector::TransposeOp>(
1506+
loc, shapeCastOp->getResult(0), tiledShapeToPackedShapePerm);
15031507
Operation *write = rewriter.create<vector::TransferWriteOp>(
15041508
loc,
15051509
/*vector=*/transposeOp->getResult(0),
@@ -1704,7 +1708,7 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
17041708

17051709
static LogicalResult
17061710
vectorizePackOpPrecondition(tensor::PackOp packOp,
1707-
ArrayRef<int64_t> inputVectorSizes) {
1711+
ArrayRef<int64_t> inputVectorSizes) {
17081712
auto padValue = packOp.getPaddingValue();
17091713
if (padValue && getConstantIntValue(padValue) != std::nullopt) {
17101714
LDBG("pad value is not constant: " << packOp << "\n");
@@ -1877,7 +1881,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
18771881
results);
18781882
})
18791883
.Case<tensor::PackOp>([&](auto packOp) {
1880-
return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes, results);
1884+
return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
1885+
results);
18811886
})
18821887
.Default([](auto) { return failure(); });
18831888

0 commit comments

Comments
 (0)