@@ -1412,10 +1412,11 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1412
1412
1413
1413
// / Create a TransferReadOp from `source` with static shape `readShape`. If the
1414
1414
// / vector type for the read is not the same as the type of `source`, then a
1415
- // / mask is created on the read.
1415
+ // / mask is created on the read. If `doMasking` parameter is set to false we
1416
+ // / update the `inBounds` attribute instead of masking.
1416
1417
static Value createReadOrMaskedRead (OpBuilder &builder, Location loc,
1417
1418
Value source, ArrayRef<int64_t > readShape,
1418
- Value padValue) {
1419
+ Value padValue, bool doMasking = true ) {
1419
1420
assert (llvm::none_of (readShape,
1420
1421
[](int64_t s) { return s == ShapedType::kDynamic ; }));
1421
1422
auto sourceShape = dyn_cast<ShapedType>(source.getType ()).getShape ();
@@ -1424,14 +1425,21 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
1424
1425
auto vectorType = VectorType::get (readShape, padValue.getType ());
1425
1426
int64_t readRank = readShape.size ();
1426
1427
auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1428
+ SmallVector<bool > inBoundsVal (readRank, true );
1429
+ if (!doMasking) {
1430
+ // Update the inBounds attribute.
1431
+ for (unsigned i = 0 ; i < readRank; i++)
1432
+ inBoundsVal[i] = sourceShape[i] == readShape[i];
1433
+ }
1427
1434
auto transferReadOp = builder.create <vector::TransferReadOp>(
1428
1435
loc,
1429
1436
/* vectorType=*/ vectorType,
1430
1437
/* source=*/ source,
1431
1438
/* indices=*/ SmallVector<Value>(readRank, zero),
1432
1439
/* padding=*/ padValue,
1433
- /* inBounds=*/ SmallVector<bool >(readRank, true ));
1434
- if (llvm::equal (readShape, sourceShape)) {
1440
+ /* inBounds=*/ inBoundsVal);
1441
+
1442
+ if (llvm::equal (readShape, sourceShape) || !doMasking) {
1435
1443
return transferReadOp;
1436
1444
}
1437
1445
SmallVector<OpFoldResult> mixedSourceDims =
@@ -1482,11 +1490,10 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
1482
1490
return write;
1483
1491
}
1484
1492
1485
- // / Vectorize tensor::PackOp with (1) static innerTiles and (2) constant
1486
- // / padding value into:
1493
+ // / Vectorize tensor::PackOp with (1) static innerTiles (2) constant
1494
+ // / padding value and (3) input vector sizes into:
1487
1495
// / masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1488
1496
// / As in the following example:
1489
- // /
1490
1497
// / %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1491
1498
// / into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1492
1499
// /
@@ -1505,6 +1512,10 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
1505
1512
// / %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1506
1513
// / {in_bounds = [true, true, true, true, true]}
1507
1514
// / : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1515
+ // /
1516
+ // / If the (3) input vector sizes are not provided, the vector sizes are
1517
+ // / determined by the result tensor shape. Also, we update the inBounds
1518
+ // / attribute instead of masking.
1508
1519
static LogicalResult
1509
1520
vectorizeAsTensorPackOp (RewriterBase &rewriter, tensor::PackOp packOp,
1510
1521
ArrayRef<int64_t > inputVectorSizes,
@@ -1525,6 +1536,16 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1525
1536
(void )status; // prevent unused variable warning on non-assert builds.
1526
1537
assert (succeeded (status) && " failed to reify result shapes" );
1527
1538
1539
+ // If the input vector sizes are not provided, then the vector sizes are
1540
+ // determined by the result tensor shape. In case the vector sizes aren't
1541
+ // provided, we update the inBounds attribute instead of masking.
1542
+ bool doMasking = true ;
1543
+ if (inputVectorSizes.empty ()) {
1544
+ ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
1545
+ inputVectorSizes = resultTensorShape.take_front (packOp.getSourceRank ());
1546
+ doMasking = false ;
1547
+ }
1548
+
1528
1549
// Create masked TransferReadOp.
1529
1550
SmallVector<int64_t > inputShape (inputVectorSizes);
1530
1551
auto innerTiles = packOp.getStaticInnerTiles ();
@@ -1536,7 +1557,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1536
1557
for (auto [idx, size] : enumerate(innerTiles))
1537
1558
inputShape[innerDimsPos[idx]] *= size;
1538
1559
auto maskedRead = createReadOrMaskedRead (rewriter, loc, packOp.getSource (),
1539
- inputShape, padValue);
1560
+ inputShape, padValue, doMasking );
1540
1561
1541
1562
// Create ShapeCastOp.
1542
1563
SmallVector<int64_t > destShape (inputVectorSizes);
@@ -1763,7 +1784,7 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
1763
1784
// / Returns success if `inputVectorSizes` is a valid masking configuraion for
1764
1785
// / given `shape`, i.e., it meets:
1765
1786
// / 1. The numbers of elements in both array are equal.
1766
- // / 2. `inputVectorSizes` does nos have dynamic dimensions.
1787
+ // / 2. `inputVectorSizes` does not have dynamic dimensions.
1767
1788
// / 3. All the values in `inputVectorSizes` are greater than or equal to
1768
1789
// / static sizes in `shape`.
1769
1790
static LogicalResult
@@ -1881,18 +1902,25 @@ static LogicalResult vectorizeLinalgOpPrecondition(
1881
1902
return success ();
1882
1903
}
1883
1904
1884
- // / TODO: Use a matcher to check for a constant padding value.
1885
1905
static LogicalResult
1886
1906
vectorizePackOpPrecondition (tensor::PackOp packOp,
1887
1907
ArrayRef<int64_t > inputVectorSizes) {
1888
1908
auto padValue = packOp.getPaddingValue ();
1889
- if (padValue && !padValue.getDefiningOp <arith::ConstantOp>()) {
1909
+ Attribute cstAttr;
1910
+ if (padValue && !matchPattern (padValue, m_Constant (&cstAttr))) {
1890
1911
LDBG (" pad value is not constant: " << packOp << " \n " );
1891
1912
return failure ();
1892
1913
}
1893
-
1894
1914
ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
1895
- if (failed (isValidMaskedInputVector (
1915
+ bool satisfyEmptyCond = true ;
1916
+ if (inputVectorSizes.empty ()) {
1917
+ if (!packOp.getDestType ().hasStaticShape () ||
1918
+ !packOp.getSourceType ().hasStaticShape ())
1919
+ satisfyEmptyCond = false ;
1920
+ }
1921
+
1922
+ if (!satisfyEmptyCond &&
1923
+ failed (isValidMaskedInputVector (
1896
1924
resultTensorShape.take_front (packOp.getSourceRank ()),
1897
1925
inputVectorSizes)))
1898
1926
return failure ();
0 commit comments