@@ -26,11 +26,12 @@ void quantized_relu_per_tensor_out(
26
26
const int64_t out_multiplier,
27
27
const int64_t out_shift,
28
28
Tensor& output) {
29
- const uint8_t _in_zero_point = static_cast <uint8_t >(in_zero_point);
30
- const uint8_t _out_zero_point = static_cast <uint8_t >(out_zero_point);
31
- const int32_t _out_multiplier = static_cast <int32_t >(out_multiplier);
32
- const int32_t _out_shift = static_cast <int32_t >(out_shift);
29
+ const int32_t _out_multiplier = static_cast <int32_t >(out_multiplier);
30
+ const int32_t _out_shift = static_cast <int32_t >(out_shift);
31
+
33
32
if (input.scalar_type () == executorch::aten::ScalarType::Byte) {
33
+ const uint8_t _in_zero_point = static_cast <uint8_t >(in_zero_point);
34
+ const uint8_t _out_zero_point = static_cast <uint8_t >(out_zero_point);
34
35
const uint8_t * p_in = input.const_data_ptr <uint8_t >();
35
36
uint8_t * p_out = output.mutable_data_ptr <uint8_t >();
36
37
@@ -48,6 +49,8 @@ void quantized_relu_per_tensor_out(
48
49
ET_CHECK_MSG (ret_val == 0 , " An internal error occured" );
49
50
50
51
} else if (input.scalar_type () == executorch::aten::ScalarType::Char) {
52
+ const int8_t _in_zero_point = static_cast <int8_t >(in_zero_point);
53
+ const int8_t _out_zero_point = static_cast <int8_t >(out_zero_point);
51
54
const int8_t * p_in = input.const_data_ptr <int8_t >();
52
55
int8_t * p_out = output.mutable_data_ptr <int8_t >();
53
56
@@ -72,28 +75,6 @@ void quantized_relu_per_tensor_out(
72
75
}
73
76
}
74
77
75
- void quantized_relu_per_tensor_out (
76
- KernelRuntimeContext& ctx,
77
- const Tensor& input,
78
- const Tensor& in_zero_point,
79
- const int64_t out_zero_point,
80
- const Tensor& out_multiplier,
81
- const Tensor& out_shift,
82
- Tensor& output) {
83
- int8_t _in_zero_point = in_zero_point.const_data_ptr <int8_t >()[0 ];
84
- int32_t _out_multiplier = out_multiplier.const_data_ptr <int32_t >()[0 ];
85
- int32_t _out_shift = out_shift.const_data_ptr <int32_t >()[0 ];
86
-
87
- quantized_relu_per_tensor_out (
88
- ctx,
89
- input,
90
- _in_zero_point,
91
- out_zero_point,
92
- _out_multiplier,
93
- _out_shift,
94
- output);
95
- }
96
-
97
78
} // namespace native
98
79
} // namespace HiFi
99
80
} // namespace impl
0 commit comments