Skip to content

Commit 73a79e8

Browse files
committed
TOSA: tosa-layerwise-constant-fold: Folder for reshapes with multiple uses
1 parent dfa02b1 commit 73a79e8

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,38 @@ static APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy) {
573573
return recip;
574574
}
575575

576+
/// Fold reshapes. This is similar to ReshapeOp::fold, but also allows
577+
/// to fold with multiple users.
578+
struct TosaFoldConstantReshape
579+
: public TosaFoldConstantUnaryElementwise<TosaFoldConstantReshape,
580+
ReshapeOp> {
581+
using TosaFoldConstantUnaryElementwise::TosaFoldConstantUnaryElementwise;
582+
583+
LogicalResult matchAndRewrite(ReshapeOp op,
584+
PatternRewriter &rewriter) const override {
585+
auto inputTensor = op.getOperand();
586+
// Check that we can apply folding
587+
auto preCondCheck =
588+
notifyIfNoTosaDenseConstantTensor(inputTensor, op, rewriter);
589+
if (failed(preCondCheck)) {
590+
return preCondCheck;
591+
}
592+
593+
// Extract the tensor values
594+
DenseElementsAttr inputValues;
595+
matchPattern(inputTensor, m_Constant(&inputValues));
596+
597+
// Check whether this should be folded.
598+
if (!constantUnaryOpShouldBeFolded(op, inputValues)) {
599+
return rewriter.notifyMatchFailure(
600+
op, "expected reshape op to have a single user");
601+
}
602+
DenseElementsAttr newTensor = inputValues.reshape(op.getType());
603+
rewriter.replaceOpWithNewOp<ConstOp>(op, newTensor.getType(), newTensor);
604+
return success();
605+
}
606+
};
607+
576608
struct TosaFoldConstantReciprocal
577609
: public TosaFoldConstantUnaryElementwise<TosaFoldConstantReciprocal, ReciprocalOp> {
578610
using TosaFoldConstantUnaryElementwise<TosaFoldConstantReciprocal,
@@ -1723,6 +1755,8 @@ void mlir::tosa::populateTosaFoldConstantPatterns(
17231755
MLIRContext *ctx, RewritePatternSet &patterns,
17241756
bool foldSplatOrSingleUseOnly,
17251757
bool enableIntCastFolding) {
1758+
1759+
patterns.add<TosaFoldConstantReshape>(ctx, foldSplatOrSingleUseOnly);
17261760
patterns.add<TosaFoldConstantTranspose>(ctx, foldSplatOrSingleUseOnly);
17271761
patterns.add<TosaFoldConstantReciprocal>(ctx, foldSplatOrSingleUseOnly);
17281762
patterns.add<TosaFoldConstantRSQRT>(ctx, foldSplatOrSingleUseOnly);
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s
2+
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold="fold-splat-or-single-use-only=0" %s | FileCheck %s --check-prefix CHECK-MULTI
3+
// CHECK-LABEL: @reshape_single_user
4+
func.func @reshape_single_user() -> tensor<1x2xf32> {
5+
// CHECK: %[[RES:.*]] = "tosa.const"{{.*}}-> tensor<1x2xf32>
6+
// CHECK: return %[[RES]]
7+
%0 = "tosa.const"() {value = dense<4.0> : tensor<2xf32>} : () -> tensor<2xf32>
8+
%1 = tosa.reshape %0 {new_shape = array<i64: 1, 2>}: (tensor<2xf32>) -> tensor<1x2xf32>
9+
return %1 : tensor<1x2xf32>
10+
}
11+
12+
// CHECK-LABEL: @reshape_multi_user
13+
func.func @reshape_multi_user_splat() -> (tensor<1x2xf32>, tensor<2xf32>) {
14+
// CHECK-DAG: %[[RES:.*]] = "tosa.const"{{.*}}-> tensor<2xf32>
15+
// CHECK-DAG: %[[RESHAPED:.*]] = "tosa.const"{{.*}}-> tensor<1x2xf32>
16+
// CHECK: return %[[RESHAPED]], %[[RES]]
17+
%0 = "tosa.const"() {value = dense<4.0> : tensor<2xf32>} : () -> tensor<2xf32>
18+
%1 = tosa.reshape %0 {new_shape = array<i64: 1, 2>}: (tensor<2xf32>) -> tensor<1x2xf32>
19+
return %1, %0 : tensor<1x2xf32>, tensor<2xf32>
20+
}
21+
22+
// CHECK-LABEL: @reshape_multi_user_non_splat
23+
func.func @reshape_multi_user_non_splat() -> (tensor<1x2xf32>, tensor<2xf32>) {
24+
// CHECK: %[[CONST:.*]] = "tosa.const"{{.*}}-> tensor<2xf32>
25+
// CHECK: %[[RES:.*]] = tosa.reshape
26+
// CHECK: return %[[RES]], %[[CONST]]
27+
// CHECK-MULTI-DAG: %[[RES:.*]] = "tosa.const"{{.*}}-> tensor<2xf32>
28+
// CHECK-MULTI-DAG: %[[RESHAPED:.*]] = "tosa.const"{{.*}}-> tensor<1x2xf32>
29+
// CHECK-MULTI: return %[[RESHAPED]], %[[RES]]
30+
%0 = "tosa.const"() {value = dense<[4.0, 3.0]> : tensor<2xf32>} : () -> tensor<2xf32>
31+
%1 = tosa.reshape %0 {new_shape = array<i64: 1, 2>}: (tensor<2xf32>) -> tensor<1x2xf32>
32+
return %1, %0 : tensor<1x2xf32>, tensor<2xf32>
33+
}

0 commit comments

Comments
 (0)