@@ -1590,61 +1590,46 @@ static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
1590
1590
// / Creates an optionally masked TransferWriteOp
1591
1591
// /
1592
1592
// / Generates the following operation:
1593
- // / %res = vector.transfer_write %vectorToStore into %dest
1593
+ // / %res = vector.transfer_write %vecToStore into %dest
1594
1594
// /
1595
- // / If the leading N dimensions of the vector to store do not match
1596
- // / `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
1597
- // / masking is applied to ensure correctness:
1595
+ // / If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
1598
1596
// /
1599
- // / %mask = vector.create_mask(%destShape) : %vectorToStoreShape
1597
+ // / %mask = vector.create_mask(%destShape) : %vecToStoreShape
1600
1598
// / %res = vector.mask %mask {
1601
- // / vector.transfer_write %vectorToStore into %dest
1599
+ // / vector.transfer_write %vecToStore into %dest
1602
1600
// / }
1603
1601
// /
1604
- // / The mask shape is identical to `vectorToStore ` (with the element type ==
1602
+ // / The mask shape is identical to `vecToStore ` (with the element type ==
1605
1603
// / i1), and the mask values are based on the shape of the `dest` tensor.
1606
1604
// /
1607
1605
// / If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1608
1606
// / is used instead of masking:
1609
1607
// /
1610
- // / %write = vector.transfer_write %vectorToStore into %dest
1608
+ // / %write = vector.transfer_write %vecToStore into %dest
1611
1609
// / in_bounds_flags = (...)
1612
1610
// / %res = vector.transfer_write %input into %dest
1613
1611
// / {in_bounds = in_bounds_flags}
1614
1612
// /
1615
- // / `writeIndices` specifies the offsets to use. If empty, all indices are set
1616
- // / to 0.
1617
- // /
1618
- // / NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
1619
- // / `valueToStore`.
1620
- // / TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
1621
- // / already provided in `vectorToStore`.
1613
+ // / Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1614
+ // / are set to 0.
1622
1615
static Operation *
1623
- createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vectorToStore,
1624
- Value dest,
1625
- ArrayRef<int64_t > inputVecSizesForLeadingDims,
1626
- SmallVector<Value> writeIndices = {},
1616
+ createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vecToStore,
1617
+ Value dest, SmallVector<Value> writeIndices = {},
1627
1618
bool useInBoundsInsteadOfMasking = false ) {
1628
1619
1629
1620
ShapedType destType = cast<ShapedType>(dest.getType ());
1630
1621
int64_t destRank = destType.getRank ();
1631
1622
auto destShape = destType.getShape ();
1632
1623
1633
- VectorType vecToStoreType = cast<VectorType>(vectorToStore .getType ());
1624
+ VectorType vecToStoreType = cast<VectorType>(vecToStore .getType ());
1634
1625
int64_t vecToStoreRank = vecToStoreType.getRank ();
1635
1626
auto vecToStoreShape = vecToStoreType.getShape ();
1636
1627
1637
1628
// Compute the in_bounds attribute
1638
1629
SmallVector<bool > inBoundsVal (vecToStoreRank, true );
1639
1630
if (useInBoundsInsteadOfMasking) {
1640
- // In this case, assume that all the required vector sizes have been
1641
- // provided.
1642
- assert (inputVecSizesForLeadingDims.size () ==
1643
- static_cast <size_t >(vecToStoreType.getRank ()) &&
1644
- " Insufficient number of input vector sizes!" );
1645
- // Update the inBounds attribute.
1646
1631
for (unsigned i = 0 ; i < destRank; i++)
1647
- inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims [i]) &&
1632
+ inBoundsVal[i] = (destShape[i] == vecToStoreShape [i]) &&
1648
1633
!ShapedType::isDynamic (destShape[i]);
1649
1634
}
1650
1635
@@ -1660,7 +1645,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
1660
1645
// Generate the xfer_write Op
1661
1646
Operation *write =
1662
1647
builder.create <vector::TransferWriteOp>(loc,
1663
- /* vector=*/ vectorToStore ,
1648
+ /* vector=*/ vecToStore ,
1664
1649
/* source=*/ dest,
1665
1650
/* indices=*/ writeIndices,
1666
1651
/* inBounds=*/ inBoundsVal);
@@ -1669,46 +1654,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
1669
1654
if (useInBoundsInsteadOfMasking)
1670
1655
return write;
1671
1656
1672
- assert (llvm::none_of (
1673
- destShape.drop_front (inputVecSizesForLeadingDims.size ()),
1674
- [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1675
- " Only dims aligned with inputVecSizesForLeadingDims may be dynamic" );
1676
-
1677
- // Check if masking is needed.
1678
- bool needMaskForWrite =
1679
- !llvm::equal (inputVecSizesForLeadingDims,
1680
- destShape.take_front (destRank - vecToStoreRank +
1681
- inputVecSizesForLeadingDims.size ()));
1682
-
1683
- // If masking is needed, generate the mask and mask the operation.
1684
- if (needMaskForWrite) {
1685
- // Get the mask shape + type. Missing mask dimensions are taken from
1686
- // `vectorToStore`.
1687
- SmallVector<int64_t > writeMaskShape;
1688
- writeMaskShape.append (inputVecSizesForLeadingDims.begin (),
1689
- inputVecSizesForLeadingDims.end ());
1690
- if (vecToStoreRank >
1691
- static_cast <int64_t >(inputVecSizesForLeadingDims.size ()))
1692
- writeMaskShape.append (vecToStoreShape.begin () +
1693
- inputVecSizesForLeadingDims.size (),
1694
- vecToStoreShape.end ());
1695
- auto writeMaskType = VectorType::get (writeMaskShape, builder.getI1Type ());
1696
-
1697
- SmallVector<OpFoldResult> destSizes =
1698
- tensor::getMixedSizes (builder, loc, dest);
1699
- SmallVector<OpFoldResult> maskSizes (destSizes.end () - writeMaskShape.size (),
1700
- destSizes.end ());
1701
-
1702
- if (isMaskTriviallyFoldable (maskSizes, writeIndices, destShape,
1703
- writeMaskShape))
1704
- return write;
1705
-
1706
- Value maskForWrite = builder.createOrFold <vector::CreateMaskOp>(
1707
- loc, writeMaskType, maskSizes);
1708
- write = mlir::vector::maskOperation (builder, write, maskForWrite);
1709
- }
1657
+ // Check if masking is needed. If not, exit.
1658
+ if (llvm::equal (vecToStoreShape, destShape.take_back (vecToStoreRank)))
1659
+ return write;
1660
+
1661
+ // Compute the mask and mask the write Op.
1662
+ auto writeMaskType = VectorType::get (vecToStoreShape, builder.getI1Type ());
1663
+
1664
+ SmallVector<OpFoldResult> destSizes =
1665
+ tensor::getMixedSizes (builder, loc, dest);
1666
+ SmallVector<OpFoldResult> maskSizes (destSizes.end () - vecToStoreRank,
1667
+ destSizes.end ());
1668
+
1669
+ if (isMaskTriviallyFoldable (maskSizes, writeIndices, destShape,
1670
+ vecToStoreShape))
1671
+ return write;
1710
1672
1711
- return write;
1673
+ Value maskForWrite =
1674
+ builder.createOrFold <vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1675
+ return mlir::vector::maskOperation (builder, write, maskForWrite);
1712
1676
}
1713
1677
1714
1678
// / Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1808,10 +1772,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1808
1772
Value dest = rewriter.create <tensor::EmptyOp>(
1809
1773
loc, reifiedReturnShapes[0 ],
1810
1774
transposeOp.getResult ().getType ().getElementType ());
1811
- Operation *write = createWriteOrMaskedWrite (
1812
- rewriter, loc, transposeOp.getResult (), dest,
1813
- /* inputVecSizesForLeadingDims= */ inputVectorSizes, /* writeIndices=*/ {},
1814
- /* useInBoundsInsteadOfMasking=*/ false );
1775
+ Operation *write =
1776
+ createWriteOrMaskedWrite ( rewriter, loc, transposeOp.getResult (), dest,
1777
+ /* writeIndices=*/ {},
1778
+ /* useInBoundsInsteadOfMasking=*/ false );
1815
1779
newResults.push_back (write->getResult (0 ));
1816
1780
return success ();
1817
1781
}
@@ -1949,7 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1949
1913
shapeCastOp.getResult ().getType ().getElementType ());
1950
1914
Operation *write = createWriteOrMaskedWrite (
1951
1915
rewriter, loc, shapeCastOp.getResult (), dest,
1952
- /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
1953
1916
/* writeIndices=*/ {}, useInBoundsInsteadOfMasking);
1954
1917
newResults.push_back (write->getResult (0 ));
1955
1918
return success ();
@@ -1982,10 +1945,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1982
1945
// Create Xfer write Op
1983
1946
Value dest = rewriter.create <tensor::EmptyOp>(
1984
1947
loc, reifiedReturnShapes[0 ], padOp.getResultType ().getElementType ());
1985
- Operation *write = createWriteOrMaskedWrite (
1986
- rewriter, loc, maskedRead, dest,
1987
- /* inputVecSizesForLeadingDims=*/ inputVectorSizes, {},
1988
- /* useInBoundsInsteadOfMasking=*/ false );
1948
+ Operation *write =
1949
+ createWriteOrMaskedWrite (rewriter, loc, maskedRead, dest, {},
1950
+ /* useInBoundsInsteadOfMasking=*/ false );
1989
1951
newResults.push_back (write->getResult (0 ));
1990
1952
return success ();
1991
1953
}
@@ -3041,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
3041
3003
// Create write
3042
3004
auto writeIndices =
3043
3005
getValueOrCreateConstantIndexOp (rewriter, loc, sliceOp.getMixedOffsets ());
3044
- Operation *write = createWriteOrMaskedWrite (
3045
- rewriter, loc, read, sliceOp.getDest (), vecType. getShape (), writeIndices);
3006
+ Operation *write = createWriteOrMaskedWrite (rewriter, loc, read,
3007
+ sliceOp.getDest (), writeIndices);
3046
3008
3047
3009
// 4. Finalize
3048
3010
newResults.push_back (write->getResult (0 ));
0 commit comments