Skip to content

Commit 8e461e2

Browse files
committed
refactor(TosaOps): FXML-1981 use hasFolder = 1 for concat folding
1 parent 5b9793e commit 8e461e2

File tree

3 files changed

+37
-46
lines changed

3 files changed

+37
-46
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,7 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
13831383
);
13841384

13851385
let hasCanonicalizer = 1;
1386+
let hasFolder = 1;
13861387
}
13871388

13881389
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 35 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -57,53 +57,9 @@ struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
5757
}
5858
};
5959

60-
struct ConcatFolding : public OpRewritePattern<tosa::ConcatOp> {
61-
using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
62-
63-
LogicalResult matchAndRewrite(tosa::ConcatOp op,
64-
PatternRewriter &rewriter) const override {
65-
// Fold consecutive concats on the same axis into a single op.
66-
uint64_t axis = op.getAxis();
67-
68-
// Keep track of the operands so we are able to construct a new concat
69-
// later. Conservatively assume that we double the number of operands when
70-
// folding
71-
SmallVector<Value, 8> concatOperands;
72-
concatOperands.reserve(2 * op->getNumOperands());
73-
74-
// Find all operands that are foldable concats
75-
bool canFold = false;
76-
for (Value operand : op->getOperands()) {
77-
concatOperands.emplace_back(operand);
78-
79-
auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
80-
if (!producer)
81-
continue;
82-
83-
// Foldable if axis is the same
84-
if (axis != producer.getAxis())
85-
continue;
86-
87-
// Replace the original operand with all incoming operands
88-
canFold = true;
89-
concatOperands.pop_back();
90-
llvm::append_range(concatOperands, producer->getOperands());
91-
}
92-
93-
if (!canFold)
94-
return rewriter.notifyMatchFailure(op, "No foldable concats found");
95-
96-
// Replace the original concat with a new one that contains the original and
97-
// folded operands
98-
rewriter.replaceOpWithNewOp<tosa::ConcatOp>(op, op->getResultTypes(),
99-
concatOperands, axis);
100-
return success();
101-
}
102-
};
103-
10460
void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
10561
MLIRContext *context) {
106-
results.add<ConcatOptimization, ConcatFolding>(context);
62+
results.add<ConcatOptimization>(context);
10763
}
10864

10965
struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
@@ -1039,3 +995,37 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
1039995
return getInput1();
1040996
return {};
1041997
}
998+
999+
OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
1000+
// Fold consecutive concats on the same axis into a single op.
1001+
// Keep track of the operands so we are able to construct a new concat
1002+
// later. Conservatively assume that we double the number of operands when
1003+
// folding
1004+
SmallVector<Value, 8> concatOperands;
1005+
concatOperands.reserve(2 * getNumOperands());
1006+
1007+
// Find all operands that are foldable concats
1008+
bool canFold = false;
1009+
for (Value operand : getOperands()) {
1010+
concatOperands.emplace_back(operand);
1011+
1012+
auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1013+
if (!producer)
1014+
continue;
1015+
1016+
// Foldable if axis is the same
1017+
if (getAxis() != producer.getAxis())
1018+
continue;
1019+
1020+
// Replace the original operand with all incoming operands
1021+
canFold = true;
1022+
concatOperands.pop_back();
1023+
llvm::append_range(concatOperands, producer->getOperands());
1024+
}
1025+
1026+
if (!canFold)
1027+
return {};
1028+
1029+
getOperation()->setOperands(concatOperands);
1030+
return getResult();
1031+
}

mlir/test/Dialect/Tosa/fold_concats.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s
1+
// RUN: mlir-opt --split-input-file --canonicalize %s | FileCheck %s
22

33
func.func @single_concat(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> {
44
%0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>

0 commit comments

Comments
 (0)