@@ -311,6 +311,111 @@ std::tuple<at::Tensor, at::Tensor> choose_qparams_tensor_reference_impl(
311
311
return std::make_tuple (scale_out, zero_point_out);
312
312
}
313
313
314
+ /*
315
+ * Reference implementation of choose_qparams_per_token_asymmetric
316
+ */
317
+ std::tuple<at::Tensor, at::Tensor>
318
+ choose_qparams_per_token_asymmetric_reference_impl (
319
+ const at::Tensor& input,
320
+ at::ScalarType dtype) {
321
+ // For per-token quantization, we need to compute scale and zero_point for
322
+ // each token
323
+ int64_t quant_min = -128 ;
324
+ int64_t quant_max = 127 ;
325
+
326
+ // Calculate output sizes
327
+ std::vector<int64_t > output_sizes;
328
+ for (int64_t i = 0 ; i < input.dim () - 1 ; i++) {
329
+ output_sizes.push_back (input.size (i));
330
+ }
331
+ output_sizes.push_back (1 );
332
+
333
+ // Create output tensors
334
+ at::Tensor scale_out =
335
+ at::empty (output_sizes, at::device (at::kCPU ).dtype (at::kDouble ));
336
+ at::Tensor zero_point_out =
337
+ at::empty (output_sizes, at::device (at::kCPU ).dtype (at::kLong ));
338
+
339
+ // Calculate number of tokens
340
+ int64_t num_tokens = 1 ;
341
+ for (int64_t i = 0 ; i < input.dim () - 1 ; i++) {
342
+ num_tokens *= input.size (i);
343
+ }
344
+
345
+ // Reshape input to [num_tokens, last_dim]
346
+ at::Tensor reshaped_input = input.reshape ({num_tokens, input.size (-1 )});
347
+
348
+ // Process each token
349
+ for (int64_t token_idx = 0 ; token_idx < num_tokens; token_idx++) {
350
+ at::Tensor token = reshaped_input[token_idx];
351
+
352
+ // Find min and max values for this token
353
+ float min_val = token.min ().item <float >();
354
+ float max_val = token.max ().item <float >();
355
+
356
+ // Extend the [min, max] interval to ensure it contains 0
357
+ min_val = std::min (min_val, 0 .f );
358
+ max_val = std::max (max_val, 0 .f );
359
+
360
+ // Calculate scale
361
+ double scale =
362
+ (static_cast <double >(max_val) - min_val) / (quant_max - quant_min);
363
+
364
+ // Handle small scale
365
+ constexpr float SMALL_SCALE_THRESHOLD = 6 .1e-5f ;
366
+ if (float (scale) == 0 .0f || std::isinf (1 .0f / float (scale))) {
367
+ scale = 0.1 ;
368
+ }
369
+
370
+ if (scale < SMALL_SCALE_THRESHOLD) {
371
+ float org_scale = scale;
372
+ scale = SMALL_SCALE_THRESHOLD;
373
+ // Adjust min and max based on new scale
374
+ if (min_val == 0 .0f ) {
375
+ max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
376
+ } else if (max_val == 0 .0f ) {
377
+ min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
378
+ } else {
379
+ float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
380
+ min_val *= amplifier;
381
+ max_val *= amplifier;
382
+ }
383
+ }
384
+
385
+ // Calculate zero point
386
+ double zero_point_from_min =
387
+ quant_min - min_val / static_cast <double >(scale);
388
+ double zero_point_from_max =
389
+ quant_max - max_val / static_cast <double >(scale);
390
+ double zero_point_from_min_error =
391
+ std::abs (quant_min) - std::abs (min_val / static_cast <double >(scale));
392
+ double zero_point_from_max_error =
393
+ std::abs (quant_max) - std::abs (max_val / static_cast <double >(scale));
394
+ double initial_zero_point =
395
+ zero_point_from_min_error < zero_point_from_max_error
396
+ ? zero_point_from_min
397
+ : zero_point_from_max;
398
+
399
+ // Nudge zero point to be an integer
400
+ int64_t nudged_zero_point = 0 ;
401
+ if (initial_zero_point < quant_min) {
402
+ nudged_zero_point = quant_min;
403
+ } else if (initial_zero_point > quant_max) {
404
+ nudged_zero_point = quant_max;
405
+ } else {
406
+ nudged_zero_point =
407
+ std::nearbyint (static_cast <float >(initial_zero_point));
408
+ }
409
+
410
+ // Set output values for this token - use index_put_ for safety
411
+ scale_out.view ({num_tokens, 1 }).index_put_ ({token_idx, 0 }, scale);
412
+ zero_point_out.view ({num_tokens, 1 })
413
+ .index_put_ ({token_idx, 0 }, nudged_zero_point);
414
+ }
415
+
416
+ return std::make_tuple (scale_out, zero_point_out);
417
+ }
418
+
314
419
// Forward declaration of implementation functions
315
420
void test_vulkan_choose_qparams_tensor_impl (
316
421
const std::vector<int >& input_sizes,
@@ -320,6 +425,12 @@ void test_vulkan_choose_qparams_tensor_impl(
320
425
const vkcompute::utils::StorageType in_storage,
321
426
const vkcompute::utils::StorageType out_storage);
322
427
428
+ void test_vulkan_choose_qparams_per_token_asymmetric_impl (
429
+ const std::vector<int >& input_sizes,
430
+ at::ScalarType dtype,
431
+ const vkcompute::utils::StorageType in_storage,
432
+ const vkcompute::utils::StorageType out_storage);
433
+
323
434
// Wrapper function to test both buffer and texture storage types
324
435
void test_vulkan_choose_qparams_tensor (
325
436
const std::vector<int >& input_sizes,
@@ -345,6 +456,22 @@ void test_vulkan_choose_qparams_tensor(
345
456
vkcompute::utils::kTexture3D );
346
457
}
347
458
459
+ // Wrapper function to test both buffer and texture storage types
460
+ void test_vulkan_choose_qparams_per_token_asymmetric (
461
+ const std::vector<int >& input_sizes,
462
+ at::ScalarType dtype) {
463
+ // Test with buffer storage
464
+ test_vulkan_choose_qparams_per_token_asymmetric_impl (
465
+ input_sizes, dtype, vkcompute::utils::kBuffer , vkcompute::utils::kBuffer );
466
+
467
+ // Test with texture storage
468
+ test_vulkan_choose_qparams_per_token_asymmetric_impl (
469
+ input_sizes,
470
+ dtype,
471
+ vkcompute::utils::kTexture3D ,
472
+ vkcompute::utils::kTexture3D );
473
+ }
474
+
348
475
void test_reference_choose_qparams_tensor (
349
476
const std::vector<int >& input_sizes,
350
477
int64_t quant_min,
@@ -506,3 +633,161 @@ TEST(VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) {
506
633
127 , // quant_max
507
634
at::kChar );
508
635
}
636
+
637
+ void test_reference_choose_qparams_per_token_asymmetric (
638
+ const std::vector<int >& input_sizes,
639
+ at::ScalarType dtype) {
640
+ std::vector<int64_t > input_sizes_int64 (
641
+ input_sizes.begin (), input_sizes.end ());
642
+ at::Tensor input =
643
+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
644
+
645
+ // Get reference output
646
+ auto [reference_scale, reference_zero_point] =
647
+ choose_qparams_per_token_asymmetric_reference_impl (input, dtype);
648
+
649
+ // Get implementation output
650
+ auto [impl_scale, impl_zero_point] =
651
+ torch::executor::native::choose_qparams_per_token_asymmetric_aten (
652
+ input, dtype);
653
+
654
+ // Compare outputs
655
+ const bool scale_correct = at::allclose (reference_scale, impl_scale);
656
+ const bool zero_point_correct =
657
+ at::equal (reference_zero_point, impl_zero_point);
658
+
659
+ if (!scale_correct || !zero_point_correct) {
660
+ std::cout << " \n "
661
+ << " Failed with parameters: " << std::endl;
662
+
663
+ std::cout << " input:" << std::endl;
664
+ std::cout << input << std::endl;
665
+ std::cout << " reference scale:" << std::endl;
666
+ std::cout << reference_scale << std::endl;
667
+ std::cout << " implementation scale:" << std::endl;
668
+ std::cout << impl_scale << std::endl;
669
+ std::cout << " reference zero_point:" << std::endl;
670
+ std::cout << reference_zero_point << std::endl;
671
+ std::cout << " implementation zero_point:" << std::endl;
672
+ std::cout << impl_zero_point << std::endl;
673
+ }
674
+
675
+ ASSERT_TRUE (scale_correct && zero_point_correct);
676
+ }
677
+
678
+ void test_vulkan_choose_qparams_per_token_asymmetric_impl (
679
+ const std::vector<int >& input_sizes,
680
+ at::ScalarType dtype,
681
+ const vkcompute::utils::StorageType in_storage,
682
+ const vkcompute::utils::StorageType out_storage) {
683
+ std::vector<int64_t > input_sizes_int64 (
684
+ input_sizes.begin (), input_sizes.end ());
685
+ at::Tensor input =
686
+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
687
+
688
+ // Calculate output sizes
689
+ std::vector<int64_t > output_sizes;
690
+ for (int64_t i = 0 ; i < input.dim () - 1 ; i++) {
691
+ output_sizes.push_back (input.size (i));
692
+ }
693
+ output_sizes.push_back (1 );
694
+
695
+ // Get reference output
696
+ auto [reference_scale, reference_zero_point] =
697
+ torch::executor::native::choose_qparams_per_token_asymmetric_aten (
698
+ input, dtype);
699
+
700
+ // Build Vulkan choose_qparams_per_token_asymmetric graph
701
+ using namespace vkcompute ;
702
+
703
+ GraphConfig config;
704
+ config.set_storage_type_override (in_storage);
705
+ ComputeGraph graph (config);
706
+
707
+ IOValueRef r_input = graph.add_input_tensor (
708
+ input.sizes ().vec (), from_at_scalartype (input.scalar_type ()), in_storage);
709
+
710
+ // Output tensors
711
+ const ValueRef r_scale =
712
+ graph.add_tensor (output_sizes, vkapi::kFloat , out_storage);
713
+ const ValueRef r_zero_point =
714
+ graph.add_tensor (output_sizes, vkapi::kInt , out_storage);
715
+
716
+ VK_GET_OP_FN (" choose_qparams_per_token_asymmetric.default" )
717
+ (graph,
718
+ {
719
+ r_input.value ,
720
+ r_scale,
721
+ r_zero_point,
722
+ });
723
+
724
+ ValueRef staging_scale = graph.set_output_tensor (r_scale);
725
+ ValueRef staging_zero_point = graph.set_output_tensor (r_zero_point);
726
+
727
+ graph.prepare ();
728
+ graph.encode_prepack ();
729
+ graph.prepack ();
730
+ graph.encode_execute ();
731
+
732
+ // Run Vulkan choose_qparams_per_token_asymmetric
733
+ graph.copy_into_staging (
734
+ r_input.staging , input.const_data_ptr (), input.numel ());
735
+
736
+ graph.execute ();
737
+
738
+ // Create output tensors to hold the results - use types that match GPU output
739
+ at::Tensor vk_scale =
740
+ at::empty (output_sizes, at::device (at::kCPU ).dtype (at::kFloat ))
741
+ .contiguous ();
742
+ at::Tensor vk_zero_point =
743
+ at::empty (output_sizes, at::device (at::kCPU ).dtype (at::kInt ))
744
+ .contiguous ();
745
+
746
+ // Copy results from GPU to CPU
747
+ graph.copy_from_staging (
748
+ staging_scale, vk_scale.mutable_data_ptr (), vk_scale.numel ());
749
+ graph.copy_from_staging (
750
+ staging_zero_point,
751
+ vk_zero_point.mutable_data_ptr (),
752
+ vk_zero_point.numel ());
753
+
754
+ // Convert reference values to match Vulkan output types for comparison
755
+ at::Tensor reference_scale_float = reference_scale.to (at::kFloat );
756
+ at::Tensor reference_zero_point_int = reference_zero_point.to (at::kInt );
757
+
758
+ // Compare outputs
759
+ const bool scale_correct = at::allclose (reference_scale_float, vk_scale);
760
+ const bool zero_point_correct =
761
+ at::equal (reference_zero_point_int, vk_zero_point);
762
+ if (!scale_correct || !zero_point_correct) {
763
+ std::cout << " \n "
764
+ << " Failed with parameters: " << std::endl;
765
+ std::cout << " storage type: "
766
+ << (in_storage == vkcompute::utils::kBuffer ? " buffer"
767
+ : " texture" )
768
+ << std::endl;
769
+
770
+ if (input.numel () < 100 ) {
771
+ std::cout << " input:" << std::endl;
772
+ std::cout << input << " \n " << std::endl;
773
+ std::cout << " reference scale:" << std::endl;
774
+ std::cout << reference_scale << std::endl;
775
+ std::cout << " vulkan scale:" << std::endl;
776
+ std::cout << vk_scale << " \n " << std::endl;
777
+ std::cout << " reference zero_point:" << std::endl;
778
+ std::cout << reference_zero_point << std::endl;
779
+ std::cout << " vulkan zero_point:" << std::endl;
780
+ std::cout << vk_zero_point << std::endl;
781
+ }
782
+ }
783
+
784
+ ASSERT_TRUE (scale_correct && zero_point_correct);
785
+ }
786
+
787
+ TEST (
788
+ VulkanChooseQparamsTest,
789
+ test_reference_choose_qparams_per_token_asymmetric_int8) {
790
+ test_reference_choose_qparams_per_token_asymmetric (
791
+ {2 , 3 , 4 }, // input sizes (2*3=6 tokens)
792
+ at::kChar );
793
+ }
0 commit comments