@@ -57,53 +57,9 @@ struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
57
57
}
58
58
};
59
59
60
- struct ConcatFolding : public OpRewritePattern <tosa::ConcatOp> {
61
- using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
62
-
63
- LogicalResult matchAndRewrite (tosa::ConcatOp op,
64
- PatternRewriter &rewriter) const override {
65
- // Fold consecutive concats on the same axis into a single op.
66
- uint64_t axis = op.getAxis ();
67
-
68
- // Keep track of the operands so we are able to construct a new concat
69
- // later. Conservatively assume that we double the number of operands when
70
- // folding
71
- SmallVector<Value, 8 > concatOperands;
72
- concatOperands.reserve (2 * op->getNumOperands ());
73
-
74
- // Find all operands that are foldable concats
75
- bool canFold = false ;
76
- for (Value operand : op->getOperands ()) {
77
- concatOperands.emplace_back (operand);
78
-
79
- auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp ());
80
- if (!producer)
81
- continue ;
82
-
83
- // Foldable if axis is the same
84
- if (axis != producer.getAxis ())
85
- continue ;
86
-
87
- // Replace the original operand with all incoming operands
88
- canFold = true ;
89
- concatOperands.pop_back ();
90
- llvm::append_range (concatOperands, producer->getOperands ());
91
- }
92
-
93
- if (!canFold)
94
- return rewriter.notifyMatchFailure (op, " No foldable concats found" );
95
-
96
- // Replace the original concat with a new one that contains the original and
97
- // folded operands
98
- rewriter.replaceOpWithNewOp <tosa::ConcatOp>(op, op->getResultTypes (),
99
- concatOperands, axis);
100
- return success ();
101
- }
102
- };
103
-
104
60
void ConcatOp::getCanonicalizationPatterns (RewritePatternSet &results,
105
61
MLIRContext *context) {
106
- results.add <ConcatOptimization, ConcatFolding >(context);
62
+ results.add <ConcatOptimization>(context);
107
63
}
108
64
109
65
struct ReshapeReshapeOptimization : public OpRewritePattern <tosa::ReshapeOp> {
@@ -1039,3 +995,37 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
1039
995
return getInput1 ();
1040
996
return {};
1041
997
}
998
+
999
+ OpFoldResult ConcatOp::fold (ArrayRef<Attribute> operands) {
1000
+ // Fold consecutive concats on the same axis into a single op.
1001
+ // Keep track of the operands so we are able to construct a new concat
1002
+ // later. Conservatively assume that we double the number of operands when
1003
+ // folding
1004
+ SmallVector<Value, 8 > concatOperands;
1005
+ concatOperands.reserve (2 * getNumOperands ());
1006
+
1007
+ // Find all operands that are foldable concats
1008
+ bool canFold = false ;
1009
+ for (Value operand : getOperands ()) {
1010
+ concatOperands.emplace_back (operand);
1011
+
1012
+ auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp ());
1013
+ if (!producer)
1014
+ continue ;
1015
+
1016
+ // Foldable if axis is the same
1017
+ if (getAxis () != producer.getAxis ())
1018
+ continue ;
1019
+
1020
+ // Replace the original operand with all incoming operands
1021
+ canFold = true ;
1022
+ concatOperands.pop_back ();
1023
+ llvm::append_range (concatOperands, producer->getOperands ());
1024
+ }
1025
+
1026
+ if (!canFold)
1027
+ return {};
1028
+
1029
+ getOperation ()->setOperands (concatOperands);
1030
+ return getResult ();
1031
+ }
0 commit comments