16
16
#include < cstdint>
17
17
#include < functional>
18
18
#include < optional>
19
- #include < type_traits>
20
19
21
- #include " mlir/Dialect/Affine/IR/AffineOps.h"
22
20
#include " mlir/Dialect/Arith/IR/Arith.h"
23
21
#include " mlir/Dialect/Arith/Utils/Utils.h"
24
- #include " mlir/Dialect/Linalg/IR/Linalg.h"
25
22
#include " mlir/Dialect/MemRef/IR/MemRef.h"
26
23
#include " mlir/Dialect/SCF/IR/SCF.h"
27
- #include " mlir/Dialect/Tensor/IR/Tensor.h"
28
24
#include " mlir/Dialect/Utils/IndexingUtils.h"
29
25
#include " mlir/Dialect/Utils/StructuredOpsUtils.h"
30
26
#include " mlir/Dialect/Vector/IR/VectorOps.h"
31
27
#include " mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
32
28
#include " mlir/Dialect/Vector/Utils/VectorUtils.h"
33
- #include " mlir/IR/BuiltinAttributeInterfaces.h"
34
29
#include " mlir/IR/BuiltinTypes.h"
35
- #include " mlir/IR/ImplicitLocOpBuilder.h"
36
30
#include " mlir/IR/Location.h"
37
31
#include " mlir/IR/Matchers.h"
38
32
#include " mlir/IR/PatternMatch.h"
39
33
#include " mlir/IR/TypeUtilities.h"
40
- #include " mlir/Interfaces/VectorInterfaces.h"
41
34
42
- #include " llvm/ADT/DenseSet.h"
43
- #include " llvm/ADT/MapVector.h"
44
35
#include " llvm/ADT/STLExtras.h"
45
- #include " llvm/Support/CommandLine.h"
46
- #include " llvm/Support/Debug.h"
47
36
#include " llvm/Support/FormatVariadic.h"
48
- #include " llvm/Support/raw_ostream.h"
49
37
50
38
#define DEBUG_TYPE " vector-to-vector"
51
39
@@ -71,54 +59,6 @@ static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
71
59
72
60
namespace {
73
61
74
- // / ShapeCastOpFolder folds cancelling ShapeCastOps away.
75
- //
76
- // Example:
77
- //
78
- // The following MLIR with cancelling ShapeCastOps:
79
- //
80
- // %0 = source : vector<5x4x2xf32>
81
- // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
82
- // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
83
- // %3 = user %2 : vector<5x4x2xf32>
84
- //
85
- // Should canonicalize to the following:
86
- //
87
- // %0 = source : vector<5x4x2xf32>
88
- // %1 = user %0 : vector<5x4x2xf32>
89
- //
90
- struct ShapeCastOpFolder : public OpRewritePattern <vector::ShapeCastOp> {
91
- using OpRewritePattern::OpRewritePattern;
92
-
93
- LogicalResult matchAndRewrite (vector::ShapeCastOp shapeCastOp,
94
- PatternRewriter &rewriter) const override {
95
- // Check if 'shapeCastOp' has vector source/result type.
96
- auto sourceVectorType =
97
- dyn_cast_or_null<VectorType>(shapeCastOp.getSource ().getType ());
98
- auto resultVectorType =
99
- dyn_cast_or_null<VectorType>(shapeCastOp.getResult ().getType ());
100
- if (!sourceVectorType || !resultVectorType)
101
- return failure ();
102
-
103
- // Check if shape cast op source operand is also a shape cast op.
104
- auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
105
- shapeCastOp.getSource ().getDefiningOp ());
106
- if (!sourceShapeCastOp)
107
- return failure ();
108
- auto operandSourceVectorType =
109
- cast<VectorType>(sourceShapeCastOp.getSource ().getType ());
110
- auto operandResultVectorType = sourceShapeCastOp.getType ();
111
-
112
- // Check if shape cast operations invert each other.
113
- if (operandSourceVectorType != resultVectorType ||
114
- operandResultVectorType != sourceVectorType)
115
- return failure ();
116
-
117
- rewriter.replaceOp (shapeCastOp, sourceShapeCastOp.getSource ());
118
- return success ();
119
- }
120
- };
121
-
122
62
// / Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
123
63
// / Ex:
124
64
// / ```
@@ -2113,11 +2053,6 @@ void mlir::vector::populateVectorMaskMaterializationPatterns(
2113
2053
patterns.add <FoldI1Select>(patterns.getContext (), benefit);
2114
2054
}
2115
2055
2116
- void mlir::vector::populateShapeCastFoldingPatterns (RewritePatternSet &patterns,
2117
- PatternBenefit benefit) {
2118
- patterns.add <ShapeCastOpFolder>(patterns.getContext (), benefit);
2119
- }
2120
-
2121
2056
void mlir::vector::populateDropUnitDimWithShapeCastPatterns (
2122
2057
RewritePatternSet &patterns, PatternBenefit benefit) {
2123
2058
// TODO: Consider either:
@@ -2126,8 +2061,7 @@ void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
2126
2061
// * better naming to distinguish this and
2127
2062
// populateVectorTransferCollapseInnerMostContiguousDimsPatterns.
2128
2063
patterns.add <DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
2129
- DropUnitDimsFromTransposeOp, ShapeCastOpFolder>(
2130
- patterns.getContext (), benefit);
2064
+ DropUnitDimsFromTransposeOp>(patterns.getContext (), benefit);
2131
2065
}
2132
2066
2133
2067
void mlir::vector::populateBubbleVectorBitCastOpPatterns (
0 commit comments