24
24
#include " mlir/Dialect/Vector/IR/VectorOps.h"
25
25
#include " mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
26
26
#include " mlir/IR/AffineExpr.h"
27
+ #include " mlir/IR/Builders.h"
27
28
#include " mlir/IR/BuiltinTypeInterfaces.h"
28
29
#include " mlir/IR/BuiltinTypes.h"
29
30
#include " mlir/IR/OpDefinition.h"
@@ -1454,7 +1455,73 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp) {
1454
1455
return applyPermutation (destShape, invertPermutationVector (perm));
1455
1456
}
1456
1457
1457
- // /
1458
+ // / Create a masked TransferReadOp from `source` with shape `readShape`.
1459
+ static vector::MaskOp createMaskedTransferRead (OpBuilder &builder, Location loc,
1460
+ Value source,
1461
+ ArrayRef<int64_t > readShape,
1462
+ Value padValue) {
1463
+ auto maskType = VectorType::get (readShape, builder.getI1Type ());
1464
+ auto vectorType = VectorType::get (readShape, padValue.getType ());
1465
+ SmallVector<OpFoldResult> mixedSourceDims =
1466
+ tensor::getMixedSizes (builder, loc, source);
1467
+ Value mask =
1468
+ builder.create <vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
1469
+ auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1470
+ int64_t readRank = readShape.size ();
1471
+ auto transferReadOp = builder.create <vector::TransferReadOp>(
1472
+ loc,
1473
+ /* vectorType=*/ vectorType,
1474
+ /* source=*/ source,
1475
+ /* indices=*/ SmallVector<Value>(readRank, zero),
1476
+ /* padding=*/ padValue,
1477
+ /* inBounds=*/ SmallVector<bool >(readRank, true ));
1478
+ return cast<vector::MaskOp>(
1479
+ mlir::vector::maskOperation (builder, transferReadOp, mask));
1480
+ }
1481
+
1482
+ // / Given an input, the mixed destSizes, and the vector sizes for vectorization,
1483
+ // / create an empty destination tensor and create a TransferWriteOp from the
1484
+ // / input to the empty tensor. If the destination shape is not the same as the
1485
+ // / inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1486
+ // / mask for the write.
1487
+ static Operation *createWriteOrMaskedWrite (OpBuilder &builder, Location loc,
1488
+ Value input,
1489
+ SmallVector<OpFoldResult> destSizes,
1490
+ ArrayRef<int64_t > inputVectorSizes) {
1491
+ auto inputType = cast<VectorType>(input.getType ());
1492
+ Value dest = builder.create <tensor::EmptyOp>(loc, destSizes,
1493
+ inputType.getElementType ());
1494
+ int64_t rank = cast<ShapedType>(dest.getType ()).getRank ();
1495
+ auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1496
+ Operation *write = builder.create <vector::TransferWriteOp>(
1497
+ loc,
1498
+ /* vector=*/ input,
1499
+ /* source=*/ dest,
1500
+ /* indices=*/ SmallVector<Value>(rank, zero),
1501
+ /* inBounds=*/ SmallVector<bool >(rank, true ));
1502
+ auto destShape = cast<ShapedType>(dest.getType ()).getShape ();
1503
+ bool needMaskForWrite =
1504
+ llvm::any_of (llvm::zip (inputVectorSizes, destShape),
1505
+ [](auto it) { return std::get<0 >(it) != std::get<1 >(it); });
1506
+ if (needMaskForWrite) {
1507
+ SmallVector<int64_t > writeMaskShape;
1508
+ writeMaskShape.append (inputVectorSizes.begin (), inputVectorSizes.end ());
1509
+ writeMaskShape.append (destShape.begin () + inputVectorSizes.size (),
1510
+ destShape.end ());
1511
+ auto writeMaskType = VectorType::get (writeMaskShape, builder.getI1Type ());
1512
+ Value maskForWrite =
1513
+ builder.create <vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1514
+ write = mlir::vector::maskOperation (builder, write, maskForWrite);
1515
+ }
1516
+ return write;
1517
+ }
1518
+
1519
+ // / Vectorize tensor::PackOp with (1) static innerTiles and (2) constant
1520
+ // / padding value into
1521
+ // / transfer_write_in_bounds(
1522
+ // / transpose(
1523
+ // / shape_cast(
1524
+ // / transfer_read_masked(pack_source, pad_value))))
1458
1525
static LogicalResult
1459
1526
vectorizeAsTensorPackOp (RewriterBase &rewriter, tensor::PackOp packOp,
1460
1527
ArrayRef<int64_t > inputVectorSizes,
@@ -1468,48 +1535,41 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1468
1535
padValue = rewriter.create <arith::ConstantOp>(
1469
1536
loc, rewriter.getZeroAttr (packOp.getSourceType ().getElementType ()));
1470
1537
}
1471
- int64_t inputRank = inputVectorSizes.size ();
1472
- int64_t outputRank = packOp.getDestRank ();
1473
- auto maskType = VectorType::get (inputVectorSizes, rewriter.getI1Type ());
1474
- auto vectorType = VectorType::get (inputVectorSizes, padValue.getType ());
1475
-
1476
1538
ReifiedRankedShapedTypeDims reifiedReturnShapes;
1477
1539
LogicalResult status =
1478
1540
cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation ())
1479
1541
.reifyResultShapes (rewriter, reifiedReturnShapes);
1480
1542
(void )status; // prevent unused variable warning on non-assert builds
1481
1543
assert (succeeded (status) && " failed to reify result shapes" );
1482
- auto emptyOp = rewriter.create <tensor::EmptyOp>(loc, reifiedReturnShapes[0 ],
1483
- padValue.getType ());
1484
- SmallVector<OpFoldResult> mixedSourceDims =
1485
- tensor::getMixedSizes (rewriter, loc, packOp.getSource ());
1486
- Value mask =
1487
- rewriter.create <vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
1488
- auto zero = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
1489
- auto transferReadOp = rewriter.create <vector::TransferReadOp>(
1490
- loc,
1491
- /* vectorType=*/ vectorType,
1492
- /* source=*/ packOp.getSource (),
1493
- /* indices=*/ SmallVector<Value>(inputRank, zero),
1494
- /* padding=*/ padValue,
1495
- /* inBounds=*/ SmallVector<bool >(inputRank, true ));
1496
- auto maskedOp = cast<vector::MaskOp>(
1497
- mlir::vector::maskOperation (rewriter, transferReadOp, mask));
1498
- // ShapeCast
1499
- auto tiledPackShape = getTiledPackShape (packOp);
1500
- auto tiledPackType =
1501
- VectorType::get (tiledPackShape, packOp.getDestType ().getElementType ());
1544
+
1545
+ // Create masked TransferReadOp
1546
+ SmallVector<int64_t > inputShape (inputVectorSizes);
1547
+ auto innerTiles = packOp.getStaticInnerTiles ();
1548
+ auto innerDimsPos = packOp.getInnerDimsPos ();
1549
+ auto outerDimsPerm = packOp.getOuterDimsPerm ();
1550
+ if (!outerDimsPerm.empty ())
1551
+ applyPermutationToVector (inputShape,
1552
+ invertPermutationVector (outerDimsPerm));
1553
+ for (auto [idx, size] : enumerate(innerTiles))
1554
+ inputShape[innerDimsPos[idx]] *= size;
1555
+ auto maskedOp = createMaskedTransferRead (rewriter, loc, packOp.getSource (),
1556
+ inputShape, padValue);
1557
+
1558
+ // Create ShapeCastOp
1559
+ auto tiledPackType = VectorType::get (getTiledPackShape (packOp),
1560
+ packOp.getDestType ().getElementType ());
1502
1561
auto shapeCastOp = rewriter.create <vector::ShapeCastOp>(
1503
1562
loc, tiledPackType, maskedOp->getResult (0 ));
1563
+
1564
+ // Create TransposeOp
1504
1565
auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm (packOp);
1505
1566
auto transposeOp = rewriter.create <vector::TransposeOp>(
1506
- loc, shapeCastOp->getResult (0 ), tiledShapeToPackedShapePerm);
1507
- Operation *write = rewriter.create <vector::TransferWriteOp>(
1508
- loc,
1509
- /* vector=*/ transposeOp->getResult (0 ),
1510
- /* source=*/ emptyOp,
1511
- /* indices=*/ SmallVector<Value>(outputRank, zero),
1512
- /* inBounds=*/ SmallVector<bool >(outputRank, true ));
1567
+ loc, shapeCastOp.getResult (), tiledShapeToPackedShapePerm);
1568
+
1569
+ // Create TransferWriteOp
1570
+ Operation *write =
1571
+ createWriteOrMaskedWrite (rewriter, loc, transposeOp.getResult (),
1572
+ reifiedReturnShapes[0 ], inputVectorSizes);
1513
1573
newResults.push_back (write->getResult (0 ));
1514
1574
return success ();
1515
1575
}
@@ -1523,9 +1583,6 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1523
1583
SmallVectorImpl<Value> &newResults) {
1524
1584
auto padValue = padOp.getConstantPaddingValue ();
1525
1585
Location loc = padOp.getLoc ();
1526
- int64_t rank = inputVectorSizes.size ();
1527
- auto maskType = VectorType::get (inputVectorSizes, rewriter.getI1Type ());
1528
- auto vectorType = VectorType::get (inputVectorSizes, padValue.getType ());
1529
1586
1530
1587
// transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1531
1588
OpBuilder::InsertionGuard g (rewriter);
@@ -1537,36 +1594,11 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1537
1594
.reifyResultShapes (rewriter, reifiedReturnShapes);
1538
1595
(void )status; // prevent unused variable warning on non-assert builds
1539
1596
assert (succeeded (status) && " failed to reify result shapes" );
1540
- auto emptyOp = rewriter.create <tensor::EmptyOp>(loc, reifiedReturnShapes[0 ],
1541
- padValue.getType ());
1542
- SmallVector<OpFoldResult> mixedSourceDims =
1543
- tensor::getMixedSizes (rewriter, loc, padOp.getSource ());
1544
- Value mask =
1545
- rewriter.create <vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
1546
- auto zero = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
1547
- auto transferReadOp = rewriter.create <vector::TransferReadOp>(
1548
- loc,
1549
- /* vectorType=*/ vectorType,
1550
- /* source=*/ padOp.getSource (),
1551
- /* indices=*/ SmallVector<Value>(rank, zero),
1552
- /* padding=*/ padValue,
1553
- /* inBounds=*/ SmallVector<bool >(rank, true ));
1554
- auto maskedOp = cast<vector::MaskOp>(
1555
- mlir::vector::maskOperation (rewriter, transferReadOp, mask));
1556
- Operation *write = rewriter.create <vector::TransferWriteOp>(
1557
- loc,
1558
- /* vector=*/ maskedOp->getResult (0 ),
1559
- /* source=*/ emptyOp,
1560
- /* indices=*/ SmallVector<Value>(rank, zero),
1561
- /* inBounds=*/ SmallVector<bool >(rank, true ));
1562
- bool needMaskForWrite = llvm::any_of (
1563
- llvm::zip_equal (inputVectorSizes, padOp.getResultType ().getShape ()),
1564
- [](auto it) { return std::get<0 >(it) != std::get<1 >(it); });
1565
- if (needMaskForWrite) {
1566
- Value maskForWrite = rewriter.create <vector::CreateMaskOp>(
1567
- loc, maskType, reifiedReturnShapes[0 ]);
1568
- write = mlir::vector::maskOperation (rewriter, write, maskForWrite);
1569
- }
1597
+ auto maskedOp = createMaskedTransferRead (rewriter, loc, padOp.getSource (),
1598
+ inputVectorSizes, padValue);
1599
+ Operation *write =
1600
+ createWriteOrMaskedWrite (rewriter, loc, maskedOp->getResult (0 ),
1601
+ reifiedReturnShapes[0 ], inputVectorSizes);
1570
1602
newResults.push_back (write->getResult (0 ));
1571
1603
return success ();
1572
1604
}
@@ -1710,18 +1742,19 @@ static LogicalResult
1710
1742
vectorizePackOpPrecondition (tensor::PackOp packOp,
1711
1743
ArrayRef<int64_t > inputVectorSizes) {
1712
1744
auto padValue = packOp.getPaddingValue ();
1713
- if (padValue && getConstantIntValue (padValue) != std::nullopt ) {
1745
+ if (padValue && ! getConstantIntValue (padValue). has_value () ) {
1714
1746
LDBG (" pad value is not constant: " << packOp << " \n " );
1715
1747
return failure ();
1716
1748
}
1717
1749
1718
- ArrayRef<int64_t > resultTensorShape = packOp.getSourceType ().getShape ();
1719
- if (failed (isValidMaskedInputVector (resultTensorShape, inputVectorSizes)))
1750
+ ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
1751
+ if (failed (isValidMaskedInputVector (
1752
+ resultTensorShape.take_front (packOp.getSourceRank ()),
1753
+ inputVectorSizes)))
1720
1754
return failure ();
1721
1755
1722
1756
if (llvm::any_of (packOp.getInnerTiles (), [](OpFoldResult v) {
1723
- std::optional<int64_t > res = getConstantIntValue (v);
1724
- return !res.has_value ();
1757
+ return !getConstantIntValue (v).has_value ();
1725
1758
})) {
1726
1759
LDBG (" inner_tiles must be constant: " << packOp << " \n " );
1727
1760
return failure ();
0 commit comments