Skip to content

Commit b8d97d8

Browse files
authored
Enable int8 support for quantized_linear reference
Differential Revision: D64553726 Pull Request resolved: #6334
1 parent ec27667 commit b8d97d8

File tree

6 files changed

+67
-12
lines changed

6 files changed

+67
-12
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def quantized_relu_meta(
188188
out_multiplier: torch.Tensor,
189189
out_shift: torch.Tensor,
190190
) -> torch.Tensor:
191-
return X.new_empty(X.size(), dtype=torch.uint8)
191+
return X.new_empty(X.size(), dtype=X.dtype)
192192

193193

194194
@register_fake("cadence::quantized_matmul")

backends/cadence/hifi/operators/dequantize_per_tensor.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ void dequantize_per_tensor_out(
4545
const int32_t* input_data = input.const_data_ptr<int32_t>();
4646
dequantize<int32_t>(out_data, input_data, scale, zero_point, numel);
4747
} else {
48-
ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type());
48+
ET_CHECK_MSG(
49+
false,
50+
"Unhandled input dtype %hhd",
51+
static_cast<int8_t>(input.scalar_type()));
4952
}
5053
}
5154

backends/cadence/hifi/operators/quantize_per_tensor.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ void quantize_per_tensor_out(
4949
cadence::impl::HiFi::kernels::quantize<int32_t>(
5050
out_data, input_data, 1. / scale, zero_point, numel);
5151
} else {
52-
ET_CHECK_MSG(false, "Unhandled input dtype %hhd", out.scalar_type());
52+
ET_CHECK_MSG(
53+
false,
54+
"Unhandled output dtype %hhd",
55+
static_cast<int8_t>(out.scalar_type()));
5356
}
5457
}
5558

backends/cadence/reference/operators/quantized_conv_out.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,11 @@ void quantized_conv_out(
248248
output_scale,
249249
(int8_t)output_zero_point,
250250
per_tensor_quantized);
251+
} else {
252+
ET_CHECK_MSG(
253+
false,
254+
"Unhandled input dtype %hhd",
255+
static_cast<int8_t>(input.scalar_type()));
251256
}
252257
}
253258

backends/cadence/reference/operators/quantized_linear_out.cpp

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ using executorch::aten::Tensor;
1717
using executorch::runtime::getLeadingDims;
1818
using executorch::runtime::KernelRuntimeContext;
1919

20-
void quantized_linear_out(
21-
KernelRuntimeContext& ctx,
20+
template <typename T>
21+
void inline _typed_quantized_linear(
2222
const Tensor& src,
2323
const Tensor& weight,
2424
const Tensor& bias,
@@ -27,14 +27,11 @@ void quantized_linear_out(
2727
const Tensor& out_multiplier,
2828
const Tensor& out_shift,
2929
int64_t out_zero_point,
30-
const executorch::aten::optional<Tensor>& offset,
3130
Tensor& out) {
32-
// Assuming uint8_t for now, but needs to be updated for other quantization
33-
// types
34-
const uint8_t* __restrict__ src_data = src.const_data_ptr<uint8_t>();
35-
const uint8_t* __restrict__ weight_data = weight.const_data_ptr<uint8_t>();
31+
const T* __restrict__ src_data = src.const_data_ptr<T>();
32+
const T* __restrict__ weight_data = weight.const_data_ptr<T>();
3633
const int32_t* __restrict__ bias_data = bias.const_data_ptr<int32_t>();
37-
uint8_t* __restrict__ out_data = out.mutable_data_ptr<uint8_t>();
34+
T* __restrict__ out_data = out.mutable_data_ptr<T>();
3835

3936
int32_t weight_zero_point = weight_zero_point_t.const_data_ptr<int32_t>()[0];
4037

@@ -71,11 +68,53 @@ void quantized_linear_out(
7168
(weight_data[j * N + k] - weight_zero_point);
7269
}
7370
out_data[i * M + j] =
74-
kernels::quantize<uint8_t>(sum, out_scale, out_zero_point);
71+
kernels::quantize<T>(sum, out_scale, out_zero_point);
7572
}
7673
}
7774
}
7875

76+
void quantized_linear_out(
77+
__ET_UNUSED KernelRuntimeContext& ctx,
78+
const Tensor& src,
79+
const Tensor& weight,
80+
const Tensor& bias,
81+
int64_t src_zero_point,
82+
const Tensor& weight_zero_point_t,
83+
const Tensor& out_multiplier,
84+
const Tensor& out_shift,
85+
int64_t out_zero_point,
86+
__ET_UNUSED const executorch::aten::optional<Tensor>& offset,
87+
Tensor& out) {
88+
if (out.scalar_type() == executorch::aten::ScalarType::Byte) {
89+
_typed_quantized_linear<uint8_t>(
90+
src,
91+
weight,
92+
bias,
93+
src_zero_point,
94+
weight_zero_point_t,
95+
out_multiplier,
96+
out_shift,
97+
out_zero_point,
98+
out);
99+
} else if (out.scalar_type() == executorch::aten::ScalarType::Char) {
100+
_typed_quantized_linear<int8_t>(
101+
src,
102+
weight,
103+
bias,
104+
src_zero_point,
105+
weight_zero_point_t,
106+
out_multiplier,
107+
out_shift,
108+
out_zero_point,
109+
out);
110+
} else {
111+
ET_CHECK_MSG(
112+
false,
113+
"Unhandled input dtype %hhd",
114+
static_cast<int8_t>(src.scalar_type()));
115+
}
116+
}
117+
79118
}; // namespace native
80119
}; // namespace reference
81120
}; // namespace impl

backends/cadence/reference/operators/quantized_matmul_out.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ void quantized_matmul_out(
144144
out_zero_point,
145145
transposed,
146146
out);
147+
} else {
148+
ET_CHECK_MSG(
149+
false,
150+
"Unhandled input dtype %hhd",
151+
static_cast<int8_t>(X.scalar_type()));
147152
}
148153
}
149154

0 commit comments

Comments
 (0)