Skip to content

Commit e2c13ec

Browse files
committed
Make helper function return the newly constructed tensor instead of replacing the old one in there (make it more consistent with splat case)
1 parent 2ab1d7b commit e2c13ec

File tree

1 file changed

+7
-13
lines changed

1 file changed

+7
-13
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <llvm/ADT/FloatingPointMode.h>
1919
#include <llvm/ADT/SmallVector.h>
2020
#include <mlir/IR/BuiltinAttributes.h>
21+
#include <mlir/Support/LogicalResult.h>
2122

2223
using namespace mlir;
2324
using namespace mlir::tosa;
@@ -38,10 +39,9 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
3839
return recip;
3940
}
4041

41-
ConstOp replaceTensorWithReciprocal(ConstOp tensorToReplace,
42-
const DenseElementsAttr &inputValues,
43-
PatternRewriter &rewriter) const {
44-
42+
DenseElementsAttr
43+
replaceTensorWithReciprocal(ConstOp tensorToReplace,
44+
const DenseElementsAttr &inputValues) const {
4545
// TODO it would be nicer to do this in-place
4646

4747
// Compute the reciprocal for each tensor element
@@ -57,9 +57,7 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
5757
// Replace the current tensor with one containing the computed reciprocals
5858
auto newTensor =
5959
DenseElementsAttr::get(inputValues.getType(), transformedValues);
60-
auto newOp = rewriter.replaceOpWithNewOp<ConstOp>(
61-
tensorToReplace, newTensor.getType(), newTensor);
62-
return newOp;
60+
return newTensor;
6361
}
6462

6563
LogicalResult matchAndRewrite(ReciprocalOp recip,
@@ -116,14 +114,10 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
116114
}
117115

118116
// Create a new tensor with the updated values
119-
auto newOp =
120-
replaceTensorWithReciprocal(definingConstOp, inputValues, rewriter);
117+
auto newTensor = replaceTensorWithReciprocal(definingConstOp, inputValues);
121118

122119
// Replace the use of the reciprocal with the transformed tensor
123-
auto updateUse = [&recip, &newOp]() { recip->replaceAllUsesWith(newOp); };
124-
rewriter.updateRootInPlace(*(recip->getUsers().begin()), updateUse);
125-
// Remove the reciprocal operation
126-
rewriter.eraseOp(recip);
120+
rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
127121
return success();
128122
}
129123
};

0 commit comments

Comments
 (0)