18
18
#include < llvm/ADT/FloatingPointMode.h>
19
19
#include < llvm/ADT/SmallVector.h>
20
20
#include < mlir/IR/BuiltinAttributes.h>
21
+ #include < mlir/Support/LogicalResult.h>
21
22
22
23
using namespace mlir ;
23
24
using namespace mlir ::tosa;
@@ -38,10 +39,9 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
38
39
return recip;
39
40
}
40
41
41
- ConstOp replaceTensorWithReciprocal (ConstOp tensorToReplace,
42
- const DenseElementsAttr &inputValues,
43
- PatternRewriter &rewriter) const {
44
-
42
+ DenseElementsAttr
43
+ replaceTensorWithReciprocal (ConstOp tensorToReplace,
44
+ const DenseElementsAttr &inputValues) const {
45
45
// TODO it would be nicer to do this in-place
46
46
47
47
// Compute the reciprocal for each tensor element
@@ -57,9 +57,7 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
57
57
// Replace the current tensor with one containing the computed reciprocals
58
58
auto newTensor =
59
59
DenseElementsAttr::get (inputValues.getType (), transformedValues);
60
- auto newOp = rewriter.replaceOpWithNewOp <ConstOp>(
61
- tensorToReplace, newTensor.getType (), newTensor);
62
- return newOp;
60
+ return newTensor;
63
61
}
64
62
65
63
LogicalResult matchAndRewrite (ReciprocalOp recip,
@@ -116,14 +114,10 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
116
114
}
117
115
118
116
// Create a new tensor with the updated values
119
- auto newOp =
120
- replaceTensorWithReciprocal (definingConstOp, inputValues, rewriter);
117
+ auto newTensor = replaceTensorWithReciprocal (definingConstOp, inputValues);
121
118
122
119
// Replace the use of the reciprocal with the transformed tensor
123
- auto updateUse = [&recip, &newOp]() { recip->replaceAllUsesWith (newOp); };
124
- rewriter.updateRootInPlace (*(recip->getUsers ().begin ()), updateUse);
125
- // Remove the reciprocal operation
126
- rewriter.eraseOp (recip);
120
+ rewriter.replaceOpWithNewOp <ConstOp>(recip, newTensor.getType (), newTensor);
127
121
return success ();
128
122
}
129
123
};
0 commit comments