@@ -1688,16 +1688,6 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1688
1688
broadcastVecType.getShape ().take_back (extractResultRank))
1689
1689
return Value ();
1690
1690
1691
- // The dim-1 broadcast -> ExtractOp folder requires in-place operation
1692
- // modifications. For dynamic position, this means we have to change the
1693
- // number of operands. This cannot be done in place since it changes the
1694
- // operation storage. For dynamic dimensions, the dim-1 broadcasting should
1695
- // be implemented as a canonicalization pattern.
1696
- // TODO: Implement canonicalization pattern for dim-1 broadcasting +
1697
- // extractop.
1698
- if (extractOp.hasDynamicPosition ())
1699
- return Value ();
1700
-
1701
1691
auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1702
1692
int64_t broadcastDstRank = broadcastOp.getResultVectorType ().getRank ();
1703
1693
@@ -1706,20 +1696,22 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1706
1696
// extract position to `0` when extracting from the source operand.
1707
1697
llvm::SetVector<int64_t > broadcastedUnitDims =
1708
1698
broadcastOp.computeBroadcastedUnitDims ();
1709
- SmallVector<int64_t > extractPos (extractOp.getStaticPosition ());
1699
+ SmallVector<OpFoldResult> extractPos (extractOp.getMixedPosition ());
1700
+ OpBuilder b (extractOp.getContext ());
1710
1701
int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1711
1702
for (int64_t i = broadcastRankDiff, e = extractPos.size (); i < e; ++i)
1712
1703
if (broadcastedUnitDims.contains (i))
1713
- extractPos[i] = 0 ;
1704
+ extractPos[i] = b. getIndexAttr ( 0 ) ;
1714
1705
// `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1715
1706
// matching extract position when extracting from the source operand.
1716
1707
int64_t rankDiff = broadcastSrcRank - extractResultRank;
1717
1708
extractPos.erase (extractPos.begin (),
1718
1709
std::next (extractPos.begin (), extractPos.size () - rankDiff));
1719
1710
// OpBuilder is only used as a helper to build an I64ArrayAttr.
1720
- OpBuilder b (extractOp.getContext ());
1721
- extractOp.setOperand (0 , source);
1722
- extractOp.setStaticPosition (extractPos);
1711
+ auto [staticPos, dynPos] = decomposeMixedValues (extractPos);
1712
+ extractOp->setOperands (
1713
+ llvm::to_vector (llvm::concat<Value>(ValueRange (source), dynPos)));
1714
+ extractOp.setStaticPosition (staticPos);
1723
1715
return extractOp.getResult ();
1724
1716
}
1725
1717
0 commit comments