Skip to content

Commit 6d56361

Browse files
author
morelos
committed
[ET-VK][Ops] choose_qparams_per_token_asymmetric.default test setup
Creating choose_qparams_per_token_asymmetric.default operator testing framework along with a reference implementation for testing Differential Revision: [D76436906](https://our.internmc.facebook.com/intern/diff/D76436906/) ghstack-source-id: 289707211 Pull Request resolved: #11556
1 parent d6f5c49 commit 6d56361

File tree

1 file changed

+285
-0
lines changed

1 file changed

+285
-0
lines changed

backends/vulkan/test/op_tests/choose_qparams_test.cpp

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,111 @@ std::tuple<at::Tensor, at::Tensor> choose_qparams_tensor_reference_impl(
311311
return std::make_tuple(scale_out, zero_point_out);
312312
}
313313

314+
/*
315+
* Reference implementation of choose_qparams_per_token_asymmetric
316+
*/
317+
std::tuple<at::Tensor, at::Tensor>
318+
choose_qparams_per_token_asymmetric_reference_impl(
319+
const at::Tensor& input,
320+
at::ScalarType dtype) {
321+
// For per-token quantization, we need to compute scale and zero_point for
322+
// each token
323+
int64_t quant_min = -128;
324+
int64_t quant_max = 127;
325+
326+
// Calculate output sizes
327+
std::vector<int64_t> output_sizes;
328+
for (int64_t i = 0; i < input.dim() - 1; i++) {
329+
output_sizes.push_back(input.size(i));
330+
}
331+
output_sizes.push_back(1);
332+
333+
// Create output tensors
334+
at::Tensor scale_out =
335+
at::empty(output_sizes, at::device(at::kCPU).dtype(at::kDouble));
336+
at::Tensor zero_point_out =
337+
at::empty(output_sizes, at::device(at::kCPU).dtype(at::kLong));
338+
339+
// Calculate number of tokens
340+
int64_t num_tokens = 1;
341+
for (int64_t i = 0; i < input.dim() - 1; i++) {
342+
num_tokens *= input.size(i);
343+
}
344+
345+
// Reshape input to [num_tokens, last_dim]
346+
at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)});
347+
348+
// Process each token
349+
for (int64_t token_idx = 0; token_idx < num_tokens; token_idx++) {
350+
at::Tensor token = reshaped_input[token_idx];
351+
352+
// Find min and max values for this token
353+
float min_val = token.min().item<float>();
354+
float max_val = token.max().item<float>();
355+
356+
// Extend the [min, max] interval to ensure it contains 0
357+
min_val = std::min(min_val, 0.f);
358+
max_val = std::max(max_val, 0.f);
359+
360+
// Calculate scale
361+
double scale =
362+
(static_cast<double>(max_val) - min_val) / (quant_max - quant_min);
363+
364+
// Handle small scale
365+
constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f;
366+
if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) {
367+
scale = 0.1;
368+
}
369+
370+
if (scale < SMALL_SCALE_THRESHOLD) {
371+
float org_scale = scale;
372+
scale = SMALL_SCALE_THRESHOLD;
373+
// Adjust min and max based on new scale
374+
if (min_val == 0.0f) {
375+
max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
376+
} else if (max_val == 0.0f) {
377+
min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
378+
} else {
379+
float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
380+
min_val *= amplifier;
381+
max_val *= amplifier;
382+
}
383+
}
384+
385+
// Calculate zero point
386+
double zero_point_from_min =
387+
quant_min - min_val / static_cast<double>(scale);
388+
double zero_point_from_max =
389+
quant_max - max_val / static_cast<double>(scale);
390+
double zero_point_from_min_error =
391+
std::abs(quant_min) - std::abs(min_val / static_cast<double>(scale));
392+
double zero_point_from_max_error =
393+
std::abs(quant_max) - std::abs(max_val / static_cast<double>(scale));
394+
double initial_zero_point =
395+
zero_point_from_min_error < zero_point_from_max_error
396+
? zero_point_from_min
397+
: zero_point_from_max;
398+
399+
// Nudge zero point to be an integer
400+
int64_t nudged_zero_point = 0;
401+
if (initial_zero_point < quant_min) {
402+
nudged_zero_point = quant_min;
403+
} else if (initial_zero_point > quant_max) {
404+
nudged_zero_point = quant_max;
405+
} else {
406+
nudged_zero_point =
407+
std::nearbyint(static_cast<float>(initial_zero_point));
408+
}
409+
410+
// Set output values for this token - use index_put_ for safety
411+
scale_out.view({num_tokens, 1}).index_put_({token_idx, 0}, scale);
412+
zero_point_out.view({num_tokens, 1})
413+
.index_put_({token_idx, 0}, nudged_zero_point);
414+
}
415+
416+
return std::make_tuple(scale_out, zero_point_out);
417+
}
418+
314419
// Forward declaration of implementation functions
315420
void test_vulkan_choose_qparams_tensor_impl(
316421
const std::vector<int>& input_sizes,
@@ -320,6 +425,12 @@ void test_vulkan_choose_qparams_tensor_impl(
320425
const vkcompute::utils::StorageType in_storage,
321426
const vkcompute::utils::StorageType out_storage);
322427

428+
void test_vulkan_choose_qparams_per_token_asymmetric_impl(
429+
const std::vector<int>& input_sizes,
430+
at::ScalarType dtype,
431+
const vkcompute::utils::StorageType in_storage,
432+
const vkcompute::utils::StorageType out_storage);
433+
323434
// Wrapper function to test both buffer and texture storage types
324435
void test_vulkan_choose_qparams_tensor(
325436
const std::vector<int>& input_sizes,
@@ -345,6 +456,22 @@ void test_vulkan_choose_qparams_tensor(
345456
vkcompute::utils::kTexture3D);
346457
}
347458

459+
// Wrapper function to test both buffer and texture storage types
460+
void test_vulkan_choose_qparams_per_token_asymmetric(
461+
const std::vector<int>& input_sizes,
462+
at::ScalarType dtype) {
463+
// Test with buffer storage
464+
test_vulkan_choose_qparams_per_token_asymmetric_impl(
465+
input_sizes, dtype, vkcompute::utils::kBuffer, vkcompute::utils::kBuffer);
466+
467+
// Test with texture storage
468+
test_vulkan_choose_qparams_per_token_asymmetric_impl(
469+
input_sizes,
470+
dtype,
471+
vkcompute::utils::kTexture3D,
472+
vkcompute::utils::kTexture3D);
473+
}
474+
348475
void test_reference_choose_qparams_tensor(
349476
const std::vector<int>& input_sizes,
350477
int64_t quant_min,
@@ -506,3 +633,161 @@ TEST(VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) {
506633
127, // quant_max
507634
at::kChar);
508635
}
636+
637+
void test_reference_choose_qparams_per_token_asymmetric(
638+
const std::vector<int>& input_sizes,
639+
at::ScalarType dtype) {
640+
std::vector<int64_t> input_sizes_int64(
641+
input_sizes.begin(), input_sizes.end());
642+
at::Tensor input =
643+
at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat));
644+
645+
// Get reference output
646+
auto [reference_scale, reference_zero_point] =
647+
choose_qparams_per_token_asymmetric_reference_impl(input, dtype);
648+
649+
// Get implementation output
650+
auto [impl_scale, impl_zero_point] =
651+
torch::executor::native::choose_qparams_per_token_asymmetric_aten(
652+
input, dtype);
653+
654+
// Compare outputs
655+
const bool scale_correct = at::allclose(reference_scale, impl_scale);
656+
const bool zero_point_correct =
657+
at::equal(reference_zero_point, impl_zero_point);
658+
659+
if (!scale_correct || !zero_point_correct) {
660+
std::cout << "\n"
661+
<< "Failed with parameters: " << std::endl;
662+
663+
std::cout << "input:" << std::endl;
664+
std::cout << input << std::endl;
665+
std::cout << "reference scale:" << std::endl;
666+
std::cout << reference_scale << std::endl;
667+
std::cout << "implementation scale:" << std::endl;
668+
std::cout << impl_scale << std::endl;
669+
std::cout << "reference zero_point:" << std::endl;
670+
std::cout << reference_zero_point << std::endl;
671+
std::cout << "implementation zero_point:" << std::endl;
672+
std::cout << impl_zero_point << std::endl;
673+
}
674+
675+
ASSERT_TRUE(scale_correct && zero_point_correct);
676+
}
677+
678+
void test_vulkan_choose_qparams_per_token_asymmetric_impl(
679+
const std::vector<int>& input_sizes,
680+
at::ScalarType dtype,
681+
const vkcompute::utils::StorageType in_storage,
682+
const vkcompute::utils::StorageType out_storage) {
683+
std::vector<int64_t> input_sizes_int64(
684+
input_sizes.begin(), input_sizes.end());
685+
at::Tensor input =
686+
at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat));
687+
688+
// Calculate output sizes
689+
std::vector<int64_t> output_sizes;
690+
for (int64_t i = 0; i < input.dim() - 1; i++) {
691+
output_sizes.push_back(input.size(i));
692+
}
693+
output_sizes.push_back(1);
694+
695+
// Get reference output
696+
auto [reference_scale, reference_zero_point] =
697+
torch::executor::native::choose_qparams_per_token_asymmetric_aten(
698+
input, dtype);
699+
700+
// Build Vulkan choose_qparams_per_token_asymmetric graph
701+
using namespace vkcompute;
702+
703+
GraphConfig config;
704+
config.set_storage_type_override(in_storage);
705+
ComputeGraph graph(config);
706+
707+
IOValueRef r_input = graph.add_input_tensor(
708+
input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage);
709+
710+
// Output tensors
711+
const ValueRef r_scale =
712+
graph.add_tensor(output_sizes, vkapi::kFloat, out_storage);
713+
const ValueRef r_zero_point =
714+
graph.add_tensor(output_sizes, vkapi::kInt, out_storage);
715+
716+
VK_GET_OP_FN("choose_qparams_per_token_asymmetric.default")
717+
(graph,
718+
{
719+
r_input.value,
720+
r_scale,
721+
r_zero_point,
722+
});
723+
724+
ValueRef staging_scale = graph.set_output_tensor(r_scale);
725+
ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point);
726+
727+
graph.prepare();
728+
graph.encode_prepack();
729+
graph.prepack();
730+
graph.encode_execute();
731+
732+
// Run Vulkan choose_qparams_per_token_asymmetric
733+
graph.copy_into_staging(
734+
r_input.staging, input.const_data_ptr(), input.numel());
735+
736+
graph.execute();
737+
738+
// Create output tensors to hold the results - use types that match GPU output
739+
at::Tensor vk_scale =
740+
at::empty(output_sizes, at::device(at::kCPU).dtype(at::kFloat))
741+
.contiguous();
742+
at::Tensor vk_zero_point =
743+
at::empty(output_sizes, at::device(at::kCPU).dtype(at::kInt))
744+
.contiguous();
745+
746+
// Copy results from GPU to CPU
747+
graph.copy_from_staging(
748+
staging_scale, vk_scale.mutable_data_ptr(), vk_scale.numel());
749+
graph.copy_from_staging(
750+
staging_zero_point,
751+
vk_zero_point.mutable_data_ptr(),
752+
vk_zero_point.numel());
753+
754+
// Convert reference values to match Vulkan output types for comparison
755+
at::Tensor reference_scale_float = reference_scale.to(at::kFloat);
756+
at::Tensor reference_zero_point_int = reference_zero_point.to(at::kInt);
757+
758+
// Compare outputs
759+
const bool scale_correct = at::allclose(reference_scale_float, vk_scale);
760+
const bool zero_point_correct =
761+
at::equal(reference_zero_point_int, vk_zero_point);
762+
if (!scale_correct || !zero_point_correct) {
763+
std::cout << "\n"
764+
<< "Failed with parameters: " << std::endl;
765+
std::cout << " storage type: "
766+
<< (in_storage == vkcompute::utils::kBuffer ? "buffer"
767+
: "texture")
768+
<< std::endl;
769+
770+
if (input.numel() < 100) {
771+
std::cout << "input:" << std::endl;
772+
std::cout << input << "\n" << std::endl;
773+
std::cout << "reference scale:" << std::endl;
774+
std::cout << reference_scale << std::endl;
775+
std::cout << "vulkan scale:" << std::endl;
776+
std::cout << vk_scale << "\n" << std::endl;
777+
std::cout << "reference zero_point:" << std::endl;
778+
std::cout << reference_zero_point << std::endl;
779+
std::cout << "vulkan zero_point:" << std::endl;
780+
std::cout << vk_zero_point << std::endl;
781+
}
782+
}
783+
784+
ASSERT_TRUE(scale_correct && zero_point_correct);
785+
}
786+
787+
TEST(
788+
VulkanChooseQparamsTest,
789+
test_reference_choose_qparams_per_token_asymmetric_int8) {
790+
test_reference_choose_qparams_per_token_asymmetric(
791+
{2, 3, 4}, // input sizes (2*3=6 tokens)
792+
at::kChar);
793+
}

0 commit comments

Comments
 (0)