Skip to content

Commit d28fe4a

Browse files
kausikmaitifacebook-github-bot
authored andcommitted
Modify signature of dequantize ops for decomposed quantized Tensor
Summary: Note: The initial purpose of this PR is to draw suggestion and feedback regarding better alternative, if any. At present, dequantize op for decomposed quantized Tensor representation e.g. dequantize_per_tensor() assumes the output dtype as torch.float and hence, it does not have the output dtype in its operator argument list. However, this op signature becomes unusable when the assumption breaks. Because, in case the output dtype is different from torch.float, there is no way to specify the same during dequantization. This change is aimed at generalizing the signature of dequantize op like dequantize_per_tensor() for wider use-cases where the output dtype can be different from torch.float and needs to passed during dequantization. The proposal is to use an additional argument named 'output_dtype' to solve the problem. However, we would also like to have suggestion and feedback regarding any better alternative that can be used instead. cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen leslie-fang-intel X-link: pytorch/pytorch#119173 Reviewed By: digantdesai Differential Revision: D53590486 Pulled By: manuelcandales
1 parent 9d6bf72 commit d28fe4a

File tree

5 files changed

+67
-15
lines changed

5 files changed

+67
-15
lines changed

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ void check_dequantize_per_tensor_args(
3333
int64_t quant_min,
3434
int64_t quant_max,
3535
ScalarType dtype,
36+
exec_aten::optional<ScalarType>& out_dtype,
3637
Tensor& out) {
3738
ET_CHECK_MSG(
3839
input.scalar_type() == ScalarType::Byte ||
@@ -47,10 +48,11 @@ void check_dequantize_per_tensor_args(
4748
"input.scalar_type() %" PRId8 " is not matching dtype argumenta:",
4849
static_cast<int8_t>(input.scalar_type()));
4950

50-
ET_CHECK_MSG(
51-
out.scalar_type() == ScalarType::Float,
52-
"out.scalar_type() %" PRId8 " is not supported:",
53-
static_cast<int8_t>(out.scalar_type()));
51+
if (out_dtype.has_value()) {
52+
ET_CHECK_MSG(
53+
out.scalar_type() == out_dtype.value(),
54+
"output_dtype must match the dtype of the out tensor");
55+
}
5456

5557
ET_CHECK_MSG(
5658
quant_min <= quant_max,
@@ -77,13 +79,15 @@ Tensor& dequantize_per_tensor_out(
7779
int64_t quant_min,
7880
int64_t quant_max,
7981
ScalarType dtype,
82+
exec_aten::optional<ScalarType> out_dtype,
8083
Tensor& out) {
8184
torch::executor::Error err = resize_tensor(out, input.sizes());
8285
ET_CHECK_MSG(
8386
err == torch::executor::Error::Ok,
8487
"Failed to resize out Tensor in dequantize_per_tensor_out");
8588

86-
check_dequantize_per_tensor_args(input, quant_min, quant_max, dtype, out);
89+
check_dequantize_per_tensor_args(
90+
input, quant_min, quant_max, dtype, out_dtype, out);
8791

8892
// calculate the dequantized output, cast scale to float to match fbgemm
8993
// behavior
@@ -128,6 +132,7 @@ Tensor& dequantize_per_tensor_tensor_args_out(
128132
int64_t quant_min,
129133
int64_t quant_max,
130134
ScalarType dtype,
135+
exec_aten::optional<ScalarType> out_dtype,
131136
Tensor& out) {
132137
ET_CHECK_MSG(
133138
scale.scalar_type() == ScalarType::Double,
@@ -153,6 +158,7 @@ Tensor& dequantize_per_tensor_tensor_args_out(
153158
quant_min,
154159
quant_max,
155160
dtype,
161+
out_dtype,
156162
out);
157163
return out;
158164
}
@@ -165,6 +171,7 @@ Tensor& dequantize_per_channel_out(
165171
int64_t quant_min,
166172
int64_t quant_max,
167173
ScalarType dtype,
174+
exec_aten::optional<ScalarType> out_dtype,
168175
Tensor& out) {
169176
torch::executor::Error err = resize_tensor(out, input.sizes());
170177

@@ -205,7 +212,8 @@ Tensor& dequantize_per_channel_out(
205212
ssize_t(zero_point.numel()),
206213
ssize_t(input.size(axis)));
207214

208-
check_dequantize_per_tensor_args(input, quant_min, quant_max, dtype, out);
215+
check_dequantize_per_tensor_args(
216+
input, quant_min, quant_max, dtype, out_dtype, out);
209217

210218
// a list contains all dimensions except axis
211219
int64_t dims[input.dim() - 1];
@@ -281,10 +289,19 @@ Tensor& dequantize_per_channel_out(
281289
int64_t quant_min,
282290
int64_t quant_max,
283291
ScalarType dtype,
292+
exec_aten::optional<ScalarType> out_dtype,
284293
Tensor& out) {
285294
(void)context;
286295
return dequantize_per_channel_out(
287-
input, scale, zero_point, axis, quant_min, quant_max, dtype, out);
296+
input,
297+
scale,
298+
zero_point,
299+
axis,
300+
quant_min,
301+
quant_max,
302+
dtype,
303+
out_dtype,
304+
out);
288305
}
289306

290307
Tensor& dequantize_per_tensor_out(
@@ -295,12 +312,13 @@ Tensor& dequantize_per_tensor_out(
295312
int64_t quant_min,
296313
int64_t quant_max,
297314
ScalarType dtype,
315+
exec_aten::optional<ScalarType> out_dtype,
298316
Tensor& out) {
299317
// TODO(larryliu): Add a context arg to the real op function and remove this
300318
// wrapper
301319
(void)context;
302320
return dequantize_per_tensor_out(
303-
input, scale, zero_point, quant_min, quant_max, dtype, out);
321+
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
304322
}
305323

306324
Tensor& dequantize_per_tensor_tensor_args_out(
@@ -311,12 +329,13 @@ Tensor& dequantize_per_tensor_tensor_args_out(
311329
int64_t quant_min,
312330
int64_t quant_max,
313331
ScalarType dtype,
332+
exec_aten::optional<ScalarType> out_dtype,
314333
Tensor& out) {
315334
// TODO(larryliu): Add a context arg to the real op function and remove this
316335
// wrapper
317336
(void)context;
318337
return dequantize_per_tensor_tensor_args_out(
319-
input, scale, zero_point, quant_min, quant_max, dtype, out);
338+
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
320339
}
321340

322341
} // namespace native

kernels/quantized/quantized.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
- arg_meta: null
1111
kernel_name: torch::executor::choose_qparams_tensor_out
1212

13-
- func: quantized_decomposed::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
13+
- func: quantized_decomposed::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None, Tensor(a!) out) -> Tensor(a!)
1414
variants: function
1515
kernels:
1616
- arg_meta: null
1717
kernel_name: torch::executor::dequantize_per_tensor_out
1818

19-
- func: quantized_decomposed::dequantize_per_tensor.Tensor_out(Tensor input, Tensor scale, Tensor zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
19+
- func: quantized_decomposed::dequantize_per_tensor.Tensor_out(Tensor input, Tensor scale, Tensor zero_point, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None, Tensor(a!) out) -> Tensor(a!)
2020
variants: function
2121
kernels:
2222
- arg_meta: null
@@ -28,7 +28,7 @@
2828
- arg_meta: null
2929
kernel_name: torch::executor::quantize_per_channel_out
3030

31-
- func: quantized_decomposed::dequantize_per_channel.out(Tensor input, Tensor scales, Tensor zero_points, int axis, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
31+
- func: quantized_decomposed::dequantize_per_channel.out(Tensor input, Tensor scales, Tensor zero_points, int axis, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None, Tensor(a!) out) -> Tensor(a!)
3232
variants: function
3333
kernels:
3434
- arg_meta: null

kernels/quantized/test/op_add_test.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
using namespace ::testing;
2222
using exec_aten::ArrayRef;
23+
using exec_aten::optional;
2324
using exec_aten::RuntimeContext;
2425
using exec_aten::Scalar;
2526
using exec_aten::ScalarType;
@@ -190,6 +191,8 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) {
190191
Tensor qinput2 = tfo.zeros({3, 5});
191192
Tensor qoutput = tfo.zeros({3, 5});
192193

194+
optional<ScalarType> out_dtype = optional<ScalarType>();
195+
193196
RuntimeContext context{};
194197
// q -> qadd -> dq
195198
// 3.5 / 0.5 + 1 = 8
@@ -235,6 +238,7 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) {
235238
quant_min,
236239
quant_max,
237240
ScalarType::Byte,
241+
out_dtype,
238242
reference_op_output);
239243

240244
// now get results for q -> dq -> fp add -> q -> dq
@@ -245,6 +249,7 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) {
245249
quant_min,
246250
quant_max,
247251
ScalarType::Byte,
252+
out_dtype,
248253
dq_input1);
249254

250255
dequantize_per_tensor_out(
@@ -254,6 +259,7 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) {
254259
quant_min,
255260
quant_max,
256261
ScalarType::Byte,
262+
out_dtype,
257263
dq_input2);
258264

259265
add_out(context, dq_input1, dq_input2, 1.0, fp_output);
@@ -274,6 +280,7 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) {
274280
quant_min,
275281
quant_max,
276282
ScalarType::Byte,
283+
out_dtype,
277284
reference_pattern_output);
278285

279286
Tensor expected = tf.full({3, 5}, 7.0);

kernels/quantized/test/op_dequantize_test.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
using namespace ::testing;
2020
using exec_aten::ArrayRef;
21+
using exec_aten::optional;
2122
using exec_aten::Scalar;
2223
using exec_aten::ScalarType;
2324
using exec_aten::Tensor;
@@ -43,7 +44,14 @@ void test_dtype() {
4344
// (100 - 30) * 0.5
4445
Tensor expected = tfo.full({3, 5}, 35);
4546
dequantize_per_tensor_out(
46-
input, scale, zero_point, quant_min, quant_max, DTYPE, out);
47+
input,
48+
scale,
49+
zero_point,
50+
quant_min,
51+
quant_max,
52+
DTYPE,
53+
optional<ScalarType>(),
54+
out);
4755

4856
EXPECT_TENSOR_EQ(out, expected);
4957
}
@@ -66,7 +74,14 @@ TEST(OpDequantizeOutTest, NonWholeNumbers) {
6674
// (100 - 30) * 0.5
6775
Tensor expected = tfo.full({3, 5}, 31.5);
6876
dequantize_per_tensor_out(
69-
input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out);
77+
input,
78+
scale,
79+
zero_point,
80+
quant_min,
81+
quant_max,
82+
ScalarType::Byte,
83+
optional<ScalarType>(),
84+
out);
7085

7186
EXPECT_TENSOR_EQ(out, expected);
7287
}
@@ -87,7 +102,14 @@ TEST(OpDequantizeOutTest, TensorArgOverload) {
87102
// (100 - 30) * 0.5
88103
Tensor expected = tfo.full({3, 5}, 31.5);
89104
dequantize_per_tensor_tensor_args_out(
90-
input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out);
105+
input,
106+
scale,
107+
zero_point,
108+
quant_min,
109+
quant_max,
110+
ScalarType::Byte,
111+
optional<ScalarType>(),
112+
out);
91113

92114
EXPECT_TENSOR_EQ(out, expected);
93115
}
@@ -116,6 +138,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
116138
quant_min,
117139
quant_max,
118140
ScalarType::Byte,
141+
optional<ScalarType>(),
119142
out);
120143

121144
EXPECT_TENSOR_EQ(out, expected);
@@ -136,6 +159,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
136159
quant_min,
137160
quant_max,
138161
ScalarType::Byte,
162+
optional<ScalarType>(),
139163
out);
140164

141165
EXPECT_TENSOR_EQ(out, expected);

kernels/quantized/test/op_embedding_test.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
using namespace ::testing;
2222
using exec_aten::ArrayRef;
23+
using exec_aten::optional;
2324
using exec_aten::RuntimeContext;
2425
using exec_aten::Scalar;
2526
using exec_aten::ScalarType;
@@ -149,6 +150,7 @@ TEST(OpQuantizedEmbeddingTest, ConsitencyWithReferencePattern) {
149150
quant_min,
150151
quant_max,
151152
ScalarType::Byte,
153+
optional<ScalarType>(),
152154
weight);
153155

154156
embedding_out(

0 commit comments

Comments
 (0)