Skip to content

Commit 29e1ec6

Browse files
authored
Merge pull request #164 from Xilinx/matthias.reshape_multi_use
TOSA: tosa-layerwise-constant-fold: Folder for reshapes with multiple uses
2 parents da2c0b9 + 05efa5b commit 29e1ec6

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,37 @@ static APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy) {
573573
return recip;
574574
}
575575

576+
/// Fold reshapes. This is similar to ReshapeOp::fold, but also allows
577+
/// to fold with multiple users.
578+
struct TosaFoldConstantReshape
579+
: public TosaFoldConstantUnaryElementwise<TosaFoldConstantReshape,
580+
ReshapeOp> {
581+
using TosaFoldConstantUnaryElementwise::TosaFoldConstantUnaryElementwise;
582+
583+
LogicalResult matchAndRewrite(ReshapeOp op,
584+
PatternRewriter &rewriter) const override {
585+
auto inputTensor = op.getOperand();
586+
// Check that we can apply folding
587+
auto preCondCheck =
588+
notifyIfNoTosaDenseConstantTensor(inputTensor, op, rewriter);
589+
if (failed(preCondCheck))
590+
return preCondCheck;
591+
592+
// Extract the tensor values
593+
DenseElementsAttr inputValues;
594+
matchPattern(inputTensor, m_Constant(&inputValues));
595+
596+
// Check whether this should be folded.
597+
if (!constantUnaryOpShouldBeFolded(op, inputValues)) {
598+
return rewriter.notifyMatchFailure(
599+
op, "expected reshape op to have a single user");
600+
}
601+
DenseElementsAttr newTensor = inputValues.reshape(op.getType());
602+
rewriter.replaceOpWithNewOp<ConstOp>(op, newTensor.getType(), newTensor);
603+
return success();
604+
}
605+
};
606+
576607
struct TosaFoldConstantReciprocal
577608
: public TosaFoldConstantUnaryElementwise<TosaFoldConstantReciprocal, ReciprocalOp> {
578609
using TosaFoldConstantUnaryElementwise<TosaFoldConstantReciprocal,
@@ -1723,8 +1754,10 @@ void mlir::tosa::populateTosaFoldConstantPatterns(
17231754
MLIRContext *ctx, RewritePatternSet &patterns,
17241755
bool foldSplatOrSingleUseOnly,
17251756
bool enableIntCastFolding) {
1757+
17261758
patterns.add<TosaFoldConstantTranspose>(ctx, foldSplatOrSingleUseOnly);
17271759
patterns.add<TosaFoldConstantReciprocal>(ctx, foldSplatOrSingleUseOnly);
1760+
patterns.add<TosaFoldConstantReshape>(ctx, foldSplatOrSingleUseOnly);
17281761
patterns.add<TosaFoldConstantRSQRT>(ctx, foldSplatOrSingleUseOnly);
17291762
patterns.add<TosaFoldConstantLogicalNot>(ctx, foldSplatOrSingleUseOnly);
17301763
patterns.add<TosaFoldConstantPow>(ctx, foldSplatOrSingleUseOnly);
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: mlir-opt --tosa-layerwise-constant-fold %s | FileCheck %s
2+
// RUN: mlir-opt --tosa-layerwise-constant-fold="fold-splat-or-single-use-only=0" %s \
3+
// RUN: | FileCheck %s --check-prefix CHECK-ALWAYS
4+
5+
// CHECK-LABEL: @reshape_single_user
6+
func.func @reshape_single_user() -> tensor<1x2xf32> {
7+
// CHECK: %[[RES:.*]] = "tosa.const"{{.*}}-> tensor<1x2xf32>
8+
// CHECK: return %[[RES]]
9+
%0 = "tosa.const"() {value = dense<4.0> : tensor<2xf32>} : () -> tensor<2xf32>
10+
%1 = tosa.reshape %0 {new_shape = array<i64: 1, 2>}: (tensor<2xf32>) -> tensor<1x2xf32>
11+
return %1 : tensor<1x2xf32>
12+
}
13+
14+
// Splat constants are always folded, even when they have multiple users.
15+
// CHECK-LABEL: @reshape_multi_user_splat
16+
func.func @reshape_multi_user_splat() -> (tensor<1x2xf32>, tensor<2xf32>) {
17+
// CHECK-DAG: %[[RES:.*]] = "tosa.const"{{.*}}-> tensor<2xf32>
18+
// CHECK-DAG: %[[RESHAPED:.*]] = "tosa.const"{{.*}}-> tensor<1x2xf32>
19+
// CHECK: return %[[RESHAPED]], %[[RES]]
20+
%0 = "tosa.const"() {value = dense<4.0> : tensor<2xf32>} : () -> tensor<2xf32>
21+
%1 = tosa.reshape %0 {new_shape = array<i64: 1, 2>}: (tensor<2xf32>) -> tensor<1x2xf32>
22+
return %1, %0 : tensor<1x2xf32>, tensor<2xf32>
23+
}
24+
25+
// Non-splat constants with multiple users are only folded when
26+
// fold-splat-or-single-use-only=0 is set.
27+
// CHECK-LABEL: @reshape_multi_user_non_splat
28+
func.func @reshape_multi_user_non_splat() -> (tensor<1x2xf32>, tensor<2xf32>) {
29+
// CHECK: %[[CONST:.*]] = "tosa.const"{{.*}}-> tensor<2xf32>
30+
// CHECK: %[[RES:.*]] = tosa.reshape
31+
// CHECK: return %[[RES]], %[[CONST]]
32+
// CHECK-ALWAYS-DAG: %[[RES:.*]] = "tosa.const"{{.*}}-> tensor<2xf32>
33+
// CHECK-ALWAYS-DAG: %[[RESHAPED:.*]] = "tosa.const"{{.*}}-> tensor<1x2xf32>
34+
// CHECK-ALWAYS: return %[[RESHAPED]], %[[RES]]
35+
%0 = "tosa.const"() {value = dense<[4.0, 3.0]> : tensor<2xf32>} : () -> tensor<2xf32>
36+
%1 = tosa.reshape %0 {new_shape = array<i64: 1, 2>}: (tensor<2xf32>) -> tensor<1x2xf32>
37+
return %1, %0 : tensor<1x2xf32>, tensor<2xf32>
38+
}

0 commit comments

Comments
 (0)