19
19
#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
20
20
#include " mlir/Dialect/Linalg/Utils/Utils.h"
21
21
#include " mlir/Dialect/Tensor/IR/Tensor.h"
22
+ #include " mlir/Dialect/Tensor/Utils/Utils.h"
22
23
#include " mlir/Dialect/Utils/IndexingUtils.h"
23
24
#include " mlir/Dialect/Utils/StructuredOpsUtils.h"
24
25
#include " mlir/Dialect/Vector/IR/VectorOps.h"
@@ -1400,83 +1401,47 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1400
1401
return success ();
1401
1402
}
1402
1403
1403
- // / Given a tensor::PackOp, return the permutation from the "tiled"
1404
- // / shape to the "packed" shape, defined as the following:
1405
- // / The "packed" shape is the same as the `dest` shape of the pack op.
1406
- // / The "tiled" shape is a permutation of the `dest` shape such that
1407
- // / each outer dimension is in the original `source` order, and the
1408
- // / inner_tile dimensions immediately follow their corresponding outer
1409
- // / dimension.
1410
- // / i.e. for the following tensor.pack:
1411
- // / ```mlir
1412
- // / %pack = tensor.pack %0 padding_value(%1)
1413
- // / outer_dims_perm = [0, 2, 1]
1414
- // / inner_dims_pos = [2, 1]
1415
- // / inner_tiles = [16, 2]
1416
- // / into %2 : tensor<32x8x16> -> tensor<32x1x4x16x2>
1417
- // / ```
1418
- // / The "packed" shape is `32x1x4x16x2`
1419
- // / The "tiled" shape is `32x(4x2)x(1x16)`
1420
- static SmallVector<int64_t >
1421
- getTiledShapeToPackedShapePerm (tensor::PackOp packOp) {
1422
- auto innerTiles = packOp.getInnerTiles ();
1423
- int64_t srcRank = packOp.getSourceRank ();
1424
- auto innerDimsPos = packOp.getInnerDimsPos ();
1425
- if (innerDimsPos.empty ())
1426
- innerDimsPos = to_vector (llvm::seq<int64_t >(innerTiles.size ()));
1427
- auto outerDimsPerm = packOp.getOuterDimsPerm ();
1428
- if (outerDimsPerm.empty ())
1429
- outerDimsPerm = to_vector (llvm::seq<int64_t >(srcRank));
1430
- auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t {
1431
- int64_t srcIdx;
1432
- if (idx >= srcRank)
1433
- srcIdx = innerDimsPos[idx - srcRank];
1434
- else
1435
- srcIdx = outerDimsPerm[idx];
1436
- int64_t tiledIdx = srcIdx;
1437
- for (int64_t pos : innerDimsPos)
1438
- if (pos < srcIdx)
1439
- tiledIdx++;
1440
- if (idx >= srcRank)
1441
- tiledIdx++;
1442
- return tiledIdx;
1443
- };
1444
- SmallVector<int64_t > perm;
1445
- for (size_t i = 0 ; i < packOp.getDestRank (); i++)
1446
- perm.push_back (packedIdxToTiledIdx (i));
1447
- return perm;
1448
- }
1449
-
1450
- // / Given a tensor::PackOp, return the "tiled" `dest` shape as described
1451
- // / above in `getTiledShapeToPackedShapePerm`.
1404
+ // / Given a tensor::PackOp, return the `dest` shape before any packing
1405
+ // / permutations.
1452
1406
static SmallVector<int64_t > getTiledPackShape (tensor::PackOp packOp,
1453
1407
ArrayRef<int64_t > destShape) {
1454
- auto perm = getTiledShapeToPackedShapePerm (packOp);
1455
- return applyPermutation (destShape, invertPermutationVector (perm ));
1408
+ return applyPermutation (destShape,
1409
+ tensor::getPackInverseDestPermutation (packOp ));
1456
1410
}
1457
1411
1458
- // / Create a masked TransferReadOp from `source` with shape `readShape`.
1459
- static vector::MaskOp createMaskedTransferRead (OpBuilder &builder, Location loc,
1460
- Value source,
1461
- ArrayRef<int64_t > readShape,
1462
- Value padValue) {
1412
+ // / Create a TransferReadOp from `source` with static shape `readShape`. If the
1413
+ // / vector type for the read is not the same as the type of `source`, then a
1414
+ // / mask is created on the read.
1415
+ static Value createReadOrMaskedRead (OpBuilder &builder, Location loc,
1416
+ Value source, ArrayRef<int64_t > readShape,
1417
+ Value padValue) {
1418
+ assert (llvm::none_of (readShape,
1419
+ [](int64_t s) { return s == ShapedType::kDynamic ; }));
1463
1420
auto maskType = VectorType::get (readShape, builder.getI1Type ());
1464
1421
auto vectorType = VectorType::get (readShape, padValue.getType ());
1465
- SmallVector<OpFoldResult> mixedSourceDims =
1466
- tensor::getMixedSizes (builder, loc, source);
1467
- Value mask =
1468
- builder.create <vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
1469
- auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1470
1422
int64_t readRank = readShape.size ();
1423
+ auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1471
1424
auto transferReadOp = builder.create <vector::TransferReadOp>(
1472
1425
loc,
1473
1426
/* vectorType=*/ vectorType,
1474
1427
/* source=*/ source,
1475
1428
/* indices=*/ SmallVector<Value>(readRank, zero),
1476
1429
/* padding=*/ padValue,
1477
1430
/* inBounds=*/ SmallVector<bool >(readRank, true ));
1478
- return cast<vector::MaskOp>(
1479
- mlir::vector::maskOperation (builder, transferReadOp, mask));
1431
+ auto sourceShape = llvm::dyn_cast<ShapedType>(source.getType ()).getShape ();
1432
+ if (sourceShape.size () == readShape.size () &&
1433
+ llvm::all_of (llvm::zip_equal (readShape, sourceShape), [](auto it) {
1434
+ return std::get<0 >(it) != ShapedType::kDynamic &&
1435
+ std::get<0 >(it) == std::get<1 >(it);
1436
+ })) {
1437
+ return transferReadOp;
1438
+ }
1439
+ SmallVector<OpFoldResult> mixedSourceDims =
1440
+ tensor::getMixedSizes (builder, loc, source);
1441
+ Value mask =
1442
+ builder.create <vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
1443
+ return mlir::vector::maskOperation (builder, transferReadOp, mask)
1444
+ ->getResult (0 );
1480
1445
}
1481
1446
1482
1447
// / Given an input, the mixed destSizes, and the vector sizes for vectorization,
@@ -1500,9 +1465,14 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
1500
1465
/* indices=*/ SmallVector<Value>(rank, zero),
1501
1466
/* inBounds=*/ SmallVector<bool >(rank, true ));
1502
1467
auto destShape = cast<ShapedType>(dest.getType ()).getShape ();
1503
- bool needMaskForWrite =
1504
- llvm::any_of (llvm::zip (inputVectorSizes, destShape),
1505
- [](auto it) { return std::get<0 >(it) != std::get<1 >(it); });
1468
+ assert (llvm::none_of (
1469
+ destShape.drop_front (inputVectorSizes.size ()),
1470
+ [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1471
+ " Only dims aligned with inputVectorSizes may be dynamic" );
1472
+ bool needMaskForWrite = llvm::any_of (
1473
+ llvm::zip_equal (inputVectorSizes,
1474
+ destShape.take_front (inputVectorSizes.size ())),
1475
+ [](auto it) { return std::get<0 >(it) != std::get<1 >(it); });
1506
1476
if (needMaskForWrite) {
1507
1477
SmallVector<int64_t > writeMaskShape;
1508
1478
writeMaskShape.append (inputVectorSizes.begin (), inputVectorSizes.end ());
@@ -1517,11 +1487,28 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
1517
1487
}
1518
1488
1519
1489
// / Vectorize tensor::PackOp with (1) static innerTiles and (2) constant
1520
- // / padding value into
1521
- // / transfer_write_in_bounds(
1522
- // / transpose(
1523
- // / shape_cast(
1524
- // / transfer_read_masked(pack_source, pad_value))))
1490
+ // / padding value into:
1491
+ // / masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1492
+ // / As in the following example:
1493
+ // / ```mlir
1494
+ // / %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1495
+ // / into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1496
+ // / ```
1497
+ // / This pack would be vectorized to:
1498
+ // / ```mlir
1499
+ // / %load = vector.mask %mask {
1500
+ // / vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1501
+ // / {in_bounds = [true, true, true]} :
1502
+ // / tensor<32x7x16xf32>, vector<32x8x16xf32>
1503
+ // / } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1504
+ // / %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1505
+ // / to vector<32x4x2x1x16xf32>
1506
+ // / %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1507
+ // / : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1508
+ // / %write = vector.transfer_write %transpose,
1509
+ // / %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1510
+ // / {in_bounds = [true, true, true, true, true]}
1511
+ // / : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1525
1512
static LogicalResult
1526
1513
vectorizeAsTensorPackOp (RewriterBase &rewriter, tensor::PackOp packOp,
1527
1514
ArrayRef<int64_t > inputVectorSizes,
@@ -1539,10 +1526,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1539
1526
LogicalResult status =
1540
1527
cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation ())
1541
1528
.reifyResultShapes (rewriter, reifiedReturnShapes);
1542
- (void )status; // prevent unused variable warning on non-assert builds
1529
+ (void )status; // prevent unused variable warning on non-assert builds.
1543
1530
assert (succeeded (status) && " failed to reify result shapes" );
1544
1531
1545
- // Create masked TransferReadOp
1532
+ // Create masked TransferReadOp.
1546
1533
SmallVector<int64_t > inputShape (inputVectorSizes);
1547
1534
auto innerTiles = packOp.getStaticInnerTiles ();
1548
1535
auto innerDimsPos = packOp.getInnerDimsPos ();
@@ -1552,23 +1539,24 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1552
1539
invertPermutationVector (outerDimsPerm));
1553
1540
for (auto [idx, size] : enumerate(innerTiles))
1554
1541
inputShape[innerDimsPos[idx]] *= size;
1555
- auto maskedOp = createMaskedTransferRead (rewriter, loc, packOp.getSource (),
1542
+ auto maskedRead = createReadOrMaskedRead (rewriter, loc, packOp.getSource (),
1556
1543
inputShape, padValue);
1557
1544
1558
- // Create ShapeCastOp
1545
+ // Create ShapeCastOp.
1559
1546
SmallVector<int64_t > destShape (inputVectorSizes);
1560
1547
destShape.append (innerTiles.begin (), innerTiles.end ());
1561
1548
auto tiledPackType = VectorType::get (getTiledPackShape (packOp, destShape),
1562
1549
packOp.getDestType ().getElementType ());
1563
- auto shapeCastOp = rewriter. create <vector::ShapeCastOp>(
1564
- loc, tiledPackType, maskedOp-> getResult ( 0 ) );
1550
+ auto shapeCastOp =
1551
+ rewriter. create <vector::ShapeCastOp>( loc, tiledPackType, maskedRead );
1565
1552
1566
- // Create TransposeOp
1567
- auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm (packOp);
1553
+ // Create TransposeOp.
1554
+ auto destPermutation =
1555
+ invertPermutationVector (tensor::getPackInverseDestPermutation (packOp));
1568
1556
auto transposeOp = rewriter.create <vector::TransposeOp>(
1569
- loc, shapeCastOp.getResult (), tiledShapeToPackedShapePerm );
1557
+ loc, shapeCastOp.getResult (), destPermutation );
1570
1558
1571
- // Create TransferWriteOp
1559
+ // Create TransferWriteOp.
1572
1560
Operation *write =
1573
1561
createWriteOrMaskedWrite (rewriter, loc, transposeOp.getResult (),
1574
1562
reifiedReturnShapes[0 ], inputVectorSizes);
@@ -1596,11 +1584,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1596
1584
.reifyResultShapes (rewriter, reifiedReturnShapes);
1597
1585
(void )status; // prevent unused variable warning on non-assert builds
1598
1586
assert (succeeded (status) && " failed to reify result shapes" );
1599
- auto maskedOp = createMaskedTransferRead (rewriter, loc, padOp.getSource (),
1587
+ auto maskedRead = createReadOrMaskedRead (rewriter, loc, padOp.getSource (),
1600
1588
inputVectorSizes, padValue);
1601
- Operation *write =
1602
- createWriteOrMaskedWrite (rewriter, loc, maskedOp->getResult (0 ),
1603
- reifiedReturnShapes[0 ], inputVectorSizes);
1589
+ Operation *write = createWriteOrMaskedWrite (
1590
+ rewriter, loc, maskedRead, reifiedReturnShapes[0 ], inputVectorSizes);
1604
1591
newResults.push_back (write->getResult (0 ));
1605
1592
return success ();
1606
1593
}
@@ -1740,11 +1727,12 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
1740
1727
return success ();
1741
1728
}
1742
1729
1730
+ // / TODO: Use a matcher to check for a constant padding value.
1743
1731
static LogicalResult
1744
1732
vectorizePackOpPrecondition (tensor::PackOp packOp,
1745
1733
ArrayRef<int64_t > inputVectorSizes) {
1746
1734
auto padValue = packOp.getPaddingValue ();
1747
- if (padValue && !getConstantIntValue ( padValue). has_value ()) {
1735
+ if (padValue && !padValue. getDefiningOp <arith::ConstantOp> ()) {
1748
1736
LDBG (" pad value is not constant: " << packOp << " \n " );
1749
1737
return failure ();
1750
1738
}
0 commit comments