@@ -1414,27 +1414,39 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1414
1414
// / create an empty destination tensor and create a TransferWriteOp from the
1415
1415
// / input to the empty tensor. If the destination shape is not the same as the
1416
1416
// / inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1417
- // / mask for the write.
1417
+ // / mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1418
+ // / inBounds attribute of the transfer write op instead of masking.
1418
1419
static Operation *createWriteOrMaskedWrite (OpBuilder &builder, Location loc,
1419
1420
Value input,
1420
1421
SmallVector<OpFoldResult> destSizes,
1421
- ArrayRef<int64_t > inputVectorSizes) {
1422
+ ArrayRef<int64_t > inputVectorSizes,
1423
+ bool useInBoundsInsteadOfMasking) {
1424
+
1422
1425
auto inputType = cast<VectorType>(input.getType ());
1423
1426
Value dest = builder.create <tensor::EmptyOp>(loc, destSizes,
1424
1427
inputType.getElementType ());
1425
1428
int64_t rank = cast<ShapedType>(dest.getType ()).getRank ();
1426
1429
auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1430
+ auto destShape = cast<ShapedType>(dest.getType ()).getShape ();
1431
+ SmallVector<bool > inBoundsVal (rank, true );
1432
+ if (useInBoundsInsteadOfMasking) {
1433
+ // Update the inBounds attribute.
1434
+ for (unsigned i = 0 ; i < rank; i++)
1435
+ inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1436
+ !ShapedType::isDynamic (destShape[i]);
1437
+ }
1427
1438
Operation *write = builder.create <vector::TransferWriteOp>(
1428
1439
loc,
1429
1440
/* vector=*/ input,
1430
1441
/* source=*/ dest,
1431
1442
/* indices=*/ SmallVector<Value>(rank, zero),
1432
- /* inBounds=*/ SmallVector<bool >(rank, true ));
1433
- auto destShape = cast<ShapedType>(dest.getType ()).getShape ();
1443
+ /* inBounds=*/ inBoundsVal);
1434
1444
assert (llvm::none_of (
1435
1445
destShape.drop_front (inputVectorSizes.size ()),
1436
1446
[](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1437
1447
" Only dims aligned with inputVectorSizes may be dynamic" );
1448
+ if (useInBoundsInsteadOfMasking)
1449
+ return write;
1438
1450
bool needMaskForWrite = !llvm::equal (
1439
1451
inputVectorSizes, destShape.take_front (inputVectorSizes.size ()));
1440
1452
if (needMaskForWrite) {
@@ -1535,9 +1547,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1535
1547
loc, shapeCastOp.getResult (), destPermutation);
1536
1548
1537
1549
// Create TransferWriteOp.
1538
- Operation *write =
1539
- createWriteOrMaskedWrite ( rewriter, loc, transposeOp.getResult (),
1540
- reifiedReturnShapes[ 0 ], inputVectorSizes );
1550
+ Operation *write = createWriteOrMaskedWrite (
1551
+ rewriter, loc, transposeOp.getResult (), reifiedReturnShapes[ 0 ] ,
1552
+ inputVectorSizes, /* useInBoundsInsteadOfMasking= */ false );
1541
1553
newResults.push_back (write->getResult (0 ));
1542
1554
return success ();
1543
1555
}
@@ -1547,7 +1559,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1547
1559
// / vector::TransposeOp - Transpose the Source tensor
1548
1560
// / ShapeCastOp - Reshape the data based on the target.
1549
1561
// / vector::TransferWriteOp. - Write the result vector back to the destination
1550
- // / tensor
1562
+ // / tensor.
1563
+ // / If the vector sizes are not provided:
1564
+ // / * the vector sizes are determined by the input operand and attributes,
1565
+ // / * update the inBounds attribute instead of masking.
1551
1566
static LogicalResult
1552
1567
vectorizeAsTensorUnpackOp (RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1553
1568
ArrayRef<int64_t > inputVectorSizes,
@@ -1560,40 +1575,65 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1560
1575
1561
1576
ArrayRef<int64_t > innerDimPos = unpackOp.getInnerDimsPos ();
1562
1577
ArrayRef<int64_t > innerTiles = unpackOp.getStaticInnerTiles ();
1563
-
1564
- SmallVector<int64_t > readMaskShape (inputVectorSizes.begin (),
1565
- inputVectorSizes.end ());
1566
- ArrayRef<int64_t > outerDimsPerm = unpackOp.getOuterDimsPerm ();
1567
1578
ArrayRef<int64_t > sourceShape = unpackTensorType.getShape ();
1579
+ bool useInBoundsInsteadOfMasking = false ;
1580
+ ArrayRef<int64_t > outerDimsPerm = unpackOp.getOuterDimsPerm ();
1581
+
1582
+ auto destSize = unpackOp.getDestRank ();
1583
+
1584
+ if (!inputVectorSizes.empty ())
1585
+ assert (inputVectorSizes.size () == destSize &&
1586
+ " Incorrect number of input vector sizes" );
1568
1587
1569
- // ReadMask is the size of tensor used to read and apply mask. It is
1588
+ // vectorSizes is the shape of the vector that will be used to do final
1589
+ // write on the destination tensor. It is set like this: Let's say the
1590
+ // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1591
+ // Thus:
1592
+ // 1. vectorSizes = sourceShape.take_front(N)
1593
+ // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1594
+ // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1595
+ // innerTiles attribute value.
1596
+ SmallVector<int64_t > vectorSizes (inputVectorSizes);
1597
+ if (vectorSizes.empty ()) {
1598
+ llvm::append_range (vectorSizes, sourceShape.take_front (destSize));
1599
+ if (!outerDimsPerm.empty ())
1600
+ applyPermutationToVector (vectorSizes, outerDimsPerm);
1601
+ for (auto [i, pos] : llvm::enumerate (innerDimPos))
1602
+ vectorSizes[pos] *= innerTiles[i];
1603
+
1604
+ useInBoundsInsteadOfMasking = true ;
1605
+ }
1606
+
1607
+ // readVectorSizes is the size of tensor used to read and apply mask. It is
1570
1608
// set like this: Let's say the vectorSize (VS) array is size 'N' and
1571
1609
// the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1572
1610
// size M-N
1573
1611
// Thus:
1574
- // - initially: ReadMaskShape = vectorInputSizes
1612
+ // - initially: readVectorSizes = vectorInputSizes
1575
1613
// - Divide all the readMaskShape locations pointed by innerDimPos
1576
1614
// by the innerTileSize attribute value.
1577
- // - if outer_dims_perms is present: do that permutation on readMaskShape .
1615
+ // - if outer_dims_perms is present: do that permutation on readVectorSizes .
1578
1616
// - Append the remaining shape from SS
1579
1617
// E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1580
1618
// inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1581
1619
// 128] and outer_dims_perm is [1, 0] then read shape is:
1582
- // ReadMaskShape (initial): [512, 128]
1620
+ // ReadVectorSizes (initial): [512, 128]
1583
1621
// Final Value(after innerDim Adjustment): [512/32, 128/16]
1584
1622
// = [16, 8]
1585
1623
// After applying outer_dims_perm: [8, 16]
1586
1624
// After appending the rest of the sourceShape: [8, 16, 32, 16]
1587
1625
1626
+ SmallVector<int64_t > readVectorSizes (vectorSizes.begin (), vectorSizes.end ());
1627
+
1588
1628
for (auto [index, size] : enumerate(innerTiles)) {
1589
- readMaskShape [innerDimPos[index]] =
1590
- llvm::divideCeil (readMaskShape [innerDimPos[index]], size);
1629
+ readVectorSizes [innerDimPos[index]] =
1630
+ llvm::divideCeil (readVectorSizes [innerDimPos[index]], size);
1591
1631
}
1592
1632
if (!outerDimsPerm.empty ()) {
1593
- applyPermutationToVector (readMaskShape , outerDimsPerm);
1633
+ applyPermutationToVector (readVectorSizes , outerDimsPerm);
1594
1634
}
1595
- readMaskShape .append (sourceShape.begin () + inputVectorSizes .size (),
1596
- sourceShape.end ());
1635
+ readVectorSizes .append (sourceShape.begin () + vectorSizes .size (),
1636
+ sourceShape.end ());
1597
1637
1598
1638
ReifiedRankedShapedTypeDims reifiedRetShapes;
1599
1639
LogicalResult status =
@@ -1611,8 +1651,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1611
1651
// Read result, mask if necessary. If transferReadOp shape is not equal
1612
1652
// to shape of source, then a mask is necessary.
1613
1653
Value readResult = vector::createReadOrMaskedRead (
1614
- rewriter, loc, unpackOp.getSource (),
1615
- ArrayRef<int64_t >(readMaskShape.begin (), readMaskShape.end ()), padValue,
1654
+ rewriter, loc, unpackOp.getSource (), readVectorSizes, padValue,
1616
1655
/* useInBoundsInsteadOfMasking=*/ false );
1617
1656
1618
1657
PackingMetadata packMetadata;
@@ -1636,15 +1675,15 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1636
1675
vector::ShapeCastOp shapeCastOp = rewriter.create <vector::ShapeCastOp>(
1637
1676
loc, vecCollapsedType, transposeOp->getResult (0 ));
1638
1677
1639
- // WriteMaskShape had to match the shapecast shape for dynamic sizes,
1678
+ // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1640
1679
// otherwise the validator complains that the mask size is invalid.
1641
- SmallVector<int64_t > writeMaskShape (
1680
+ SmallVector<int64_t > writeVectorSizes (
1642
1681
unpackOp.getDestType ().hasStaticShape ()
1643
- ? inputVectorSizes
1682
+ ? vectorSizes
1644
1683
: shapeCastOp.getResultVectorType ().getShape ());
1645
- Operation *write =
1646
- createWriteOrMaskedWrite ( rewriter, loc, shapeCastOp.getResult (),
1647
- reifiedRetShapes[ 0 ], writeMaskShape );
1684
+ Operation *write = createWriteOrMaskedWrite (
1685
+ rewriter, loc, shapeCastOp.getResult (), reifiedRetShapes[ 0 ] ,
1686
+ writeVectorSizes, useInBoundsInsteadOfMasking );
1648
1687
newResults.push_back (write->getResult (0 ));
1649
1688
return success ();
1650
1689
}
@@ -1673,7 +1712,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1673
1712
rewriter, loc, padOp.getSource (), inputVectorSizes, padValue,
1674
1713
/* useInBoundsInsteadOfMasking=*/ false );
1675
1714
Operation *write = createWriteOrMaskedWrite (
1676
- rewriter, loc, maskedRead, reifiedReturnShapes[0 ], inputVectorSizes);
1715
+ rewriter, loc, maskedRead, reifiedReturnShapes[0 ], inputVectorSizes,
1716
+ /* useInBoundsInsteadOfMasking=*/ false );
1677
1717
newResults.push_back (write->getResult (0 ));
1678
1718
return success ();
1679
1719
}
@@ -1755,8 +1795,11 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1755
1795
LDBG (" Inner-tiles must be constant: " << unpackOp << " \n " );
1756
1796
return failure ();
1757
1797
}
1758
- llvm::ArrayRef<int64_t > resultShape = unpackOp.getDestType ().getShape ();
1759
- if (!inputVectorSizes.empty () &&
1798
+ ArrayRef<int64_t > resultShape = unpackOp.getDestType ().getShape ();
1799
+ bool satisfyEmptyCond = inputVectorSizes.empty () &&
1800
+ unpackOp.getDestType ().hasStaticShape () &&
1801
+ unpackOp.getSourceType ().hasStaticShape ();
1802
+ if (!satisfyEmptyCond &&
1760
1803
failed (vector::isValidMaskedInputVector (resultShape, inputVectorSizes)))
1761
1804
return failure ();
1762
1805
0 commit comments