Skip to content

Commit 9b67e54

Browse files
authored
TOSA: Fold concat where one argument has zero elements (#41)
1 parent 9bccb5b commit 9b67e54

File tree

3 files changed

+33
-0
lines changed

3 files changed

+33
-0
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,6 +1460,8 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
14601460
Tosa_Tensor:$output
14611461
);
14621462

1463+
let hasFolder = 1;
1464+
14631465
let hasCanonicalizer = 1;
14641466

14651467
let extraClassDeclaration = [{

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,30 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
494494
// Operator Folders.
495495
//===----------------------------------------------------------------------===//
496496

497+
static bool hasZeroSize(Type ty) {
498+
auto ranked = dyn_cast<RankedTensorType>(ty);
499+
if (!ranked)
500+
return false;
501+
return any_of(ranked.getShape(), [](auto d) { return d == 0; });
502+
}
503+
504+
OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
505+
/// Remove operands that have zero elements.
506+
bool changed = false;
507+
for (size_t i = 0; i < getInput1().size(); ) {
508+
auto input = getInput1()[i];
509+
if (hasZeroSize(input.getType())) {
510+
getInput1Mutable().erase(i);
511+
changed = true;
512+
} else {
513+
++i;
514+
}
515+
}
516+
if (changed)
517+
return getResult();
518+
return {};
519+
}
520+
497521
template <typename IntFolder, typename FloatFolder>
498522
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
499523
RankedTensorType returnTy) {

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
8686
return %1 : tensor<4xi8>
8787
}
8888

89+
// CHECK-LABEL: @concat_fold_zero
90+
func.func @concat_fold_zero(%arg0: tensor<?x0xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x2xf32>) -> tensor<?x3xf32> {
91+
// CHECK: "tosa.concat"(%arg1, %arg2) <{axis = 1 : i64}>
92+
%0 = "tosa.concat"(%arg0, %arg1, %arg2) {axis = 1 : i64}: (tensor<?x0xf32>, tensor<?x1xf32>, tensor<?x2xf32>) -> tensor<?x3xf32>
93+
return %0 : tensor<?x3xf32>
94+
}
95+
8996
// CHECK-LABEL: @concat_fold
9097
func.func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
9198
// CHECK: return %arg0

0 commit comments

Comments
 (0)