@@ -1597,6 +1597,16 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1597
1597
1598
1598
RankedTensorType unpackTensorType = unpackOp.getSourceType ();
1599
1599
1600
+ // If the input vector sizes are not provided, then the vector sizes are
1601
+ // determined by the result tensor shape. In case the vector sizes aren't
1602
+ // provided, we update the inBounds attribute instead of masking.
1603
+ bool doMasking = true ;
1604
+ if (inputVectorSizes.empty ()) {
1605
+ ArrayRef<int64_t > resultTensorShape = unpackOp.getDestType ().getShape ();
1606
+ inputVectorSizes = resultTensorShape.take_front (unpackOp.getSourceRank ());
1607
+ doMasking = false ;
1608
+ }
1609
+
1600
1610
ArrayRef<int64_t > innerDimPos = unpackOp.getInnerDimsPos ();
1601
1611
ArrayRef<int64_t > innerTiles = unpackOp.getStaticInnerTiles ();
1602
1612
@@ -1651,7 +1661,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1651
1661
// to shape of source, then a mask is necessary.
1652
1662
Value readResult = createReadOrMaskedRead (
1653
1663
rewriter, loc, unpackOp.getSource (),
1654
- ArrayRef<int64_t >(readMaskShape.begin (), readMaskShape.end ()), padValue);
1664
+ ArrayRef<int64_t >(readMaskShape.begin (), readMaskShape.end ()), padValue,
1665
+ doMasking);
1655
1666
1656
1667
PackingMetadata packMetadata;
1657
1668
SmallVector<int64_t > lastDimToInsertPosPerm =
@@ -1827,8 +1838,14 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1827
1838
LDBG (" Inner-tiles must be constant: " << unpackOp << " \n " );
1828
1839
return failure ();
1829
1840
}
1830
- llvm::ArrayRef<int64_t > resultShape = unpackOp.getDestType ().getShape ();
1831
- if (!inputVectorSizes.empty () &&
1841
+ ArrayRef<int64_t > resultShape = unpackOp.getDestType ().getShape ();
1842
+ bool satisfyEmptyCond = true ;
1843
+ if (inputVectorSizes.empty ()) {
1844
+ if (!unpackOp.getDestType ().hasStaticShape () ||
1845
+ !unpackOp.getSourceType ().hasStaticShape ())
1846
+ satisfyEmptyCond = false ;
1847
+ }
1848
+ if (!satisfyEmptyCond &&
1832
1849
failed (isValidMaskedInputVector (resultShape, inputVectorSizes)))
1833
1850
return failure ();
1834
1851
0 commit comments