Skip to content

Commit 20d9c53

Browse files
committed
Reserve space for the transformed tensor elements; drop whitespace at the beginning of functions/ifs
1 parent 576f52f commit 20d9c53

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
3131
APFloat::rmNearestTiesToEven;
3232

3333
APFloat computeReciprocal(const APFloat &floatVal, Type floatTy) const {
34-
3534
auto recipAttr = FloatAttr::get(floatTy, 1.0);
3635
APFloat recip = recipAttr.getValue();
3736
recip.divide(floatVal, reciprocalRoundingMode);
@@ -46,7 +45,10 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
4645
// TODO it would be nicer to do this in-place
4746

4847
// Compute the reciprocal for each tensor element
49-
llvm::SmallVector<APFloat, 10> transformedValues;
48+
llvm::SmallVector<APFloat, 1> transformedValues;
49+
// We already know the amount of values we will insert, reserver space for
50+
// all of them to avoid dynamic resizing
51+
transformedValues.reserve(inputValues.getNumElements());
5052
for (auto it = inputValues.value_begin<APFloat>();
5153
it != inputValues.value_end<APFloat>(); it++) {
5254
auto val = *it;
@@ -64,7 +66,6 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
6466

6567
LogicalResult matchAndRewrite(ReciprocalOp recip,
6668
PatternRewriter &rewriter) const override {
67-
6869
auto inputTensor = recip.getInput1();
6970
auto elemType = inputTensor.getType().getElementType();
7071
// TOSA only allows for floats as inputs to the reciprocal operation, so
@@ -85,7 +86,6 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
8586
// In case we have a splat, we only need to calculate the reciprocal once
8687
// and update the tensor to the transformed splat value.
8788
if (auto splatAttrs = dyn_cast<SplatElementsAttr>(inputValues)) {
88-
8989
// Transform the splat value
9090
auto splatVal = splatAttrs.getSplatValue<APFloat>();
9191
auto newSplatRecipAttr = computeReciprocal(splatVal, elemType);

0 commit comments

Comments
 (0)