@@ -1412,10 +1412,12 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1412
1412
1413
1413
// / Create a TransferReadOp from `source` with static shape `readShape`. If the
1414
1414
// / vector type for the read is not the same as the type of `source`, then a
1415
- // / mask is created on the read.
1415
+ // / mask is created on the read. `doMasking` specifies whether masking is
1416
+ // / required or not. If `doMasking` paramter is set to false we update the
1417
+ // / `inBounds` attribute instead.
1416
1418
static Value createReadOrMaskedRead (OpBuilder &builder, Location loc,
1417
1419
Value source, ArrayRef<int64_t > readShape,
1418
- Value padValue) {
1420
+ Value padValue, bool doMasking = true ) {
1419
1421
assert (llvm::none_of (readShape,
1420
1422
[](int64_t s) { return s == ShapedType::kDynamic ; }));
1421
1423
auto sourceShape = dyn_cast<ShapedType>(source.getType ()).getShape ();
@@ -1424,14 +1426,21 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
1424
1426
auto vectorType = VectorType::get (readShape, padValue.getType ());
1425
1427
int64_t readRank = readShape.size ();
1426
1428
auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1429
+ SmallVector<bool > inBoundsVal (readRank, true );
1430
+ if (!doMasking) {
1431
+ // Update the inBounds attribute.
1432
+ for (unsigned i = 0 ; i < readRank; i++)
1433
+ inBoundsVal[i] = sourceShape[i] == readShape[i];
1434
+ }
1427
1435
auto transferReadOp = builder.create <vector::TransferReadOp>(
1428
1436
loc,
1429
1437
/* vectorType=*/ vectorType,
1430
1438
/* source=*/ source,
1431
1439
/* indices=*/ SmallVector<Value>(readRank, zero),
1432
1440
/* padding=*/ padValue,
1433
- /* inBounds=*/ SmallVector<bool >(readRank, true ));
1434
- if (llvm::equal (readShape, sourceShape)) {
1441
+ /* inBounds=*/ inBoundsVal);
1442
+
1443
+ if (llvm::equal (readShape, sourceShape) || !doMasking) {
1435
1444
return transferReadOp;
1436
1445
}
1437
1446
SmallVector<OpFoldResult> mixedSourceDims =
@@ -1525,6 +1534,20 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1525
1534
(void )status; // prevent unused variable warning on non-assert builds.
1526
1535
assert (succeeded (status) && " failed to reify result shapes" );
1527
1536
1537
+ ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
1538
+ bool doMasking = true ;
1539
+
1540
+ // If the input vector sizes are not provided, then the vector sizes are
1541
+ // determined by the result tensor shape. In case the vector sizes aren't
1542
+ // provided, we update the inBounds attribute instead of masking.
1543
+ if (inputVectorSizes.empty ()) {
1544
+ // Make sure that the result tensor shape is static.
1545
+ if (ShapedType::isDynamicShape (resultTensorShape))
1546
+ return failure ();
1547
+ inputVectorSizes = resultTensorShape.take_front (packOp.getSourceRank ());
1548
+ doMasking = false ;
1549
+ }
1550
+
1528
1551
// Create masked TransferReadOp.
1529
1552
SmallVector<int64_t > inputShape (inputVectorSizes);
1530
1553
auto innerTiles = packOp.getStaticInnerTiles ();
@@ -1536,7 +1559,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1536
1559
for (auto [idx, size] : enumerate(innerTiles))
1537
1560
inputShape[innerDimsPos[idx]] *= size;
1538
1561
auto maskedRead = createReadOrMaskedRead (rewriter, loc, packOp.getSource (),
1539
- inputShape, padValue);
1562
+ inputShape, padValue, doMasking );
1540
1563
1541
1564
// Create ShapeCastOp.
1542
1565
SmallVector<int64_t > destShape (inputVectorSizes);
@@ -1763,7 +1786,7 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
1763
1786
// / Returns success if `inputVectorSizes` is a valid masking configuraion for
1764
1787
// / given `shape`, i.e., it meets:
1765
1788
// / 1. The numbers of elements in both array are equal.
1766
- // / 2. `inputVectorSizes` does nos have dynamic dimensions.
1789
+ // / 2. `inputVectorSizes` does not have dynamic dimensions.
1767
1790
// / 3. All the values in `inputVectorSizes` are greater than or equal to
1768
1791
// / static sizes in `shape`.
1769
1792
static LogicalResult
@@ -1881,18 +1904,19 @@ static LogicalResult vectorizeLinalgOpPrecondition(
1881
1904
return success ();
1882
1905
}
1883
1906
1884
- // / TODO: Use a matcher to check for a constant padding value.
1885
1907
static LogicalResult
1886
1908
vectorizePackOpPrecondition (tensor::PackOp packOp,
1887
1909
ArrayRef<int64_t > inputVectorSizes) {
1888
1910
auto padValue = packOp.getPaddingValue ();
1889
- if (padValue && !padValue.getDefiningOp <arith::ConstantOp>()) {
1911
+ Attribute cstAttr;
1912
+ if (padValue && !matchPattern (padValue, m_Constant (&cstAttr))) {
1890
1913
LDBG (" pad value is not constant: " << packOp << " \n " );
1891
1914
return failure ();
1892
1915
}
1893
1916
1894
1917
ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
1895
- if (failed (isValidMaskedInputVector (
1918
+ if (!inputVectorSizes.empty () &&
1919
+ failed (isValidMaskedInputVector (
1896
1920
resultTensorShape.take_front (packOp.getSourceRank ()),
1897
1921
inputVectorSizes)))
1898
1922
return failure ();
0 commit comments