@@ -1731,33 +1731,33 @@ struct DropUnitDimFromBroadcastOp final
1731
1731
1732
1732
LogicalResult matchAndRewrite (vector::BroadcastOp broadcastOp,
1733
1733
PatternRewriter &rewriter) const override {
1734
- auto srcVT = dyn_cast<VectorType>(broadcastOp.getSourceType ());
1735
- if (!srcVT )
1734
+ auto srcVecTy = dyn_cast<VectorType>(broadcastOp.getSourceType ());
1735
+ if (!srcVecTy )
1736
1736
return failure ();
1737
- auto resVT = broadcastOp.getResultVectorType ();
1738
- VectorType newSrcVT = srcVT ;
1739
- VectorType newResVT = resVT ;
1737
+ auto resVecTy = broadcastOp.getResultVectorType ();
1738
+ auto srcVecTyBuilder = VectorType::Builder (srcVecTy) ;
1739
+ auto resVecTyBuilder = VectorType::Builder (resVecTy) ;
1740
1740
auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims ();
1741
1741
// Reversing allows us to remove dims from the back without keeping track of
1742
1742
// removed dimensions.
1743
- for (const auto &dim : llvm::enumerate (llvm::reverse (srcVT.getShape ()))) {
1743
+ for (const auto &dim :
1744
+ llvm::enumerate (llvm::reverse (srcVecTy.getShape ()))) {
1744
1745
if (dim.value () == 1 &&
1745
- !srcVT.getScalableDims ()[srcVT.getRank () - dim.index () - 1 ] &&
1746
- !broadcastedUnitDims.contains (srcVT.getRank () - dim.index () - 1 )) {
1747
- newSrcVT = VectorType::Builder (newSrcVT).dropDim (srcVT.getRank () -
1748
- dim.index () - 1 );
1749
- newResVT = VectorType::Builder (newResVT).dropDim (resVT.getRank () -
1750
- dim.index () - 1 );
1746
+ !srcVecTy.getScalableDims ()[srcVecTy.getRank () - dim.index () - 1 ] &&
1747
+ !broadcastedUnitDims.contains (srcVecTy.getRank () - dim.index () - 1 )) {
1748
+ srcVecTyBuilder.dropDim (srcVecTy.getRank () - dim.index () - 1 );
1749
+ resVecTyBuilder.dropDim (resVecTy.getRank () - dim.index () - 1 );
1751
1750
}
1752
1751
}
1753
1752
1754
- if (newSrcVT == srcVT )
1753
+ if (VectorType (srcVecTyBuilder) == srcVecTy )
1755
1754
return failure ();
1756
1755
auto loc = broadcastOp->getLoc ();
1757
1756
auto newSource = rewriter.create <vector::ShapeCastOp>(
1758
- loc, newSrcVT, broadcastOp.getSource ());
1759
- auto newOp = rewriter.create <vector::BroadcastOp>(loc, newResVT, newSource);
1760
- rewriter.replaceOpWithNewOp <ShapeCastOp>(broadcastOp, resVT,
1757
+ loc, VectorType (srcVecTyBuilder), broadcastOp.getSource ());
1758
+ auto newOp = rewriter.create <vector::BroadcastOp>(
1759
+ loc, VectorType (resVecTyBuilder), newSource);
1760
+ rewriter.replaceOpWithNewOp <ShapeCastOp>(broadcastOp, resVecTy,
1761
1761
newOp.getResult ());
1762
1762
return success ();
1763
1763
}
0 commit comments