Skip to content

Commit 59b56bf

Browse files
mgehre-amdGitHub Enterprise
authored andcommitted
Merge pull request #6 from ACT/tina.tosareciprocalfolding
[FXML-1727] Implement folding for constant reciprocals
2 parents 41597d5 + e2c13ec commit 59b56bf

File tree

5 files changed

+251
-0
lines changed

5 files changed

+251
-0
lines changed

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
2929
RewritePatternSet &patterns);
3030
void populateTosaDecomposeDepthwise(MLIRContext *ctx,
3131
RewritePatternSet &patterns);
32+
void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,
33+
RewritePatternSet &patterns);
3234
void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx,
3335
RewritePatternSet &patterns);
3436

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
22
TosaDecomposeTransposeConv.cpp
33
TosaDecomposeConv2D.cpp
44
TosaDecomposeDepthwise.cpp
5+
TosaFoldConstantReciprocal.cpp
56
TosaFoldConstantTranspose.cpp
67
TosaInferShapes.cpp
78
TosaLayerwiseConstantFoldPass.cpp
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
//===- TosaFoldConstantReciprocal.cpp -------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Fold TOSA Reciprocal operation on constant data
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
14+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
15+
#include "mlir/IR/Matchers.h"
16+
#include "mlir/Pass/Pass.h"
17+
#include <llvm/ADT/APFloat.h>
18+
#include <llvm/ADT/FloatingPointMode.h>
19+
#include <llvm/ADT/SmallVector.h>
20+
#include <mlir/IR/BuiltinAttributes.h>
21+
#include <mlir/Support/LogicalResult.h>
22+
23+
using namespace mlir;
24+
using namespace mlir::tosa;
25+
26+
namespace {
27+
28+
struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
29+
30+
using OpRewritePattern::OpRewritePattern;
31+
static constexpr llvm::RoundingMode reciprocalRoundingMode =
32+
APFloat::rmNearestTiesToEven;
33+
34+
APFloat computeReciprocal(const APFloat &floatVal, Type floatTy) const {
35+
auto recipAttr = FloatAttr::get(floatTy, 1.0);
36+
APFloat recip = recipAttr.getValue();
37+
recip.divide(floatVal, reciprocalRoundingMode);
38+
39+
return recip;
40+
}
41+
42+
DenseElementsAttr
43+
replaceTensorWithReciprocal(ConstOp tensorToReplace,
44+
const DenseElementsAttr &inputValues) const {
45+
// TODO it would be nicer to do this in-place
46+
47+
// Compute the reciprocal for each tensor element
48+
llvm::SmallVector<APFloat, 1> transformedValues;
49+
// We already know the amount of values we will insert, reserve space for
50+
// all of them to avoid dynamic resizing
51+
transformedValues.reserve(inputValues.getNumElements());
52+
for (auto val : inputValues.getValues<APFloat>()) {
53+
auto recipVal = computeReciprocal(val, inputValues.getElementType());
54+
transformedValues.push_back(recipVal);
55+
}
56+
57+
// Replace the current tensor with one containing the computed reciprocals
58+
auto newTensor =
59+
DenseElementsAttr::get(inputValues.getType(), transformedValues);
60+
return newTensor;
61+
}
62+
63+
LogicalResult matchAndRewrite(ReciprocalOp recip,
64+
PatternRewriter &rewriter) const override {
65+
auto inputTensor = recip.getInput1();
66+
auto elemType = inputTensor.getType().getElementType();
67+
// TOSA only allows for floats as inputs to the reciprocal operation, so
68+
// bail if anything else is contained
69+
if (!isa<FloatType>(elemType)) {
70+
return rewriter.notifyMatchFailure(recip,
71+
"Unexpected input tensor type: the "
72+
"TOSA spec only allows floats");
73+
}
74+
75+
// Check whether the tensor is constant and dense
76+
DenseElementsAttr inputValues;
77+
if (!matchPattern(inputTensor, m_Constant(&inputValues))) {
78+
return rewriter.notifyMatchFailure(
79+
recip, "Non-const or non-dense input to reciprocal");
80+
}
81+
82+
// In case we have a splat, we only need to calculate the reciprocal once
83+
// and update the tensor to the transformed splat value.
84+
if (auto splatAttrs = dyn_cast<SplatElementsAttr>(inputValues)) {
85+
// Transform the splat value
86+
auto splatVal = splatAttrs.getSplatValue<APFloat>();
87+
auto newSplatRecipAttr = computeReciprocal(splatVal, elemType);
88+
89+
// Create a tensor with the transformed splat value
90+
auto newSplatTensor =
91+
DenseElementsAttr::get(splatAttrs.getType(), newSplatRecipAttr);
92+
93+
// Replace the reciprocal op with the newly constructed tensor
94+
rewriter.replaceOpWithNewOp<ConstOp>(recip, newSplatTensor.getType(),
95+
newSplatTensor);
96+
return success();
97+
}
98+
99+
if (!isa<ConstOp>(inputTensor.getDefiningOp())) {
100+
return rewriter.notifyMatchFailure(recip,
101+
"The reciprocal can only be folded if "
102+
"it operates on a TOSA constant");
103+
}
104+
auto definingConstOp = cast<ConstOp>(inputTensor.getDefiningOp());
105+
106+
// Our transformation replaces the input tensor with the transformed tensor.
107+
// If the input has several users we need to keep the input. This can
108+
// result in a significantly increased memory usage, such that we currently
109+
// refrain from applying the transformation in that case.
110+
if (!definingConstOp->hasOneUse()) {
111+
return rewriter.notifyMatchFailure(
112+
recip, "Currently, reciprocals will only be folded if the input "
113+
"tensor has a single user");
114+
}
115+
116+
// Create a new tensor with the updated values
117+
auto newTensor = replaceTensorWithReciprocal(definingConstOp, inputValues);
118+
119+
// Replace the use of the reciprocal with the transformed tensor
120+
rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
121+
return success();
122+
}
123+
};
124+
125+
} // namespace
126+
127+
void mlir::tosa::populateTosaFoldConstantReciprocalPatterns(
128+
MLIRContext *ctx, RewritePatternSet &patterns) {
129+
patterns.add<TosaFoldConstantReciprocal>(ctx);
130+
}

mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ struct TosaLayerwiseConstantFoldPass
5050
RewritePatternSet patterns(ctx);
5151
auto func = getOperation();
5252

53+
mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
5354
mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
5455
populateTosaOpsCanonicalizationPatterns(ctx, patterns);
5556

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s
2+
3+
// CHECK-LABEL: @reciprocal_fold_single_valued
4+
func.func @reciprocal_fold_single_valued() -> tensor<f32> {
5+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.5{{0*}}e-01{{.*}}tensor<f32>
6+
// CHECK-NOT: tosa.reciprocal
7+
// CHECK: return [[RES]]
8+
%0 = "tosa.const"() {value = dense<4.0> : tensor<f32>} : () -> tensor<f32>
9+
%1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
10+
return %1 : tensor<f32>
11+
}
12+
13+
// CHECK-LABEL: @reciprocal_fold_splat
14+
func.func @reciprocal_fold_splat() -> tensor<12x7xf32> {
15+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.5{{0*}}e-01{{.*}}tensor<12x7xf32>
16+
// CHECK-NOT: tosa.reciprocal
17+
// CHECK: return [[RES]]
18+
%0 = "tosa.const"() {value = dense<4.0> : tensor<12x7xf32>} : () -> tensor<12x7xf32>
19+
%1 = "tosa.reciprocal"(%0) : (tensor<12x7xf32>) -> tensor<12x7xf32>
20+
return %1 : tensor<12x7xf32>
21+
}
22+
23+
// CHECK-LABEL: @reciprocal_div_zero
24+
func.func @reciprocal_div_zero() -> tensor<f32> {
25+
// 0x7F800000 is the value for +infinity
26+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000
27+
// CHECK-NOT: tosa.reciprocal
28+
// CHECK: return [[RES]]
29+
%0 = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
30+
%1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
31+
return %1 : tensor<f32>
32+
}
33+
34+
// CHECK-LABEL: @reciprocal_div_neg_zero
35+
func.func @reciprocal_div_neg_zero() -> tensor<f32> {
36+
// 0xFF800000 is the value for -infinity
37+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0xFF800000
38+
// CHECK-NOT: tosa.reciprocal
39+
// CHECK: return [[RES]]
40+
%0 = "tosa.const"() {value = dense<-0.0> : tensor<f32>} : () -> tensor<f32>
41+
%1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
42+
return %1 : tensor<f32>
43+
}
44+
45+
// CHECK-LABEL: @reciprocal_div_nan
46+
func.func @reciprocal_div_nan() -> tensor<f32> {
47+
// 0x7FC00000 is the value for NAN
48+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7FC00000
49+
// CHECK-NOT: tosa.reciprocal
50+
// CHECK: return [[RES]]
51+
%0 = "tosa.const"() {value = dense<0x7FC00000> : tensor<f32>} : () -> tensor<f32>
52+
%1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
53+
return %1 : tensor<f32>
54+
}
55+
56+
// CHECK-LABEL: @reciprocal_div_infinity
57+
func.func @reciprocal_div_infinity() -> tensor<f32> {
58+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<0.{{0*}}e+00>
59+
// CHECK-NOT: tosa.reciprocal
60+
// CHECK: return [[RES]]
61+
%0 = "tosa.const"() {value = dense<0x7F800000> : tensor<f32>} : () -> tensor<f32>
62+
%1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
63+
return %1 : tensor<f32>
64+
}
65+
66+
// CHECK-LABEL: @reciprocal_div_neg_infinity
67+
func.func @reciprocal_div_neg_infinity() -> tensor<f32> {
68+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<-0.{{0*}}e+00>
69+
// CHECK-NOT: tosa.reciprocal
70+
// CHECK: return [[RES]]
71+
%0 = "tosa.const"() {value = dense<0xFF800000> : tensor<f32>} : () -> tensor<f32>
72+
%1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
73+
return %1 : tensor<f32>
74+
}
75+
76+
// CHECK-LABEL: @reciprocal_no_fold
77+
// The folding optimization works only intra-procedurally, so we won't be able
78+
// to fold anything here
79+
func.func @reciprocal_no_fold(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
80+
// CHECK: tosa.reciprocal
81+
// CHECK-NEXT: return
82+
%0 = "tosa.reciprocal"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
83+
return %0 : tensor<?x?xf32>
84+
}
85+
86+
// CHECK-LABEL: @reciprocal_fold
87+
func.func @reciprocal_fold() -> tensor<4x6xf32> {
88+
// CHECK: [[RES:]] ={{.*}}tosa.const
89+
// CHECK-SAME{LITERAL}: [[5.68828249, 11.4416485, 1.6880486, 0.680272102, -0.875350117, 0.342313349],
90+
// CHECK-SAME{LITERAL}: [-4.81231928, 0.698080301, 0.65432179, -82.6446304, -4.33651352, -0.747551739],
91+
// CHECK-SAME{LITERAL}: [-12.4378109, 13.140605, 1.89501607, 0.885582745, 4.08830738, 1.4396776],
92+
// CHECK-SAME{LITERAL}: [2.02880907, -1.53280187, 0.552730501, 7.15819644, 0.64495325, -0.973709881]]
93+
// CHECK-NOT: tosa.reciprocal
94+
// CHECK: return [[RES]]
95+
%0 = "tosa.const"() { value = dense<[
96+
[ 0.1758, 0.0874, 0.5924, 1.4700, -1.1424, 2.9213],
97+
[-0.2078, 1.4325, 1.5283, -0.0121, -0.2306, -1.3377],
98+
[-0.0804, 0.0761, 0.5277, 1.1292, 0.2446, 0.6946],
99+
[ 0.4929, -0.6524, 1.8092, 0.1397, 1.5505, -1.0270]]>
100+
: tensor<4x6xf32>
101+
} : () -> tensor<4x6xf32>
102+
%1 = "tosa.reciprocal"(%0) : (tensor<4x6xf32>) -> tensor<4x6xf32>
103+
return %1 : tensor<4x6xf32>
104+
}
105+
106+
// CHECK-LABEL: @reciprocal_of_const_sparse
107+
// Sparse tensors are currently not supported
108+
func.func @reciprocal_of_const_sparse() -> tensor<32xbf16> {
109+
// CHECK: tosa.const
110+
// CHECK: tosa.reciprocal
111+
%0 = "tosa.const"() { value = sparse<
112+
[[0], [3], [11], [17], [20], [23], [25], [30], [31]],
113+
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]>
114+
: tensor<32xbf16> } : () -> tensor<32xbf16>
115+
%1 = "tosa.reciprocal"(%0) : (tensor<32xbf16>) -> tensor<32xbf16>
116+
return %1 : tensor<32xbf16>
117+
}

0 commit comments

Comments
 (0)