19
19
#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
20
20
#include " mlir/Dialect/Linalg/Utils/Utils.h"
21
21
#include " mlir/Dialect/Tensor/IR/Tensor.h"
22
+ #include " mlir/Dialect/Tensor/Utils/Utils.h"
23
+ #include " mlir/Dialect/Utils/IndexingUtils.h"
22
24
#include " mlir/Dialect/Utils/StructuredOpsUtils.h"
23
25
#include " mlir/Dialect/Vector/IR/VectorOps.h"
24
26
#include " mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
25
27
#include " mlir/IR/AffineExpr.h"
28
+ #include " mlir/IR/Builders.h"
29
+ #include " mlir/IR/BuiltinTypeInterfaces.h"
30
+ #include " mlir/IR/BuiltinTypes.h"
31
+ #include " mlir/IR/OpDefinition.h"
26
32
#include " mlir/IR/PatternMatch.h"
27
33
#include " mlir/Support/LLVM.h"
28
34
#include " mlir/Transforms/RegionUtils.h"
29
35
#include " llvm/ADT/STLExtras.h"
30
36
#include " llvm/ADT/Sequence.h"
31
37
#include " llvm/ADT/SmallVector.h"
32
38
#include " llvm/ADT/TypeSwitch.h"
39
+ #include " llvm/ADT/iterator_range.h"
33
40
#include " llvm/Support/Debug.h"
41
+ #include " llvm/Support/MathExtras.h"
34
42
#include " llvm/Support/raw_ostream.h"
35
43
#include < optional>
36
44
#include < type_traits>
@@ -1393,6 +1401,164 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1393
1401
return success ();
1394
1402
}
1395
1403
1404
+ // / Given a tensor::PackOp, return the `dest` shape before any packing
1405
+ // / permutations.
1406
+ static SmallVector<int64_t > getTiledPackShape (tensor::PackOp packOp,
1407
+ ArrayRef<int64_t > destShape) {
1408
+ return applyPermutation (destShape,
1409
+ tensor::getPackInverseDestPermutation (packOp));
1410
+ }
1411
+
1412
+ // / Create a TransferReadOp from `source` with static shape `readShape`. If the
1413
+ // / vector type for the read is not the same as the type of `source`, then a
1414
+ // / mask is created on the read.
1415
+ static Value createReadOrMaskedRead (OpBuilder &builder, Location loc,
1416
+ Value source, ArrayRef<int64_t > readShape,
1417
+ Value padValue) {
1418
+ assert (llvm::none_of (readShape,
1419
+ [](int64_t s) { return s == ShapedType::kDynamic ; }));
1420
+ auto sourceShape = dyn_cast<ShapedType>(source.getType ()).getShape ();
1421
+ assert (sourceShape.size () == readShape.size ());
1422
+ auto maskType = VectorType::get (readShape, builder.getI1Type ());
1423
+ auto vectorType = VectorType::get (readShape, padValue.getType ());
1424
+ int64_t readRank = readShape.size ();
1425
+ auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1426
+ auto transferReadOp = builder.create <vector::TransferReadOp>(
1427
+ loc,
1428
+ /* vectorType=*/ vectorType,
1429
+ /* source=*/ source,
1430
+ /* indices=*/ SmallVector<Value>(readRank, zero),
1431
+ /* padding=*/ padValue,
1432
+ /* inBounds=*/ SmallVector<bool >(readRank, true ));
1433
+ if (llvm::equal (readShape, sourceShape)) {
1434
+ return transferReadOp;
1435
+ }
1436
+ SmallVector<OpFoldResult> mixedSourceDims =
1437
+ tensor::getMixedSizes (builder, loc, source);
1438
+ Value mask =
1439
+ builder.create <vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
1440
+ return mlir::vector::maskOperation (builder, transferReadOp, mask)
1441
+ ->getResult (0 );
1442
+ }
1443
+
1444
+ // / Given an input, the mixed destSizes, and the vector sizes for vectorization,
1445
+ // / create an empty destination tensor and create a TransferWriteOp from the
1446
+ // / input to the empty tensor. If the destination shape is not the same as the
1447
+ // / inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1448
+ // / mask for the write.
1449
+ static Operation *createWriteOrMaskedWrite (OpBuilder &builder, Location loc,
1450
+ Value input,
1451
+ SmallVector<OpFoldResult> destSizes,
1452
+ ArrayRef<int64_t > inputVectorSizes) {
1453
+ auto inputType = cast<VectorType>(input.getType ());
1454
+ Value dest = builder.create <tensor::EmptyOp>(loc, destSizes,
1455
+ inputType.getElementType ());
1456
+ int64_t rank = cast<ShapedType>(dest.getType ()).getRank ();
1457
+ auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1458
+ Operation *write = builder.create <vector::TransferWriteOp>(
1459
+ loc,
1460
+ /* vector=*/ input,
1461
+ /* source=*/ dest,
1462
+ /* indices=*/ SmallVector<Value>(rank, zero),
1463
+ /* inBounds=*/ SmallVector<bool >(rank, true ));
1464
+ auto destShape = cast<ShapedType>(dest.getType ()).getShape ();
1465
+ assert (llvm::none_of (
1466
+ destShape.drop_front (inputVectorSizes.size ()),
1467
+ [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1468
+ " Only dims aligned with inputVectorSizes may be dynamic" );
1469
+ bool needMaskForWrite = !llvm::equal (
1470
+ inputVectorSizes, destShape.take_front (inputVectorSizes.size ()));
1471
+ if (needMaskForWrite) {
1472
+ SmallVector<int64_t > writeMaskShape;
1473
+ writeMaskShape.append (inputVectorSizes.begin (), inputVectorSizes.end ());
1474
+ writeMaskShape.append (destShape.begin () + inputVectorSizes.size (),
1475
+ destShape.end ());
1476
+ auto writeMaskType = VectorType::get (writeMaskShape, builder.getI1Type ());
1477
+ Value maskForWrite =
1478
+ builder.create <vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1479
+ write = mlir::vector::maskOperation (builder, write, maskForWrite);
1480
+ }
1481
+ return write;
1482
+ }
1483
+
1484
+ // / Vectorize tensor::PackOp with (1) static innerTiles and (2) constant
1485
+ // / padding value into:
1486
+ // / masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1487
+ // / As in the following example:
1488
+ // /
1489
+ // / %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1490
+ // / into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1491
+ // /
1492
+ // / This pack would be vectorized to:
1493
+ // /
1494
+ // / %load = vector.mask %mask {
1495
+ // / vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1496
+ // / {in_bounds = [true, true, true]} :
1497
+ // / tensor<32x7x16xf32>, vector<32x8x16xf32>
1498
+ // / } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1499
+ // / %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1500
+ // / to vector<32x4x2x1x16xf32>
1501
+ // / %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1502
+ // / : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1503
+ // / %write = vector.transfer_write %transpose,
1504
+ // / %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1505
+ // / {in_bounds = [true, true, true, true, true]}
1506
+ // / : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1507
+ static LogicalResult
1508
+ vectorizeAsTensorPackOp (RewriterBase &rewriter, tensor::PackOp packOp,
1509
+ ArrayRef<int64_t > inputVectorSizes,
1510
+ SmallVectorImpl<Value> &newResults) {
1511
+ OpBuilder::InsertionGuard g (rewriter);
1512
+ rewriter.setInsertionPoint (packOp);
1513
+
1514
+ Location loc = packOp.getLoc ();
1515
+ auto padValue = packOp.getPaddingValue ();
1516
+ if (!padValue) {
1517
+ padValue = rewriter.create <arith::ConstantOp>(
1518
+ loc, rewriter.getZeroAttr (packOp.getSourceType ().getElementType ()));
1519
+ }
1520
+ ReifiedRankedShapedTypeDims reifiedReturnShapes;
1521
+ LogicalResult status =
1522
+ cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation ())
1523
+ .reifyResultShapes (rewriter, reifiedReturnShapes);
1524
+ (void )status; // prevent unused variable warning on non-assert builds.
1525
+ assert (succeeded (status) && " failed to reify result shapes" );
1526
+
1527
+ // Create masked TransferReadOp.
1528
+ SmallVector<int64_t > inputShape (inputVectorSizes);
1529
+ auto innerTiles = packOp.getStaticInnerTiles ();
1530
+ auto innerDimsPos = packOp.getInnerDimsPos ();
1531
+ auto outerDimsPerm = packOp.getOuterDimsPerm ();
1532
+ if (!outerDimsPerm.empty ())
1533
+ applyPermutationToVector (inputShape,
1534
+ invertPermutationVector (outerDimsPerm));
1535
+ for (auto [idx, size] : enumerate(innerTiles))
1536
+ inputShape[innerDimsPos[idx]] *= size;
1537
+ auto maskedRead = createReadOrMaskedRead (rewriter, loc, packOp.getSource (),
1538
+ inputShape, padValue);
1539
+
1540
+ // Create ShapeCastOp.
1541
+ SmallVector<int64_t > destShape (inputVectorSizes);
1542
+ destShape.append (innerTiles.begin (), innerTiles.end ());
1543
+ auto tiledPackType = VectorType::get (getTiledPackShape (packOp, destShape),
1544
+ packOp.getDestType ().getElementType ());
1545
+ auto shapeCastOp =
1546
+ rewriter.create <vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1547
+
1548
+ // Create TransposeOp.
1549
+ auto destPermutation =
1550
+ invertPermutationVector (tensor::getPackInverseDestPermutation (packOp));
1551
+ auto transposeOp = rewriter.create <vector::TransposeOp>(
1552
+ loc, shapeCastOp.getResult (), destPermutation);
1553
+
1554
+ // Create TransferWriteOp.
1555
+ Operation *write =
1556
+ createWriteOrMaskedWrite (rewriter, loc, transposeOp.getResult (),
1557
+ reifiedReturnShapes[0 ], inputVectorSizes);
1558
+ newResults.push_back (write->getResult (0 ));
1559
+ return success ();
1560
+ }
1561
+
1396
1562
// / Vectorize a `padOp` with (1) static result type, (2) constant padding value
1397
1563
// / and (3) all-zero lowPad to
1398
1564
// / `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
@@ -1402,9 +1568,6 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1402
1568
SmallVectorImpl<Value> &newResults) {
1403
1569
auto padValue = padOp.getConstantPaddingValue ();
1404
1570
Location loc = padOp.getLoc ();
1405
- int64_t rank = inputVectorSizes.size ();
1406
- auto maskType = VectorType::get (inputVectorSizes, rewriter.getI1Type ());
1407
- auto vectorType = VectorType::get (inputVectorSizes, padValue.getType ());
1408
1571
1409
1572
// transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1410
1573
OpBuilder::InsertionGuard g (rewriter);
@@ -1416,36 +1579,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1416
1579
.reifyResultShapes (rewriter, reifiedReturnShapes);
1417
1580
(void )status; // prevent unused variable warning on non-assert builds
1418
1581
assert (succeeded (status) && " failed to reify result shapes" );
1419
- auto emptyOp = rewriter.create <tensor::EmptyOp>(loc, reifiedReturnShapes[0 ],
1420
- padValue.getType ());
1421
- SmallVector<OpFoldResult> mixedSourceDims =
1422
- tensor::getMixedSizes (rewriter, loc, padOp.getSource ());
1423
- Value mask =
1424
- rewriter.create <vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
1425
- auto zero = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
1426
- auto transferReadOp = rewriter.create <vector::TransferReadOp>(
1427
- loc,
1428
- /* vectorType=*/ vectorType,
1429
- /* source=*/ padOp.getSource (),
1430
- /* indices=*/ SmallVector<Value>(rank, zero),
1431
- /* padding=*/ padValue,
1432
- /* inBounds=*/ SmallVector<bool >(rank, true ));
1433
- auto maskedOp = cast<vector::MaskOp>(
1434
- mlir::vector::maskOperation (rewriter, transferReadOp, mask));
1435
- Operation *write = rewriter.create <vector::TransferWriteOp>(
1436
- loc,
1437
- /* vector=*/ maskedOp->getResult (0 ),
1438
- /* source=*/ emptyOp,
1439
- /* indices=*/ SmallVector<Value>(rank, zero),
1440
- /* inBounds=*/ SmallVector<bool >(rank, true ));
1441
- bool needMaskForWrite = llvm::any_of (
1442
- llvm::zip_equal (inputVectorSizes, padOp.getResultType ().getShape ()),
1443
- [](auto it) { return std::get<0 >(it) != std::get<1 >(it); });
1444
- if (needMaskForWrite) {
1445
- Value maskForWrite = rewriter.create <vector::CreateMaskOp>(
1446
- loc, maskType, reifiedReturnShapes[0 ]);
1447
- write = mlir::vector::maskOperation (rewriter, write, maskForWrite);
1448
- }
1582
+ auto maskedRead = createReadOrMaskedRead (rewriter, loc, padOp.getSource (),
1583
+ inputVectorSizes, padValue);
1584
+ Operation *write = createWriteOrMaskedWrite (
1585
+ rewriter, loc, maskedRead, reifiedReturnShapes[0 ], inputVectorSizes);
1449
1586
newResults.push_back (write->getResult (0 ));
1450
1587
return success ();
1451
1588
}
@@ -1585,6 +1722,32 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
1585
1722
return success ();
1586
1723
}
1587
1724
1725
+ // / TODO: Use a matcher to check for a constant padding value.
1726
+ static LogicalResult
1727
+ vectorizePackOpPrecondition (tensor::PackOp packOp,
1728
+ ArrayRef<int64_t > inputVectorSizes) {
1729
+ auto padValue = packOp.getPaddingValue ();
1730
+ if (padValue && !padValue.getDefiningOp <arith::ConstantOp>()) {
1731
+ LDBG (" pad value is not constant: " << packOp << " \n " );
1732
+ return failure ();
1733
+ }
1734
+
1735
+ ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
1736
+ if (failed (isValidMaskedInputVector (
1737
+ resultTensorShape.take_front (packOp.getSourceRank ()),
1738
+ inputVectorSizes)))
1739
+ return failure ();
1740
+
1741
+ if (llvm::any_of (packOp.getInnerTiles (), [](OpFoldResult v) {
1742
+ return !getConstantIntValue (v).has_value ();
1743
+ })) {
1744
+ LDBG (" inner_tiles must be constant: " << packOp << " \n " );
1745
+ return failure ();
1746
+ }
1747
+
1748
+ return success ();
1749
+ }
1750
+
1588
1751
static LogicalResult
1589
1752
vectorizePadOpPrecondition (tensor::PadOp padOp,
1590
1753
ArrayRef<int64_t > inputVectorSizes) {
@@ -1644,6 +1807,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
1644
1807
.Case <tensor::PadOp>([&](auto padOp) {
1645
1808
return vectorizePadOpPrecondition (padOp, inputVectorSizes);
1646
1809
})
1810
+ .Case <tensor::PackOp>([&](auto packOp) {
1811
+ return vectorizePackOpPrecondition (packOp, inputVectorSizes);
1812
+ })
1647
1813
.Default ([](auto ) { return failure (); });
1648
1814
}
1649
1815
@@ -1732,6 +1898,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
1732
1898
return vectorizeAsTensorPadOp (rewriter, padOp, inputVectorSizes,
1733
1899
results);
1734
1900
})
1901
+ .Case <tensor::PackOp>([&](auto packOp) {
1902
+ return vectorizeAsTensorPackOp (rewriter, packOp, inputVectorSizes,
1903
+ results);
1904
+ })
1735
1905
.Default ([](auto ) { return failure (); });
1736
1906
1737
1907
if (failed (vectorizeResult)) {
0 commit comments