Skip to content

Commit 20684a4

Browse files
authored
Canonicalize tosa sqrt + reciprocal into rsqrt (#88)
* Add draft for canonicalization of sqrt + reciprocal in rsqrt * Address comments: Add error message, handle tile case and reject non-float scales
1 parent d0868e8 commit 20684a4

File tree

3 files changed

+88
-0
lines changed

3 files changed

+88
-0
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,8 @@ def Tosa_PowOp : Tosa_ElemWiseBinaryOp<"pow"> {
804804
let results = (outs
805805
Tosa_Tensor:$z
806806
);
807+
808+
let hasCanonicalizer = 1;
807809
}
808810

809811
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
1818
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
1919
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
20+
#include "mlir/IR/BuiltinAttributes.h"
2021
#include "mlir/IR/BuiltinTypes.h"
2122
#include "mlir/IR/DialectImplementation.h"
2223
#include "mlir/IR/Matchers.h"
@@ -62,6 +63,57 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
6263
results.add<ConcatOptimization>(context);
6364
}
6465

66+
struct SqrtReciprocalOptimization : public OpRewritePattern<tosa::PowOp> {
67+
using OpRewritePattern<tosa::PowOp>::OpRewritePattern;
68+
// Pattern that matches a Sqrt + Reciprocal to replace them by a rsqrt.
69+
// Sqrt is represented in tosa by a Pow so we check for Pow + reciprocal.
70+
LogicalResult matchAndRewrite(tosa::PowOp op,
71+
PatternRewriter &rewriter) const override {
72+
// Check that the PowOp has a single user
73+
if (!op->hasOneUse())
74+
return rewriter.notifyMatchFailure(op, "pow operator has more than one user");
75+
76+
Operation* user = *op->user_begin();
77+
// Check that this user is a reciprocal
78+
if (!isa<tosa::ReciprocalOp>(user))
79+
return rewriter.notifyMatchFailure(op, "expected a pow + reciprocal pattern");
80+
81+
// Check that the Pow op is an Sqrt - its second input should be the scale, 0.5 for Sqrt.
82+
Operation* powScale = op.getInput2().getDefiningOp();
83+
if (!powScale || !isa<tosa::ConstOp>(powScale))
84+
return rewriter.notifyMatchFailure(op, "expected the pow to have a constant scale input");
85+
86+
auto scale = cast<DenseElementsAttr>(cast<tosa::ConstOp>(powScale).getValue());
87+
if (!scale.isSplat())
88+
return rewriter.notifyMatchFailure(op, "expected the pow scale to be a splat tensor");
89+
90+
auto constantType = scale.getElementType();
91+
float scaleValue = 0.;
92+
if (constantType.isF32())
93+
scaleValue = scale.getSplatValue<float>();
94+
else
95+
return rewriter.notifyMatchFailure(op, "unexpected type for scale value of the pow op");
96+
if(scaleValue != 0.5)
97+
return rewriter.notifyMatchFailure(op, "expected the pow to have a scale of 0.5 to be a sqrt");
98+
99+
auto inputType = cast<ShapedType>(op.getOperand(0).getType());
100+
auto outputType = cast<ShapedType>(op.getType());
101+
// If the operator needs tiling, fail to match
102+
// An improvement for the future would be to generate a tile operator here instead
103+
if (inputType != outputType)
104+
return rewriter.notifyMatchFailure(op, "input type and output type are different, tiling is not supported for this canonicalization");
105+
106+
rewriter.replaceOpWithNewOp<tosa::RsqrtOp>(user, outputType, op.getInput1());
107+
108+
return success();
109+
}
110+
};
111+
112+
void PowOp::getCanonicalizationPatterns(RewritePatternSet &results,
113+
MLIRContext *context) {
114+
results.add<SqrtReciprocalOptimization>(context);
115+
}
116+
65117
LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
66118
auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
67119
if (!notOp)

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,40 @@ func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 :
584584

585585
// -----
586586

587+
// CHECK-LABEL: @canonicalize_optimize_sqrt_reciprocal
588+
func.func @canonicalize_optimize_sqrt_reciprocal(%arg0: tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32> {
589+
// CHECK: %[[RSQRT:.*]] = "tosa.rsqrt"(%arg{{.*}}) : (tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32>
590+
// CHECK: return %[[RSQRT]] : tensor<1x5x1x1xf32>
591+
%0 = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32>
592+
%1 = "tosa.pow"(%arg0, %0) : (tensor<1x5x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x5x1x1xf32>
593+
%2 = "tosa.reciprocal"(%1) : (tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32>
594+
return %2 : tensor<1x5x1x1xf32>
595+
}
596+
597+
// -----
598+
599+
// CHECK-LABEL: @canonicalize_optimize_sqrt_reciprocal_no_match
600+
func.func @canonicalize_optimize_sqrt_reciprocal_no_match(%arg0: tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32> {
601+
// CHECK-NOT: tosa.rsqrt"(%arg{{.*}})
602+
%0 = "tosa.const"() <{value = dense<4.000000e-01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32>
603+
%1 = "tosa.pow"(%arg0, %0) : (tensor<1x5x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x5x1x1xf32>
604+
%2 = "tosa.reciprocal"(%1) : (tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32>
605+
return %2 : tensor<1x5x1x1xf32>
606+
}
607+
608+
// -----
609+
610+
// CHECK-LABEL: @canonicalize_optimize_sqrt_reciprocal_tile_no_match
611+
func.func @canonicalize_optimize_sqrt_reciprocal_tile_no_match(%arg0: tensor<1x5x1x1xf32>) -> tensor<1x5x7x1xf32> {
612+
// CHECK-NOT: tosa.rsqrt"(%arg{{.*}})
613+
%0 = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1x1x7x1xf32>}> : () -> tensor<1x1x7x1xf32>
614+
%1 = "tosa.pow"(%arg0, %0) : (tensor<1x5x1x1xf32>, tensor<1x1x7x1xf32>) -> tensor<1x5x7x1xf32>
615+
%2 = "tosa.reciprocal"(%1) : (tensor<1x5x7x1xf32>) -> tensor<1x5x7x1xf32>
616+
return %2 : tensor<1x5x7x1xf32>
617+
}
618+
619+
// -----
620+
587621
// CHECK-LABEL
588622
func.func @fold_log_exp(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
589623
// CHECK: return %arg{{.*}} : tensor<?x1xf32>

0 commit comments

Comments
 (0)