@@ -1414,27 +1414,39 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1414
1414
// / create an empty destination tensor and create a TransferWriteOp from the
1415
1415
// / input to the empty tensor. If the destination shape is not the same as the
1416
1416
// / inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1417
- // / mask for the write.
1417
+ // / mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1418
+ // / inBounds attribute of the transfer write op instead of masking.
1418
1419
static Operation *createWriteOrMaskedWrite (OpBuilder &builder, Location loc,
1419
1420
Value input,
1420
1421
SmallVector<OpFoldResult> destSizes,
1421
- ArrayRef<int64_t > inputVectorSizes) {
1422
+ ArrayRef<int64_t > inputVectorSizes,
1423
+ bool useInBoundsInsteadOfMasking) {
1424
+
1422
1425
auto inputType = cast<VectorType>(input.getType ());
1423
1426
Value dest = builder.create <tensor::EmptyOp>(loc, destSizes,
1424
1427
inputType.getElementType ());
1425
1428
int64_t rank = cast<ShapedType>(dest.getType ()).getRank ();
1426
1429
auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1430
+ auto destShape = cast<ShapedType>(dest.getType ()).getShape ();
1431
+ SmallVector<bool > inBoundsVal (rank, true );
1432
+ if (useInBoundsInsteadOfMasking) {
1433
+ // Update the inBounds attribute.
1434
+ for (unsigned i = 0 ; i < rank; i++)
1435
+ inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1436
+ !ShapedType::isDynamic (destShape[i]);
1437
+ }
1427
1438
Operation *write = builder.create <vector::TransferWriteOp>(
1428
1439
loc,
1429
1440
/* vector=*/ input,
1430
1441
/* source=*/ dest,
1431
1442
/* indices=*/ SmallVector<Value>(rank, zero),
1432
- /* inBounds=*/ SmallVector<bool >(rank, true ));
1433
- auto destShape = cast<ShapedType>(dest.getType ()).getShape ();
1443
+ /* inBounds=*/ inBoundsVal);
1434
1444
assert (llvm::none_of (
1435
1445
destShape.drop_front (inputVectorSizes.size ()),
1436
1446
[](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1437
1447
" Only dims aligned with inputVectorSizes may be dynamic" );
1448
+ if (useInBoundsInsteadOfMasking)
1449
+ return write;
1438
1450
bool needMaskForWrite = !llvm::equal (
1439
1451
inputVectorSizes, destShape.take_front (inputVectorSizes.size ()));
1440
1452
if (needMaskForWrite) {
@@ -1535,9 +1547,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1535
1547
loc, shapeCastOp.getResult (), destPermutation);
1536
1548
1537
1549
// Create TransferWriteOp.
1538
- Operation *write =
1539
- createWriteOrMaskedWrite ( rewriter, loc, transposeOp.getResult (),
1540
- reifiedReturnShapes[ 0 ], inputVectorSizes );
1550
+ Operation *write = createWriteOrMaskedWrite (
1551
+ rewriter, loc, transposeOp.getResult (), reifiedReturnShapes[ 0 ] ,
1552
+ inputVectorSizes, /* useInBoundsInsteadOfMasking= */ false );
1541
1553
newResults.push_back (write->getResult (0 ));
1542
1554
return success ();
1543
1555
}
@@ -1547,7 +1559,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1547
1559
// / vector::TransposeOp - Transpose the Source tensor
1548
1560
// / ShapeCastOp - Reshape the data based on the target.
1549
1561
// / vector::TransferWriteOp. - Write the result vector back to the destination
1550
- // / tensor
1562
+ // / tensor. If the vector sizes are not provided:
1563
+ // / * the vector sizes are determined by the input operand and attributes,
1564
+ // / * update the inBounds attribute instead of masking.
1551
1565
static LogicalResult
1552
1566
vectorizeAsTensorUnpackOp (RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1553
1567
ArrayRef<int64_t > inputVectorSizes,
@@ -1560,11 +1574,33 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1560
1574
1561
1575
ArrayRef<int64_t > innerDimPos = unpackOp.getInnerDimsPos ();
1562
1576
ArrayRef<int64_t > innerTiles = unpackOp.getStaticInnerTiles ();
1577
+ ArrayRef<int64_t > sourceShape = unpackTensorType.getShape ();
1578
+ bool useInBoundsInsteadOfMasking = false ;
1579
+ ArrayRef<int64_t > outerDimsPerm = unpackOp.getOuterDimsPerm ();
1580
+
1581
+ auto destSize = unpackOp.getDestRank ();
1582
+
1583
+ // initVectorShape is the shape of the vector that will be used to do final
1584
+ // write on the destination tensor. It is set like this: Let's say the
1585
+ // sourceShape is 'M' and the vectorSize (VS) array is size 'N' where N <= M.
1586
+ // Thus:
1587
+ // - initVectorShape = sourceShape.take_front(N)
1588
+ // - if outer_dims_perms is present: do that permutation on initVectorShape.
1589
+ // - Multiply all the locations pointed by innerDimPos by the innerTileSize
1590
+ // attribute value.
1591
+ SmallVector<int64_t > initVectorShape (sourceShape.take_front (destSize));
1592
+ if (inputVectorSizes.empty ()) {
1593
+ if (!outerDimsPerm.empty ())
1594
+ applyPermutationToVector (initVectorShape, outerDimsPerm);
1595
+ for (auto [i, pos] : llvm::enumerate (innerDimPos))
1596
+ initVectorShape[pos] *= innerTiles[i];
1597
+
1598
+ inputVectorSizes = initVectorShape;
1599
+ useInBoundsInsteadOfMasking = true ;
1600
+ }
1563
1601
1564
1602
SmallVector<int64_t > readMaskShape (inputVectorSizes.begin (),
1565
1603
inputVectorSizes.end ());
1566
- ArrayRef<int64_t > outerDimsPerm = unpackOp.getOuterDimsPerm ();
1567
- ArrayRef<int64_t > sourceShape = unpackTensorType.getShape ();
1568
1604
1569
1605
// ReadMask is the size of tensor used to read and apply mask. It is
1570
1606
// set like this: Let's say the vectorSize (VS) array is size 'N' and
@@ -1642,9 +1678,9 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1642
1678
unpackOp.getDestType ().hasStaticShape ()
1643
1679
? inputVectorSizes
1644
1680
: shapeCastOp.getResultVectorType ().getShape ());
1645
- Operation *write =
1646
- createWriteOrMaskedWrite ( rewriter, loc, shapeCastOp.getResult (),
1647
- reifiedRetShapes[ 0 ], writeMaskShape );
1681
+ Operation *write = createWriteOrMaskedWrite (
1682
+ rewriter, loc, shapeCastOp.getResult (), reifiedRetShapes[ 0 ] ,
1683
+ writeMaskShape, useInBoundsInsteadOfMasking );
1648
1684
newResults.push_back (write->getResult (0 ));
1649
1685
return success ();
1650
1686
}
@@ -1673,7 +1709,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1673
1709
rewriter, loc, padOp.getSource (), inputVectorSizes, padValue,
1674
1710
/* useInBoundsInsteadOfMasking=*/ false );
1675
1711
Operation *write = createWriteOrMaskedWrite (
1676
- rewriter, loc, maskedRead, reifiedReturnShapes[0 ], inputVectorSizes);
1712
+ rewriter, loc, maskedRead, reifiedReturnShapes[0 ], inputVectorSizes,
1713
+ /* useInBoundsInsteadOfMasking=*/ false );
1677
1714
newResults.push_back (write->getResult (0 ));
1678
1715
return success ();
1679
1716
}
@@ -1755,8 +1792,11 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1755
1792
LDBG (" Inner-tiles must be constant: " << unpackOp << " \n " );
1756
1793
return failure ();
1757
1794
}
1758
- llvm::ArrayRef<int64_t > resultShape = unpackOp.getDestType ().getShape ();
1759
- if (!inputVectorSizes.empty () &&
1795
+ ArrayRef<int64_t > resultShape = unpackOp.getDestType ().getShape ();
1796
+ bool satisfyEmptyCond = inputVectorSizes.empty () &&
1797
+ unpackOp.getDestType ().hasStaticShape () &&
1798
+ unpackOp.getSourceType ().hasStaticShape ();
1799
+ if (!satisfyEmptyCond &&
1760
1800
failed (vector::isValidMaskedInputVector (resultShape, inputVectorSizes)))
1761
1801
return failure ();
1762
1802
0 commit comments