Skip to content

Commit 9dcee22

Browse files
authored
[Cadence] add reference requantize out and tests
Differential Revision: D70906707 Pull Request resolved: #9097
1 parent 1b2c60c commit 9dcee22

File tree

11 files changed

+456
-9
lines changed

11 files changed

+456
-9
lines changed

backends/cadence/aot/TARGETS

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,23 @@ python_library(
115115
],
116116
deps = [
117117
"fbcode//caffe2:torch",
118-
"fbcode//executorch/exir:scalar_type",
119118
"fbcode//executorch/backends/cadence/aot:utils",
120119
],
121120
)
122121

122+
python_library(
123+
name = "ref_implementations",
124+
srcs = [
125+
"ref_implementations.py",
126+
],
127+
typing = True,
128+
deps = [
129+
"fbcode//caffe2:torch",
130+
"fbcode//executorch/exir:scalar_type",
131+
],
132+
)
133+
134+
123135
export_file(name = "functions.yaml")
124136

125137
executorch_generated_lib(

backends/cadence/aot/export_example.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def export_model(
3838
example_inputs: Tuple[Any, ...],
3939
file_name: str = "CadenceDemoModel",
4040
run_and_compare: bool = True,
41+
eps_error: float = 1e-1,
42+
eps_warn: float = 1e-5,
4143
):
4244
# create work directory for outputs and model binary
4345
working_dir = tempfile.mkdtemp(dir="/tmp")
@@ -89,4 +91,6 @@ def export_model(
8991
inputs=example_inputs,
9092
ref_outputs=ref_outputs,
9193
working_dir=working_dir,
94+
eps_error=eps_error,
95+
eps_warn=eps_warn,
9296
)

backends/cadence/aot/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,8 @@
248248
kernels:
249249
- arg_meta: null
250250
kernel_name: impl::reference::quantized_fully_connected_per_tensor_out
251+
252+
- func: cadence::requantize.out(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!)
253+
kernels:
254+
- arg_meta: null
255+
kernel_name: impl::reference::requantize_out

backends/cadence/aot/ops_registrations.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@
9494
"int[] dilation, SymInt[] output_padding, int groups, bool channel_last=False) -> (Tensor Y)"
9595
)
9696
lib.define("dequantize(Tensor X, Tensor X_scale, Tensor X_zero_point) -> (Tensor Y)")
97-
# cadence::quantized_relu is defined in OSS
9897
lib.define(
9998
"quantized_add(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
10099
"Tensor Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)"
@@ -119,8 +118,6 @@
119118
"quantized_embedding_byte(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, "
120119
"Tensor indices, bool pruned_weights=False) -> (Tensor X)"
121120
)
122-
# cadence::quantized_layer_norm is defined in OSS
123-
# cadence::quantized_conv is defined is OSS
124121
lib.define(
125122
"quantized_transposed_conv(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, "
126123
"int[] dilation, SymInt[] output_padding, int groups, int input_zero_point, Tensor weight_zero_point, "
@@ -156,7 +153,7 @@
156153
)
157154

158155
# ------------------------------------ #
159-
# Migrated from custom_ops.ymal #
156+
# Migrated from custom_ops.yaml #
160157
# ------------------------------------ #
161158
# Migrated from the custom_ops.yaml files containing different operator variants (e.g., .out, .tensor_out)
162159
lib.define(
@@ -167,7 +164,6 @@
167164
"transposed_convolution.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, "
168165
"int[] dilation, SymInt[] output_padding, int groups, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
169166
)
170-
# cadence::quantized_relu.out is defined in OSS
171167
lib.define(
172168
"quantized_relu.per_tensor(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift) -> Tensor"
173169
)
@@ -265,14 +261,12 @@
265261
"_cat_nop.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)"
266262
)
267263

268-
# Custom ops with jarvis_nn_ops namespace
264+
# Custom ops with cadence_nn_ops namespace
269265
jarvis_nn_lib = Library("jarvis_nn_ops", "DEF")
270266
jarvis_nn_lib.define(
271267
"attention_mask.out(Tensor input, Tensor start, Tensor stop, *, Tensor(a!) out) -> Tensor(a!)"
272268
)
273269

274-
m = Library("cadence", "IMPL", "Meta")
275-
276270

277271
@register_fake("cadence::quantize_per_tensor")
278272
def quantize_per_tensor_meta(
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import torch
10+
from executorch.exir.scalar_type import ScalarType
11+
from torch.library import impl, Library
12+
13+
14+
m = Library("cadence", "IMPL", "CompositeExplicitAutograd")
15+
16+
qdtype_map: dict[ScalarType, torch.dtype] = {
17+
ScalarType.QINT8: torch.qint8,
18+
ScalarType.QUINT8: torch.quint8,
19+
ScalarType.QINT32: torch.qint32,
20+
}
21+
22+
23+
@impl(m, "requantize")
24+
def requantize(
25+
input: torch.Tensor,
26+
in_scale: torch.Tensor,
27+
in_zero_point: torch.Tensor,
28+
out_scale: torch.Tensor,
29+
out_zero_point: torch.Tensor,
30+
dtype: ScalarType,
31+
) -> torch.Tensor:
32+
if dtype in qdtype_map:
33+
# Old quantization mechanism
34+
return torch.quantize_per_tensor(
35+
torch.dequantize(input), out_scale, out_zero_point, qdtype_map[dtype]
36+
)
37+
38+
# For in_scale or out_scale other than scalar, it requires quant/dequant
39+
# per channel, but the channel dimension value is missing
40+
if in_scale.numel() > 1 or out_scale.numel() > 1:
41+
raise NotImplementedError("Only scalar scales are supported")
42+
43+
quant_min = torch.iinfo(input.dtype).min
44+
quant_max = torch.iinfo(input.dtype).max
45+
# pyre-fixme[6]: This dtype is actually the right one.
46+
out_quant_min = torch.iinfo(dtype).min
47+
# pyre-fixme[6]: This dtype is actually the right one.
48+
out_quant_max = torch.iinfo(dtype).max
49+
return torch.ops.quantized_decomposed.quantize_per_tensor(
50+
torch.ops.quantized_decomposed.dequantize_per_tensor(
51+
input,
52+
in_scale.flatten()[0],
53+
in_zero_point.flatten()[0],
54+
quant_min,
55+
quant_max,
56+
input.dtype,
57+
),
58+
out_scale.flatten()[0],
59+
out_zero_point.flatten()[0],
60+
out_quant_min,
61+
out_quant_max,
62+
dtype,
63+
)

backends/cadence/reference/kernels/kernels.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,36 @@ void dequantize(
5858
}
5959
}
6060

61+
// Requantize the int8_t/uint8_t in value to a uint8_t/int8_t out value.
62+
// The scale and zero_point for requantization are in the args.
63+
template <typename IT, typename OT>
64+
OT requantize(
65+
const IT in,
66+
float in_scale,
67+
int32_t in_zero_point,
68+
float inv_out_scale,
69+
int32_t out_zero_point) {
70+
float dequant = dequantize<IT>(in, in_scale, in_zero_point);
71+
return quantize<OT>(dequant, inv_out_scale, out_zero_point);
72+
}
73+
74+
// Requantize the int8_t/uint8_t in array to a uint8_t/int8_t out array.
75+
// The scale and zero_point for requantization are in the args.
76+
template <typename IT, typename OT>
77+
void requantize(
78+
OT* __restrict__ out,
79+
const IT* __restrict__ in,
80+
float in_scale,
81+
int32_t in_zero_point,
82+
float inv_out_scale,
83+
int32_t out_zero_point,
84+
size_t size) {
85+
for (size_t i = 0; i < size; ++i) {
86+
out[i] = requantize<IT, OT>(
87+
in[i], in_scale, in_zero_point, inv_out_scale, out_zero_point);
88+
}
89+
}
90+
6191
// explicit template instantiation
6292

6393
#define typed_quantize_val(dtype) \
@@ -106,6 +136,58 @@ typed_dequantize_vec(uint16_t);
106136
typed_dequantize_vec(int32_t);
107137
#undef typed_dequantize_vec
108138

139+
#define typed_requantize_val(itype, otype) \
140+
template otype requantize( \
141+
const itype in, \
142+
float in_scale, \
143+
int32_t in_zero_point, \
144+
float inv_out_scale, \
145+
int32_t out_zero_point);
146+
typed_requantize_val(int8_t, int8_t);
147+
typed_requantize_val(int8_t, uint8_t);
148+
typed_requantize_val(int8_t, int16_t);
149+
typed_requantize_val(int8_t, uint16_t);
150+
typed_requantize_val(uint8_t, int8_t);
151+
typed_requantize_val(uint8_t, uint8_t);
152+
typed_requantize_val(uint8_t, int16_t);
153+
typed_requantize_val(uint8_t, uint16_t);
154+
typed_requantize_val(int16_t, int8_t);
155+
typed_requantize_val(int16_t, uint8_t);
156+
typed_requantize_val(int16_t, int16_t);
157+
typed_requantize_val(int16_t, uint16_t);
158+
typed_requantize_val(uint16_t, int8_t);
159+
typed_requantize_val(uint16_t, uint8_t);
160+
typed_requantize_val(uint16_t, int16_t);
161+
typed_requantize_val(uint16_t, uint16_t);
162+
#undef typed_requantize_val
163+
164+
#define typed_requantize_vec(itype, otype) \
165+
template void requantize( \
166+
otype* __restrict__ out, \
167+
const itype* __restrict__ in, \
168+
float in_scale, \
169+
int32_t in_zero_point, \
170+
float inv_out_scale, \
171+
int32_t out_zero_point, \
172+
size_t size);
173+
typed_requantize_vec(int8_t, int8_t);
174+
typed_requantize_vec(int8_t, uint8_t);
175+
typed_requantize_vec(int8_t, int16_t);
176+
typed_requantize_vec(int8_t, uint16_t);
177+
typed_requantize_vec(uint8_t, int8_t);
178+
typed_requantize_vec(uint8_t, uint8_t);
179+
typed_requantize_vec(uint8_t, int16_t);
180+
typed_requantize_vec(uint8_t, uint16_t);
181+
typed_requantize_vec(int16_t, int8_t);
182+
typed_requantize_vec(int16_t, uint8_t);
183+
typed_requantize_vec(int16_t, int16_t);
184+
typed_requantize_vec(int16_t, uint16_t);
185+
typed_requantize_vec(uint16_t, int8_t);
186+
typed_requantize_vec(uint16_t, uint8_t);
187+
typed_requantize_vec(uint16_t, int16_t);
188+
typed_requantize_vec(uint16_t, uint16_t);
189+
#undef typed_requantize_vec
190+
109191
}; // namespace kernels
110192
}; // namespace reference
111193
}; // namespace impl

backends/cadence/reference/kernels/kernels.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,24 @@ void dequantize(
3636
int32_t zero_point,
3737
size_t size);
3838

39+
template <typename IT, typename OT>
40+
OT requantize(
41+
const IT in,
42+
float in_scale,
43+
int32_t in_zero_point,
44+
float inv_out_scale,
45+
int32_t out_zero_point);
46+
47+
template <typename IT, typename OT>
48+
void requantize(
49+
OT* __restrict__ out,
50+
const IT* __restrict__ in,
51+
float in_scale,
52+
int32_t in_zero_point,
53+
float inv_out_scale,
54+
int32_t out_zero_point,
55+
size_t size);
56+
3957
}; // namespace kernels
4058
}; // namespace reference
4159
}; // namespace impl

backends/cadence/reference/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ add_library(
9090
"quantized_fully_connected_out.cpp"
9191
"dequantize_per_tensor.cpp"
9292
"quantized_matmul_out.cpp"
93+
"requantize_out.cpp"
9394
"im2row_out.cpp"
9495
)
9596
target_include_directories(

0 commit comments

Comments
 (0)