Skip to content

Commit c8311e6

Browse files
authored
Fix HiFi relu for int8
Differential Revision: D69542780 Pull Request resolved: #8424
1 parent e74b141 commit c8311e6

File tree

1 file changed

+7
-26
lines changed

1 file changed

+7
-26
lines changed

backends/cadence/hifi/operators/op_quantized_relu_out.cpp

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,12 @@ void quantized_relu_per_tensor_out(
2626
const int64_t out_multiplier,
2727
const int64_t out_shift,
2828
Tensor& output) {
29-
const uint8_t _in_zero_point = static_cast<uint8_t>(in_zero_point);
30-
const uint8_t _out_zero_point = static_cast<uint8_t>(out_zero_point);
31-
const int32_t _out_multiplier = static_cast<int32_t>(out_multiplier);
32-
const int32_t _out_shift = static_cast<int32_t>(out_shift);
29+
const int32_t _out_multiplier = static_cast<int32_t>(out_multiplier);
30+
const int32_t _out_shift = static_cast<int32_t>(out_shift);
31+
3332
if (input.scalar_type() == executorch::aten::ScalarType::Byte) {
33+
const uint8_t _in_zero_point = static_cast<uint8_t>(in_zero_point);
34+
const uint8_t _out_zero_point = static_cast<uint8_t>(out_zero_point);
3435
const uint8_t* p_in = input.const_data_ptr<uint8_t>();
3536
uint8_t* p_out = output.mutable_data_ptr<uint8_t>();
3637

@@ -48,6 +49,8 @@ void quantized_relu_per_tensor_out(
4849
ET_CHECK_MSG(ret_val == 0, "An internal error occured");
4950

5051
} else if (input.scalar_type() == executorch::aten::ScalarType::Char) {
52+
const int8_t _in_zero_point = static_cast<int8_t>(in_zero_point);
53+
const int8_t _out_zero_point = static_cast<int8_t>(out_zero_point);
5154
const int8_t* p_in = input.const_data_ptr<int8_t>();
5255
int8_t* p_out = output.mutable_data_ptr<int8_t>();
5356

@@ -72,28 +75,6 @@ void quantized_relu_per_tensor_out(
7275
}
7376
}
7477

75-
void quantized_relu_per_tensor_out(
76-
KernelRuntimeContext& ctx,
77-
const Tensor& input,
78-
const Tensor& in_zero_point,
79-
const int64_t out_zero_point,
80-
const Tensor& out_multiplier,
81-
const Tensor& out_shift,
82-
Tensor& output) {
83-
int8_t _in_zero_point = in_zero_point.const_data_ptr<int8_t>()[0];
84-
int32_t _out_multiplier = out_multiplier.const_data_ptr<int32_t>()[0];
85-
int32_t _out_shift = out_shift.const_data_ptr<int32_t>()[0];
86-
87-
quantized_relu_per_tensor_out(
88-
ctx,
89-
input,
90-
_in_zero_point,
91-
out_zero_point,
92-
_out_multiplier,
93-
_out_shift,
94-
output);
95-
}
96-
9778
} // namespace native
9879
} // namespace HiFi
9980
} // namespace impl

0 commit comments

Comments
 (0)