@@ -1412,10 +1412,11 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1412
1412
1413
1413
// / Create a TransferReadOp from `source` with static shape `readShape`. If the
1414
1414
// / vector type for the read is not the same as the type of `source`, then a
1415
- // / mask is created on the read.
1415
+ // / mask is created on the read. If `doMasking` parameter is set to false we
1416
+ // / update the `inBounds` attribute instead of masking.
1416
1417
static Value createReadOrMaskedRead (OpBuilder &builder, Location loc,
1417
1418
Value source, ArrayRef<int64_t > readShape,
1418
- Value padValue) {
1419
+ Value padValue, bool doMasking = true ) {
1419
1420
assert (llvm::none_of (readShape,
1420
1421
[](int64_t s) { return s == ShapedType::kDynamic ; }));
1421
1422
auto sourceShape = dyn_cast<ShapedType>(source.getType ()).getShape ();
@@ -1424,14 +1425,21 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
1424
1425
auto vectorType = VectorType::get (readShape, padValue.getType ());
1425
1426
int64_t readRank = readShape.size ();
1426
1427
auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1428
+ SmallVector<bool > inBoundsVal (readRank, true );
1429
+ if (!doMasking) {
1430
+ // Update the inBounds attribute.
1431
+ for (unsigned i = 0 ; i < readRank; i++)
1432
+ inBoundsVal[i] = sourceShape[i] == readShape[i];
1433
+ }
1427
1434
auto transferReadOp = builder.create <vector::TransferReadOp>(
1428
1435
loc,
1429
1436
/* vectorType=*/ vectorType,
1430
1437
/* source=*/ source,
1431
1438
/* indices=*/ SmallVector<Value>(readRank, zero),
1432
1439
/* padding=*/ padValue,
1433
- /* inBounds=*/ SmallVector<bool >(readRank, true ));
1434
- if (llvm::equal (readShape, sourceShape)) {
1440
+ /* inBounds=*/ inBoundsVal);
1441
+
1442
+ if (llvm::equal (readShape, sourceShape) || !doMasking) {
1435
1443
return transferReadOp;
1436
1444
}
1437
1445
SmallVector<OpFoldResult> mixedSourceDims =
@@ -1525,6 +1533,17 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1525
1533
(void )status; // prevent unused variable warning on non-assert builds.
1526
1534
assert (succeeded (status) && " failed to reify result shapes" );
1527
1535
1536
+ ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
1537
+ bool doMasking = true ;
1538
+
1539
+ // If the input vector sizes are not provided, then the vector sizes are
1540
+ // determined by the result tensor shape. In case the vector sizes aren't
1541
+ // provided, we update the inBounds attribute instead of masking.
1542
+ if (inputVectorSizes.empty ()) {
1543
+ inputVectorSizes = resultTensorShape.take_front (packOp.getSourceRank ());
1544
+ doMasking = false ;
1545
+ }
1546
+
1528
1547
// Create masked TransferReadOp.
1529
1548
SmallVector<int64_t > inputShape (inputVectorSizes);
1530
1549
auto innerTiles = packOp.getStaticInnerTiles ();
@@ -1536,7 +1555,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1536
1555
for (auto [idx, size] : enumerate(innerTiles))
1537
1556
inputShape[innerDimsPos[idx]] *= size;
1538
1557
auto maskedRead = createReadOrMaskedRead (rewriter, loc, packOp.getSource (),
1539
- inputShape, padValue);
1558
+ inputShape, padValue, doMasking );
1540
1559
1541
1560
// Create ShapeCastOp.
1542
1561
SmallVector<int64_t > destShape (inputVectorSizes);
@@ -1763,7 +1782,7 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
1763
1782
// / Returns success if `inputVectorSizes` is a valid masking configuraion for
1764
1783
// / given `shape`, i.e., it meets:
1765
1784
// / 1. The numbers of elements in both array are equal.
1766
- // / 2. `inputVectorSizes` does nos have dynamic dimensions.
1785
+ // / 2. `inputVectorSizes` does not have dynamic dimensions.
1767
1786
// / 3. All the values in `inputVectorSizes` are greater than or equal to
1768
1787
// / static sizes in `shape`.
1769
1788
static LogicalResult
@@ -1881,18 +1900,26 @@ static LogicalResult vectorizeLinalgOpPrecondition(
1881
1900
return success ();
1882
1901
}
1883
1902
1884
- // / TODO: Use a matcher to check for a constant padding value.
1885
1903
static LogicalResult
1886
1904
vectorizePackOpPrecondition (tensor::PackOp packOp,
1887
1905
ArrayRef<int64_t > inputVectorSizes) {
1888
1906
auto padValue = packOp.getPaddingValue ();
1889
- if (padValue && !padValue.getDefiningOp <arith::ConstantOp>()) {
1907
+ Attribute cstAttr;
1908
+ if (padValue && !matchPattern (padValue, m_Constant (&cstAttr))) {
1890
1909
LDBG (" pad value is not constant: " << packOp << " \n " );
1891
1910
return failure ();
1892
1911
}
1893
-
1894
1912
ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
1895
- if (failed (isValidMaskedInputVector (
1913
+ ArrayRef<int64_t > inputTensorShape = packOp.getSourceType ().getShape ();
1914
+ bool satisfyEmptyCond = true ;
1915
+ if (inputVectorSizes.empty ()) {
1916
+ if (ShapedType::isDynamicShape (resultTensorShape) ||
1917
+ ShapedType::isDynamicShape (inputTensorShape))
1918
+ satisfyEmptyCond = false ;
1919
+ }
1920
+
1921
+ if (!satisfyEmptyCond &&
1922
+ failed (isValidMaskedInputVector (
1896
1923
resultTensorShape.take_front (packOp.getSourceRank ()),
1897
1924
inputVectorSizes)))
1898
1925
return failure ();
0 commit comments