Skip to content

Commit 9fb73e4

Browse files
author
morelos
committed
[ET-VK][Ops] enabling double support for quantization and dequantization ops
With the added double support in the layout template, this diff is enabling it as input/output for dequantization. Since there are limitations with how 64bit can be supported, the expectation is that IO be downgraded to 32bit Differential Revision: [D76289197](https://our.internmc.facebook.com/intern/diff/D76289197/) ghstack-source-id: 289707203 Pull Request resolved: #11553
1 parent 6bb8b53 commit 9fb73e4

File tree

6 files changed

+94
-2
lines changed

6 files changed

+94
-2
lines changed

backends/vulkan/runtime/graph/ops/glsl/dequantize.glsl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,10 @@ $if MODE == "per_tensor":
155155
[[unroll]] for (int i = 0; i < 4; ++i) {
156156
IN_T qvalue = IN_T(intex[i]);
157157
OUT_T value = dequantize_val(qvalue, scale, zero_point);
158-
outtex[i] = value;
158+
$if OUT_DTYPE == "double":
159+
outtex[i] = float(value);
160+
$else:
161+
outtex[i] = value;
159162
}
160163
write_texel(t_out, pos, outtex);
161164

@@ -198,7 +201,10 @@ $if MODE == "per_token":
198201
[[unroll]] for (int i = 0; i < 4; ++i) {
199202
IN_T qvalue = IN_T(intex[i]);
200203
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
201-
outtex[i] = value;
204+
$if OUT_DTYPE == "double":
205+
outtex[i] = float(value);
206+
$else:
207+
outtex[i] = value;
202208
}
203209

204210
write_texel(t_out, pos, outtex);

backends/vulkan/runtime/graph/ops/glsl/dequantize.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ dequantize:
1515
OUT_DTYPE:
1616
- VALUE: half
1717
- VALUE: float
18+
- VALUE: double
1819
shader_variants:
1920
- NAME: dequantize_per_tensor
2021
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/quantize.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ quantize:
1111
IN_DTYPE:
1212
- VALUE: half
1313
- VALUE: float
14+
- VALUE: double
1415
OUT_DTYPE:
1516
- VALUE: uint8
1617
- VALUE: int8

backends/vulkan/runtime/graph/ops/impl/Quantize.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ void quantize_per_tensor_impl(
191191

192192
// Verify input is a floating point type
193193
VK_CHECK_COND(
194+
graph.dtype_of(input) == vkapi::kDouble ||
194195
graph.dtype_of(input) == vkapi::kFloat ||
195196
graph.dtype_of(input) == vkapi::kHalf);
196197

@@ -214,6 +215,7 @@ void quantize_per_token_impl(
214215

215216
// Verify input is a floating point type
216217
VK_CHECK_COND(
218+
graph.dtype_of(input) == vkapi::kDouble ||
217219
graph.dtype_of(input) == vkapi::kFloat ||
218220
graph.dtype_of(input) == vkapi::kHalf);
219221

backends/vulkan/test/op_tests/dequantize_test.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,12 @@ void test_vulkan_dequantize_per_tensor(
462462
vkcompute::utils::kBuffer,
463463
vkcompute::utils::kBuffer);
464464

465+
// Telling the system to expect a float instead of a double
466+
// since the shader can only return 32bit anyways
467+
if (out_dtype == at::kDouble) {
468+
out_dtype = at::kFloat;
469+
}
470+
465471
// Test with texture storage
466472
test_vulkan_dequantize_per_tensor_impl(
467473
input_sizes,
@@ -496,6 +502,12 @@ void test_vulkan_dequantize_per_token(
496502
vkcompute::utils::kBuffer,
497503
vkcompute::utils::kBuffer);
498504

505+
// Telling the system to expect a float instead of a double
506+
// since the shader can only return 32bit anyways
507+
if (out_dtype == at::kDouble) {
508+
out_dtype = at::kFloat;
509+
}
510+
499511
// Test with texture storage
500512
test_vulkan_dequantize_per_token_impl(
501513
input_sizes,
@@ -790,6 +802,19 @@ TEST(
790802
at::kFloat); // output dtype
791803
}
792804

805+
TEST(
806+
VulkanDequantizePerTensorTest,
807+
test_vulkan_dequantize_per_tensor_int32_to_double) {
808+
test_vulkan_dequantize_per_tensor(
809+
{2, 4, 3}, // input sizes
810+
0.0001, // scale
811+
100, // zero_point
812+
-2147483648, // quant_min
813+
2147483647, // quant_max
814+
at::kInt, // input dtype
815+
at::kDouble); // output dtype
816+
}
817+
793818
void test_reference_dequantize_per_token(
794819
const std::vector<int>& input_sizes,
795820
const std::vector<float>& scales,
@@ -1165,3 +1190,19 @@ TEST(
11651190
at::kInt, // input dtype
11661191
at::kFloat); // output dtype
11671192
}
1193+
1194+
TEST(
1195+
VulkanDequantizePerTokenTest,
1196+
test_vulkan_dequantize_per_token_int32_to_double) {
1197+
std::vector<float> scales = {0.0001, 0.0002, 0.0003, 0.0};
1198+
std::vector<int> zero_points = {100, -100, 50, -50};
1199+
1200+
test_vulkan_dequantize_per_token(
1201+
{2, 2, 8}, // input sizes (2*2=4 tokens)
1202+
scales,
1203+
zero_points,
1204+
-2147483648, // quant_min
1205+
2147483647, // quant_max
1206+
at::kInt, // input dtype
1207+
at::kDouble); // output dtype
1208+
}

backends/vulkan/test/op_tests/quantize_test.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,12 @@ void test_vulkan_quantize_per_tensor(
432432
vkcompute::utils::kBuffer,
433433
vkcompute::utils::kBuffer);
434434

435+
// If the in_dtype is a double, convert to float for texture implementation
436+
// since they don't support 64bit as inputs
437+
if (in_dtype == at::kDouble) {
438+
in_dtype = at::kFloat;
439+
}
440+
435441
// Test with texture storage
436442
test_vulkan_quantize_per_tensor_impl(
437443
input_sizes,
@@ -466,6 +472,12 @@ void test_vulkan_quantize_per_token(
466472
vkcompute::utils::kBuffer,
467473
vkcompute::utils::kBuffer);
468474

475+
// If the in_dtype is a double, convert to float for texture implementation
476+
// since they don't support 64bit as inputs
477+
if (in_dtype == at::kDouble) {
478+
in_dtype = at::kFloat;
479+
}
480+
469481
// Test with texture storage
470482
test_vulkan_quantize_per_token_impl(
471483
input_sizes,
@@ -718,6 +730,19 @@ TEST(
718730
at::kChar); // output dtype
719731
}
720732

733+
TEST(
734+
VulkanQuantizePerTensorTest,
735+
test_vulkan_quantize_per_tensor_double_to_int8) {
736+
test_vulkan_quantize_per_tensor(
737+
{2, 3}, // input sizes
738+
0.01, // scale
739+
1, // zero_point
740+
-128, // quant_min
741+
127, // quant_max
742+
at::kDouble, // input dtype
743+
at::kChar); // output dtype
744+
}
745+
721746
void test_reference_quantize_per_token(
722747
const std::vector<int>& input_sizes,
723748
const std::vector<float>& pre_scales,
@@ -1064,3 +1089,19 @@ TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) {
10641089
at::kHalf, // input dtype
10651090
at::kChar); // output dtype
10661091
}
1092+
1093+
TEST(
1094+
VulkanQuantizePerTensorTest,
1095+
test_vulkan_quantize_per_token_double_to_int8) {
1096+
std::vector<float> scales = {0.1, 0.2};
1097+
std::vector<int> zero_points = {0, 5};
1098+
1099+
test_vulkan_quantize_per_token(
1100+
{2, 2}, // input sizes (2*2=4 tokens)
1101+
scales,
1102+
zero_points,
1103+
-128, // quant_min
1104+
127, // quant_max
1105+
at::kDouble, // input dtype
1106+
at::kChar); // output dtype
1107+
}

0 commit comments

Comments
 (0)