Skip to content

Commit cebfd74

Browse files
committed
Hoist out vector builder
1 parent d94eb3d commit cebfd74

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,33 +1731,33 @@ struct DropUnitDimFromBroadcastOp final
17311731

17321732
LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
17331733
PatternRewriter &rewriter) const override {
1734-
auto srcVT = dyn_cast<VectorType>(broadcastOp.getSourceType());
1735-
if (!srcVT)
1734+
auto srcVecTy = dyn_cast<VectorType>(broadcastOp.getSourceType());
1735+
if (!srcVecTy)
17361736
return failure();
1737-
auto resVT = broadcastOp.getResultVectorType();
1738-
VectorType newSrcVT = srcVT;
1739-
VectorType newResVT = resVT;
1737+
auto resVecTy = broadcastOp.getResultVectorType();
1738+
auto srcVecTyBuilder = VectorType::Builder(srcVecTy);
1739+
auto resVecTyBuilder = VectorType::Builder(resVecTy);
17401740
auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims();
17411741
// Reversing allows us to remove dims from the back without keeping track of
17421742
// removed dimensions.
1743-
for (const auto &dim : llvm::enumerate(llvm::reverse(srcVT.getShape()))) {
1743+
for (const auto &dim :
1744+
llvm::enumerate(llvm::reverse(srcVecTy.getShape()))) {
17441745
if (dim.value() == 1 &&
1745-
!srcVT.getScalableDims()[srcVT.getRank() - dim.index() - 1] &&
1746-
!broadcastedUnitDims.contains(srcVT.getRank() - dim.index() - 1)) {
1747-
newSrcVT = VectorType::Builder(newSrcVT).dropDim(srcVT.getRank() -
1748-
dim.index() - 1);
1749-
newResVT = VectorType::Builder(newResVT).dropDim(resVT.getRank() -
1750-
dim.index() - 1);
1746+
!srcVecTy.getScalableDims()[srcVecTy.getRank() - dim.index() - 1] &&
1747+
!broadcastedUnitDims.contains(srcVecTy.getRank() - dim.index() - 1)) {
1748+
srcVecTyBuilder.dropDim(srcVecTy.getRank() - dim.index() - 1);
1749+
resVecTyBuilder.dropDim(resVecTy.getRank() - dim.index() - 1);
17511750
}
17521751
}
17531752

1754-
if (newSrcVT == srcVT)
1753+
if (VectorType(srcVecTyBuilder) == srcVecTy)
17551754
return failure();
17561755
auto loc = broadcastOp->getLoc();
17571756
auto newSource = rewriter.create<vector::ShapeCastOp>(
1758-
loc, newSrcVT, broadcastOp.getSource());
1759-
auto newOp = rewriter.create<vector::BroadcastOp>(loc, newResVT, newSource);
1760-
rewriter.replaceOpWithNewOp<ShapeCastOp>(broadcastOp, resVT,
1757+
loc, VectorType(srcVecTyBuilder), broadcastOp.getSource());
1758+
auto newOp = rewriter.create<vector::BroadcastOp>(
1759+
loc, VectorType(resVecTyBuilder), newSource);
1760+
rewriter.replaceOpWithNewOp<ShapeCastOp>(broadcastOp, resVecTy,
17611761
newOp.getResult());
17621762
return success();
17631763
}

0 commit comments

Comments
 (0)