Skip to content

Commit e62b72a

Browse files
authored
Merge pull request #34 from Xilinx/dominik.FXML-1981.fold_concats
feat(TosaCanonicalizations): FXML-1981 fold consecutive concats on same axis
2 parents 37665ac + 5a529a4 commit e62b72a

File tree

3 files changed

+128
-0
lines changed

3 files changed

+128
-0
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,7 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
13831383
);
13841384

13851385
let hasCanonicalizer = 1;
1386+
let hasFolder = 1;
13861387
}
13871388

13881389
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,3 +995,37 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
995995
return getInput1();
996996
return {};
997997
}
998+
999+
OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
1000+
// Fold consecutive concats on the same axis into a single op.
1001+
// Keep track of the operands so we are able to construct a new concat
1002+
// later. Conservatively assume that we double the number of operands when
1003+
// folding
1004+
SmallVector<Value, 8> concatOperands;
1005+
concatOperands.reserve(2 * getNumOperands());
1006+
1007+
// Find all operands that are foldable concats
1008+
bool canFold = false;
1009+
for (Value operand : getOperands()) {
1010+
concatOperands.emplace_back(operand);
1011+
1012+
auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1013+
if (!producer)
1014+
continue;
1015+
1016+
// Foldable if axes are the same
1017+
if (getAxis() != producer.getAxis())
1018+
continue;
1019+
1020+
// Replace the original operand with all incoming operands
1021+
canFold = true;
1022+
concatOperands.pop_back();
1023+
llvm::append_range(concatOperands, producer->getOperands());
1024+
}
1025+
1026+
if (!canFold)
1027+
return {};
1028+
1029+
getOperation()->setOperands(concatOperands);
1030+
return getResult();
1031+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// RUN: mlir-opt --split-input-file --canonicalize %s | FileCheck %s
2+
3+
func.func @single_concat(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> {
4+
%0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
5+
return %0 : tensor<1x2x7x7xf32>
6+
}
7+
8+
// CHECK-LABEL: func.func @single_concat(
9+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> {
10+
// CHECK: %[[VAL_1:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]]) {axis = 1 : i64} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
11+
// CHECK: return %[[VAL_1]] : tensor<1x2x7x7xf32>
12+
// CHECK: }
13+
14+
// -----
15+
16+
func.func @concat_different_axis(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> {
17+
%0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
18+
%1 = "tosa.concat"(%0, %0) {axis = 0} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
19+
return %1 : tensor<2x2x7x7xf32>
20+
}
21+
22+
// CHECK-LABEL: func.func @concat_different_axis(
23+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> {
24+
// CHECK: %[[VAL_1:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]]) {axis = 1 : i64} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
25+
// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_1]]) {axis = 0 : i64} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
26+
// CHECK: return %[[VAL_2]] : tensor<2x2x7x7xf32>
27+
// CHECK: }
28+
29+
// -----
30+
31+
func.func @fold_concats(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
32+
%tmp = tensor.empty() : tensor<1x1x7x7xf32>
33+
%0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
34+
%1 = "tosa.concat"(%tmp, %0, %tmp) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
35+
return %1 : tensor<1x4x7x7xf32>
36+
}
37+
38+
// CHECK-LABEL: func.func @fold_concats(
39+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
40+
// CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32>
41+
// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]]) {axis = 1 : i64} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
42+
// CHECK: return %[[VAL_2]] : tensor<1x4x7x7xf32>
43+
// CHECK: }
44+
45+
// -----
46+
47+
func.func @nested_fold(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> {
48+
%tmp = tensor.empty() : tensor<1x1x7x7xf32>
49+
%0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
50+
%1 = "tosa.concat"(%tmp, %0, %tmp) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
51+
%2 = "tosa.concat"(%1, %1) {axis = 1} : (tensor<1x4x7x7xf32>, tensor<1x4x7x7xf32>) -> tensor<1x8x7x7xf32>
52+
return %2 : tensor<1x8x7x7xf32>
53+
}
54+
55+
// CHECK-LABEL: func.func @nested_fold(
56+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> {
57+
// CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32>
58+
// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]]) {axis = 1 : i64} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32>
59+
// CHECK: return %[[VAL_2]] : tensor<1x8x7x7xf32>
60+
// CHECK: }
61+
62+
// -----
63+
64+
func.func @wide_fold(%arg0: tensor<1x1x7x7xf32>, %arg1: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
65+
%0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
66+
%1 = "tosa.concat"(%arg1, %arg1) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
67+
%2 = "tosa.concat"(%0, %1) {axis = 1} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32>
68+
return %2 : tensor<1x4x7x7xf32>
69+
}
70+
71+
// CHECK-LABEL: func.func @wide_fold(
72+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>,
73+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
74+
// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_1]]) {axis = 1 : i64} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
75+
// CHECK: return %[[VAL_2]] : tensor<1x4x7x7xf32>
76+
// CHECK: }
77+
78+
// -----
79+
80+
func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> {
81+
%0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>) -> tensor<1x2x8x8xf32>
82+
%1 = "tosa.concat"(%arg1, %arg1) {axis = 2} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
83+
%2 = "tosa.concat"(%0, %1) {axis = 1} : (tensor<1x2x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
84+
return %2 : tensor<1x4x8x8xf32>
85+
}
86+
87+
// CHECK-LABEL: func.func @partially_foldable(
88+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x8x8xf32>,
89+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> {
90+
// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_1]]) {axis = 2 : i64} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
91+
// CHECK: %[[VAL_3:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]], %[[VAL_2]]) {axis = 1 : i64} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
92+
// CHECK: return %[[VAL_3]] : tensor<1x4x8x8xf32>
93+
// CHECK: }

0 commit comments

Comments
 (0)