Skip to content

Commit d88a3a3

Browse files
authored
[mlir][vector] Remove redundant shape_cast(shape_cast(x)) pattern (#135447)
This PR removes one OpRewritePattern `shape_cast(shape_cast(x)) -> x` that is already handled by `ShapeCastOp::fold`. Note that this might affect downstream users who indirectly call `populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit)` and then use `patterns` with a `GreedyRewriteConfig config` that has `config.fold = false`. (only user I've checked is IREE, that never uses config.fold = false).
1 parent 0daf20b commit d88a3a3

File tree

4 files changed

+1
-75
lines changed

4 files changed

+1
-75
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,10 +306,6 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
306306
const UnrollVectorOptions &options,
307307
PatternBenefit benefit = 1);
308308

309-
/// Collect a set of vector.shape_cast folding patterns.
310-
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
311-
PatternBenefit benefit = 1);
312-
313309
/// Collect a set of leading one dimension removal patterns.
314310
///
315311
/// These patterns insert vector.shape_cast to remove leading one dimensions

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
#include <numeric>
1010

11-
#include "mlir/Dialect/Arith/IR/Arith.h"
1211
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1312
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1413
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -577,5 +576,4 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
577576
CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
578577
CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
579578
CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
580-
populateShapeCastFoldingPatterns(patterns, benefit);
581579
}

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -976,7 +976,6 @@ void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
976976
patterns
977977
.add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
978978
patterns.getContext(), benefit);
979-
populateShapeCastFoldingPatterns(patterns);
980979
}
981980

982981
void mlir::vector::populateFlattenVectorTransferPatterns(
@@ -985,6 +984,5 @@ void mlir::vector::populateFlattenVectorTransferPatterns(
985984
patterns.add<FlattenContiguousRowMajorTransferReadPattern,
986985
FlattenContiguousRowMajorTransferWritePattern>(
987986
patterns.getContext(), targetVectorBitwidth, benefit);
988-
populateShapeCastFoldingPatterns(patterns, benefit);
989987
populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
990988
}

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 1 addition & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -16,36 +16,24 @@
1616
#include <cstdint>
1717
#include <functional>
1818
#include <optional>
19-
#include <type_traits>
2019

21-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
2220
#include "mlir/Dialect/Arith/IR/Arith.h"
2321
#include "mlir/Dialect/Arith/Utils/Utils.h"
24-
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2522
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2623
#include "mlir/Dialect/SCF/IR/SCF.h"
27-
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2824
#include "mlir/Dialect/Utils/IndexingUtils.h"
2925
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
3026
#include "mlir/Dialect/Vector/IR/VectorOps.h"
3127
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
3228
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
33-
#include "mlir/IR/BuiltinAttributeInterfaces.h"
3429
#include "mlir/IR/BuiltinTypes.h"
35-
#include "mlir/IR/ImplicitLocOpBuilder.h"
3630
#include "mlir/IR/Location.h"
3731
#include "mlir/IR/Matchers.h"
3832
#include "mlir/IR/PatternMatch.h"
3933
#include "mlir/IR/TypeUtilities.h"
40-
#include "mlir/Interfaces/VectorInterfaces.h"
4134

42-
#include "llvm/ADT/DenseSet.h"
43-
#include "llvm/ADT/MapVector.h"
4435
#include "llvm/ADT/STLExtras.h"
45-
#include "llvm/Support/CommandLine.h"
46-
#include "llvm/Support/Debug.h"
4736
#include "llvm/Support/FormatVariadic.h"
48-
#include "llvm/Support/raw_ostream.h"
4937

5038
#define DEBUG_TYPE "vector-to-vector"
5139

@@ -71,54 +59,6 @@ static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
7159

7260
namespace {
7361

74-
/// ShapeCastOpFolder folds cancelling ShapeCastOps away.
75-
//
76-
// Example:
77-
//
78-
// The following MLIR with cancelling ShapeCastOps:
79-
//
80-
// %0 = source : vector<5x4x2xf32>
81-
// %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
82-
// %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
83-
// %3 = user %2 : vector<5x4x2xf32>
84-
//
85-
// Should canonicalize to the following:
86-
//
87-
// %0 = source : vector<5x4x2xf32>
88-
// %1 = user %0 : vector<5x4x2xf32>
89-
//
90-
struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
91-
using OpRewritePattern::OpRewritePattern;
92-
93-
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
94-
PatternRewriter &rewriter) const override {
95-
// Check if 'shapeCastOp' has vector source/result type.
96-
auto sourceVectorType =
97-
dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
98-
auto resultVectorType =
99-
dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
100-
if (!sourceVectorType || !resultVectorType)
101-
return failure();
102-
103-
// Check if shape cast op source operand is also a shape cast op.
104-
auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
105-
shapeCastOp.getSource().getDefiningOp());
106-
if (!sourceShapeCastOp)
107-
return failure();
108-
auto operandSourceVectorType =
109-
cast<VectorType>(sourceShapeCastOp.getSource().getType());
110-
auto operandResultVectorType = sourceShapeCastOp.getType();
111-
112-
// Check if shape cast operations invert each other.
113-
if (operandSourceVectorType != resultVectorType ||
114-
operandResultVectorType != sourceVectorType)
115-
return failure();
116-
117-
rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
118-
return success();
119-
}
120-
};
121-
12262
/// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
12363
/// Ex:
12464
/// ```
@@ -2113,11 +2053,6 @@ void mlir::vector::populateVectorMaskMaterializationPatterns(
21132053
patterns.add<FoldI1Select>(patterns.getContext(), benefit);
21142054
}
21152055

2116-
void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
2117-
PatternBenefit benefit) {
2118-
patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
2119-
}
2120-
21212056
void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
21222057
RewritePatternSet &patterns, PatternBenefit benefit) {
21232058
// TODO: Consider either:
@@ -2126,8 +2061,7 @@ void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
21262061
// * better naming to distinguish this and
21272062
// populateVectorTransferCollapseInnerMostContiguousDimsPatterns.
21282063
patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
2129-
DropUnitDimsFromTransposeOp, ShapeCastOpFolder>(
2130-
patterns.getContext(), benefit);
2064+
DropUnitDimsFromTransposeOp>(patterns.getContext(), benefit);
21312065
}
21322066

21332067
void mlir::vector::populateBubbleVectorBitCastOpPatterns(

0 commit comments

Comments
 (0)