@@ -5769,10 +5769,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
5769
5769
}
5770
5770
};
5771
5771
5772
- // / Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
5773
- // / i) Y = ShapeCast(X), or
5774
- // / ii) Y = Broadcast(X)
5775
- // / If both (i) and (ii) are possible, (i) is chosen.
5772
+ // / Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
5776
5773
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5777
5774
public:
5778
5775
using OpRewritePattern::OpRewritePattern;
@@ -5787,22 +5784,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5787
5784
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType ());
5788
5785
bool srcIsScalar = !srcVectorType;
5789
5786
5790
- // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
5791
- // Example:
5792
- // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
5793
- // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
5794
- // to
5795
- // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
5796
- if (srcVectorType) {
5797
- if (srcVectorType.getNumElements () ==
5798
- shapeCastOp.getResultVectorType ().getNumElements ()) {
5799
- rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(
5800
- shapeCastOp, shapeCastOp.getResultVectorType (),
5801
- broadcastOp.getSource ());
5802
- return success ();
5803
- }
5804
- }
5805
-
5806
5787
// Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
5807
5788
// Example
5808
5789
// %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
0 commit comments