|
17 | 17 | #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
|
18 | 18 | #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
|
19 | 19 | #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
|
| 20 | +#include "mlir/IR/BuiltinAttributes.h" |
20 | 21 | #include "mlir/IR/BuiltinTypes.h"
|
21 | 22 | #include "mlir/IR/DialectImplementation.h"
|
22 | 23 | #include "mlir/IR/Matchers.h"
|
@@ -62,6 +63,57 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
62 | 63 | results.add<ConcatOptimization>(context);
|
63 | 64 | }
|
64 | 65 |
|
| 66 | +struct SqrtReciprocalOptimization : public OpRewritePattern<tosa::PowOp> { |
| 67 | + using OpRewritePattern<tosa::PowOp>::OpRewritePattern; |
| 68 | + // Pattern that matches a Sqrt + Reciprocal to replace them by a rsqrt. |
| 69 | + // Sqrt is represented in tosa by a Pow so we check for Pow + reciprocal. |
| 70 | + LogicalResult matchAndRewrite(tosa::PowOp op, |
| 71 | + PatternRewriter &rewriter) const override { |
| 72 | + // Check that the PowOp has a single user |
| 73 | + if (!op->hasOneUse()) |
| 74 | + return rewriter.notifyMatchFailure(op, "pow operator has more than one user"); |
| 75 | + |
| 76 | + Operation* user = *op->user_begin(); |
| 77 | + // Check that this user is a reciprocal |
| 78 | + if (!isa<tosa::ReciprocalOp>(user)) |
| 79 | + return rewriter.notifyMatchFailure(op, "expected a pow + reciprocal pattern"); |
| 80 | + |
| 81 | + // Check that the Pow op is an Sqrt - its second input should be the scale, 0.5 for Sqrt. |
| 82 | + Operation* powScale = op.getInput2().getDefiningOp(); |
| 83 | + if (!powScale || !isa<tosa::ConstOp>(powScale)) |
| 84 | + return rewriter.notifyMatchFailure(op, "expected the pow to have a constant scale input"); |
| 85 | + |
| 86 | + auto scale = cast<DenseElementsAttr>(cast<tosa::ConstOp>(powScale).getValue()); |
| 87 | + if (!scale.isSplat()) |
| 88 | + return rewriter.notifyMatchFailure(op, "expected the pow scale to be a splat tensor"); |
| 89 | + |
| 90 | + auto constantType = scale.getElementType(); |
| 91 | + float scaleValue = 0.; |
| 92 | + if (constantType.isF32()) |
| 93 | + scaleValue = scale.getSplatValue<float>(); |
| 94 | + else |
| 95 | + return rewriter.notifyMatchFailure(op, "unexpected type for scale value of the pow op"); |
| 96 | + if(scaleValue != 0.5) |
| 97 | + return rewriter.notifyMatchFailure(op, "expected the pow to have a scale of 0.5 to be a sqrt"); |
| 98 | + |
| 99 | + auto inputType = cast<ShapedType>(op.getOperand(0).getType()); |
| 100 | + auto outputType = cast<ShapedType>(op.getType()); |
| 101 | + // If the operator needs tiling, fail to match |
| 102 | + // An improvement for the future would be to generate a tile operator here instead |
| 103 | + if (inputType != outputType) |
| 104 | + return rewriter.notifyMatchFailure(op, "input type and output type are different, tiling is not supported for this canonicalization"); |
| 105 | + |
| 106 | + rewriter.replaceOpWithNewOp<tosa::RsqrtOp>(user, outputType, op.getInput1()); |
| 107 | + |
| 108 | + return success(); |
| 109 | + } |
| 110 | +}; |
| 111 | + |
| 112 | +void PowOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 113 | + MLIRContext *context) { |
| 114 | + results.add<SqrtReciprocalOptimization>(context); |
| 115 | +} |
| 116 | + |
65 | 117 | LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
|
66 | 118 | auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
|
67 | 119 | if (!notOp)
|
|
0 commit comments