@@ -1558,6 +1558,16 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1558
1558
1559
1559
RankedTensorType unpackTensorType = unpackOp.getSourceType ();
1560
1560
1561
+ // If the input vector sizes are not provided, then the vector sizes are
1562
+ // determined by the result tensor shape. In case the vector sizes aren't
1563
+ // provided, we update the inBounds attribute instead of masking.
1564
+ bool useInBoundsInsteadOfMasking = true ;
1565
+ if (inputVectorSizes.empty ()) {
1566
+ ArrayRef<int64_t > resultTensorShape = unpackOp.getDestType ().getShape ();
1567
+ inputVectorSizes = resultTensorShape.take_front (unpackOp.getSourceRank ());
1568
+ useInBoundsInsteadOfMasking = false ;
1569
+ }
1570
+
1561
1571
ArrayRef<int64_t > innerDimPos = unpackOp.getInnerDimsPos ();
1562
1572
ArrayRef<int64_t > innerTiles = unpackOp.getStaticInnerTiles ();
1563
1573
@@ -1612,7 +1622,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1612
1622
// to shape of source, then a mask is necessary.
1613
1623
Value readResult = vector::createReadOrMaskedRead (
1614
1624
rewriter, loc, unpackOp.getSource (),
1615
- ArrayRef<int64_t >(readMaskShape.begin (), readMaskShape.end ()), padValue);
1625
+ ArrayRef<int64_t >(readMaskShape.begin (), readMaskShape.end ()), padValue,
1626
+ doMasking);
1616
1627
1617
1628
PackingMetadata packMetadata;
1618
1629
SmallVector<int64_t > lastDimToInsertPosPerm =
@@ -1753,8 +1764,14 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1753
1764
LDBG (" Inner-tiles must be constant: " << unpackOp << " \n " );
1754
1765
return failure ();
1755
1766
}
1756
- llvm::ArrayRef<int64_t > resultShape = unpackOp.getDestType ().getShape ();
1757
- if (!inputVectorSizes.empty () &&
1767
+ ArrayRef<int64_t > resultShape = unpackOp.getDestType ().getShape ();
1768
+ bool satisfyEmptyCond = true ;
1769
+ if (inputVectorSizes.empty ()) {
1770
+ if (!unpackOp.getDestType ().hasStaticShape () ||
1771
+ !unpackOp.getSourceType ().hasStaticShape ())
1772
+ satisfyEmptyCond = false ;
1773
+ }
1774
+ if (!satisfyEmptyCond &&
1758
1775
failed (vector::isValidMaskedInputVector (resultShape, inputVectorSizes)))
1759
1776
return failure ();
1760
1777
0 commit comments