Skip to content

Commit 285333d

Browse files
author
morelos
committed
[ET-VK][Ops] enabling double support for quantization and dequantization ops
Pull Request resolved: #11553 With the added double support in the layout template, this diff is enabling it as input/output for dequantization. Since there are limitations with how 64bit can be supported, the expectation is that IO be downgraded to 32bit ghstack-source-id: 289798618 @exported-using-ghexport Differential Revision: [D76289197](https://our.internmc.facebook.com/intern/diff/D76289197/)
1 parent 9b20d26 commit 285333d

File tree

6 files changed

+94
-2
lines changed

6 files changed

+94
-2
lines changed

backends/vulkan/runtime/graph/ops/glsl/dequantize.glsl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,10 @@ $if MODE == "per_tensor":
155155
[[unroll]] for (int i = 0; i < 4; ++i) {
156156
IN_T qvalue = IN_T(intex[i]);
157157
OUT_T value = dequantize_val(qvalue, scale, zero_point);
158-
outtex[i] = value;
158+
$if OUT_DTYPE == "double":
159+
outtex[i] = float(value);
160+
$else:
161+
outtex[i] = value;
159162
}
160163
write_texel(t_out, pos, outtex);
161164

@@ -198,7 +201,10 @@ $if MODE == "per_token":
198201
[[unroll]] for (int i = 0; i < 4; ++i) {
199202
IN_T qvalue = IN_T(intex[i]);
200203
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
201-
outtex[i] = value;
204+
$if OUT_DTYPE == "double":
205+
outtex[i] = float(value);
206+
$else:
207+
outtex[i] = value;
202208
}
203209

204210
write_texel(t_out, pos, outtex);

backends/vulkan/runtime/graph/ops/glsl/dequantize.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ dequantize:
1515
OUT_DTYPE:
1616
- VALUE: half
1717
- VALUE: float
18+
- VALUE: double
1819
shader_variants:
1920
- NAME: dequantize_per_tensor
2021
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/quantize.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ quantize:
1111
IN_DTYPE:
1212
- VALUE: half
1313
- VALUE: float
14+
- VALUE: double
1415
OUT_DTYPE:
1516
- VALUE: uint8
1617
- VALUE: int8

backends/vulkan/runtime/graph/ops/impl/Quantize.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ void quantize_per_tensor_impl(
191191

192192
// Verify input is a floating point type
193193
VK_CHECK_COND(
194+
graph.dtype_of(input) == vkapi::kDouble ||
194195
graph.dtype_of(input) == vkapi::kFloat ||
195196
graph.dtype_of(input) == vkapi::kHalf);
196197

@@ -214,6 +215,7 @@ void quantize_per_token_impl(
214215

215216
// Verify input is a floating point type
216217
VK_CHECK_COND(
218+
graph.dtype_of(input) == vkapi::kDouble ||
217219
graph.dtype_of(input) == vkapi::kFloat ||
218220
graph.dtype_of(input) == vkapi::kHalf);
219221

backends/vulkan/test/op_tests/dequantize_test.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,12 @@ void test_vulkan_dequantize_per_tensor(
369369
vkcompute::utils::kBuffer,
370370
vkcompute::utils::kBuffer);
371371

372+
// Telling the system to expect a float instead of a double
373+
// since the shader can only return 32bit anyways
374+
if (out_dtype == at::kDouble) {
375+
out_dtype = at::kFloat;
376+
}
377+
372378
// Test with texture storage
373379
test_vulkan_dequantize_per_tensor_impl(
374380
input_sizes,
@@ -403,6 +409,12 @@ void test_vulkan_dequantize_per_token(
403409
vkcompute::utils::kBuffer,
404410
vkcompute::utils::kBuffer);
405411

412+
// Telling the system to expect a float instead of a double
413+
// since the shader can only return 32bit anyways
414+
if (out_dtype == at::kDouble) {
415+
out_dtype = at::kFloat;
416+
}
417+
406418
// Test with texture storage
407419
test_vulkan_dequantize_per_token_impl(
408420
input_sizes,
@@ -772,6 +784,19 @@ TEST(
772784
at::kHalf); // output dtype
773785
}
774786

787+
TEST(
788+
VulkanDequantizePerTensorTest,
789+
test_vulkan_dequantize_per_tensor_int32_to_double) {
790+
test_vulkan_dequantize_per_tensor(
791+
{2, 4, 3}, // input sizes
792+
0.0001, // scale
793+
100, // zero_point
794+
-2147483648, // quant_min
795+
2147483647, // quant_max
796+
at::kInt, // input dtype
797+
at::kDouble); // output dtype
798+
}
799+
775800
void test_reference_dequantize_per_token(
776801
const std::vector<int>& input_sizes,
777802
const std::vector<float>& scales,
@@ -1237,3 +1262,19 @@ TEST(
12371262
at::kInt, // input dtype
12381263
at::kHalf); // output dtype
12391264
}
1265+
1266+
TEST(
1267+
VulkanDequantizePerTokenTest,
1268+
test_vulkan_dequantize_per_token_int32_to_double) {
1269+
std::vector<float> scales = {0.0001, 0.0002, 0.0003, 0.0};
1270+
std::vector<int> zero_points = {100, -100, 50, -50};
1271+
1272+
test_vulkan_dequantize_per_token(
1273+
{2, 2, 8}, // input sizes (2*2=4 tokens)
1274+
scales,
1275+
zero_points,
1276+
-2147483648, // quant_min
1277+
2147483647, // quant_max
1278+
at::kInt, // input dtype
1279+
at::kDouble); // output dtype
1280+
}

backends/vulkan/test/op_tests/quantize_test.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,12 @@ void test_vulkan_quantize_per_tensor(
314314
vkcompute::utils::kBuffer,
315315
vkcompute::utils::kBuffer);
316316

317+
// If the in_dtype is a double, convert to float for texture implementation
318+
// since they don't support 64bit as inputs
319+
if (in_dtype == at::kDouble) {
320+
in_dtype = at::kFloat;
321+
}
322+
317323
// Test with texture storage
318324
test_vulkan_quantize_per_tensor_impl(
319325
input_sizes,
@@ -348,6 +354,12 @@ void test_vulkan_quantize_per_token(
348354
vkcompute::utils::kBuffer,
349355
vkcompute::utils::kBuffer);
350356

357+
// If the in_dtype is a double, convert to float for texture implementation
358+
// since they don't support 64bit as inputs
359+
if (in_dtype == at::kDouble) {
360+
in_dtype = at::kFloat;
361+
}
362+
351363
// Test with texture storage
352364
test_vulkan_quantize_per_token_impl(
353365
input_sizes,
@@ -639,6 +651,19 @@ TEST(
639651
at::kChar); // output dtype
640652
}
641653

654+
TEST(
655+
VulkanQuantizePerTensorTest,
656+
test_vulkan_quantize_per_tensor_double_to_int8) {
657+
test_vulkan_quantize_per_tensor(
658+
{2, 3}, // input sizes
659+
0.01, // scale
660+
1, // zero_point
661+
-128, // quant_min
662+
127, // quant_max
663+
at::kDouble, // input dtype
664+
at::kChar); // output dtype
665+
}
666+
642667
void test_reference_quantize_per_token(
643668
const std::vector<int>& input_sizes,
644669
const std::vector<float>& pre_scales,
@@ -1033,3 +1058,19 @@ TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) {
10331058
at::kHalf, // input dtype
10341059
at::kChar); // output dtype
10351060
}
1061+
1062+
TEST(
1063+
VulkanQuantizePerTensorTest,
1064+
test_vulkan_quantize_per_token_double_to_int8) {
1065+
std::vector<float> scales = {0.1, 0.2};
1066+
std::vector<int> zero_points = {0, 5};
1067+
1068+
test_vulkan_quantize_per_token(
1069+
{2, 2}, // input sizes (2*2=4 tokens)
1070+
scales,
1071+
zero_points,
1072+
-128, // quant_min
1073+
127, // quant_max
1074+
at::kDouble, // input dtype
1075+
at::kChar); // output dtype
1076+
}

0 commit comments

Comments
 (0)