@@ -193,6 +193,111 @@ std::tuple<at::Tensor, at::Tensor> choose_qparams_tensor_reference_impl(
193
193
return std::make_tuple (scale_out, zero_point_out);
194
194
}
195
195
196
+ /*
197
+ * Reference implementation of choose_qparams_per_token_asymmetric
198
+ */
199
+ std::tuple<at::Tensor, at::Tensor>
200
+ choose_qparams_per_token_asymmetric_reference_impl (
201
+ const at::Tensor& input,
202
+ at::ScalarType dtype) {
203
+ // For per-token quantization, we need to compute scale and zero_point for
204
+ // each token
205
+ int64_t quant_min = -128 ;
206
+ int64_t quant_max = 127 ;
207
+
208
+ // Calculate output sizes
209
+ std::vector<int64_t > output_sizes;
210
+ for (int64_t i = 0 ; i < input.dim () - 1 ; i++) {
211
+ output_sizes.push_back (input.size (i));
212
+ }
213
+ output_sizes.push_back (1 );
214
+
215
+ // Create output tensors
216
+ at::Tensor scale_out =
217
+ at::empty (output_sizes, at::device (at::kCPU ).dtype (at::kDouble ));
218
+ at::Tensor zero_point_out =
219
+ at::empty (output_sizes, at::device (at::kCPU ).dtype (at::kLong ));
220
+
221
+ // Calculate number of tokens
222
+ int64_t num_tokens = 1 ;
223
+ for (int64_t i = 0 ; i < input.dim () - 1 ; i++) {
224
+ num_tokens *= input.size (i);
225
+ }
226
+
227
+ // Reshape input to [num_tokens, last_dim]
228
+ at::Tensor reshaped_input = input.reshape ({num_tokens, input.size (-1 )});
229
+
230
+ // Process each token
231
+ for (int64_t token_idx = 0 ; token_idx < num_tokens; token_idx++) {
232
+ at::Tensor token = reshaped_input[token_idx];
233
+
234
+ // Find min and max values for this token
235
+ float min_val = token.min ().item <float >();
236
+ float max_val = token.max ().item <float >();
237
+
238
+ // Extend the [min, max] interval to ensure it contains 0
239
+ min_val = std::min (min_val, 0 .f );
240
+ max_val = std::max (max_val, 0 .f );
241
+
242
+ // Calculate scale
243
+ double scale =
244
+ (static_cast <double >(max_val) - min_val) / (quant_max - quant_min);
245
+
246
+ // Handle small scale
247
+ constexpr float SMALL_SCALE_THRESHOLD = 6 .1e-5f ;
248
+ if (float (scale) == 0 .0f || std::isinf (1 .0f / float (scale))) {
249
+ scale = 0.1 ;
250
+ }
251
+
252
+ if (scale < SMALL_SCALE_THRESHOLD) {
253
+ float org_scale = scale;
254
+ scale = SMALL_SCALE_THRESHOLD;
255
+ // Adjust min and max based on new scale
256
+ if (min_val == 0 .0f ) {
257
+ max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
258
+ } else if (max_val == 0 .0f ) {
259
+ min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
260
+ } else {
261
+ float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
262
+ min_val *= amplifier;
263
+ max_val *= amplifier;
264
+ }
265
+ }
266
+
267
+ // Calculate zero point
268
+ double zero_point_from_min =
269
+ quant_min - min_val / static_cast <double >(scale);
270
+ double zero_point_from_max =
271
+ quant_max - max_val / static_cast <double >(scale);
272
+ double zero_point_from_min_error =
273
+ std::abs (quant_min) - std::abs (min_val / static_cast <double >(scale));
274
+ double zero_point_from_max_error =
275
+ std::abs (quant_max) - std::abs (max_val / static_cast <double >(scale));
276
+ double initial_zero_point =
277
+ zero_point_from_min_error < zero_point_from_max_error
278
+ ? zero_point_from_min
279
+ : zero_point_from_max;
280
+
281
+ // Nudge zero point to be an integer
282
+ int64_t nudged_zero_point = 0 ;
283
+ if (initial_zero_point < quant_min) {
284
+ nudged_zero_point = quant_min;
285
+ } else if (initial_zero_point > quant_max) {
286
+ nudged_zero_point = quant_max;
287
+ } else {
288
+ nudged_zero_point =
289
+ std::nearbyint (static_cast <float >(initial_zero_point));
290
+ }
291
+
292
+ // Set output values for this token - use index_put_ for safety
293
+ scale_out.view ({num_tokens, 1 }).index_put_ ({token_idx, 0 }, scale);
294
+ zero_point_out.view ({num_tokens, 1 })
295
+ .index_put_ ({token_idx, 0 }, nudged_zero_point);
296
+ }
297
+
298
+ return std::make_tuple (scale_out, zero_point_out);
299
+ }
300
+
196
301
// Forward declaration of implementation functions
197
302
void test_vulkan_choose_qparams_tensor_impl (
198
303
const std::vector<int >& input_sizes,
@@ -202,6 +307,12 @@ void test_vulkan_choose_qparams_tensor_impl(
202
307
const vkcompute::utils::StorageType in_storage,
203
308
const vkcompute::utils::StorageType out_storage);
204
309
310
+ void test_vulkan_choose_qparams_per_token_asymmetric_impl (
311
+ const std::vector<int >& input_sizes,
312
+ at::ScalarType dtype,
313
+ const vkcompute::utils::StorageType in_storage,
314
+ const vkcompute::utils::StorageType out_storage);
315
+
205
316
// Wrapper function to test both buffer and texture storage types
206
317
void test_vulkan_choose_qparams_tensor (
207
318
const std::vector<int >& input_sizes,
@@ -227,6 +338,22 @@ void test_vulkan_choose_qparams_tensor(
227
338
vkcompute::utils::kTexture3D );
228
339
}
229
340
341
+ // Wrapper function to test both buffer and texture storage types
342
+ void test_vulkan_choose_qparams_per_token_asymmetric (
343
+ const std::vector<int >& input_sizes,
344
+ at::ScalarType dtype) {
345
+ // Test with buffer storage
346
+ test_vulkan_choose_qparams_per_token_asymmetric_impl (
347
+ input_sizes, dtype, vkcompute::utils::kBuffer , vkcompute::utils::kBuffer );
348
+
349
+ // Test with texture storage
350
+ test_vulkan_choose_qparams_per_token_asymmetric_impl (
351
+ input_sizes,
352
+ dtype,
353
+ vkcompute::utils::kTexture3D ,
354
+ vkcompute::utils::kTexture3D );
355
+ }
356
+
230
357
void test_reference_choose_qparams_tensor (
231
358
const std::vector<int >& input_sizes,
232
359
int64_t quant_min,
@@ -388,3 +515,161 @@ TEST(VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) {
388
515
127 , // quant_max
389
516
at::kChar );
390
517
}
518
+
519
+ void test_reference_choose_qparams_per_token_asymmetric (
520
+ const std::vector<int >& input_sizes,
521
+ at::ScalarType dtype) {
522
+ std::vector<int64_t > input_sizes_int64 (
523
+ input_sizes.begin (), input_sizes.end ());
524
+ at::Tensor input =
525
+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
526
+
527
+ // Get reference output
528
+ auto [reference_scale, reference_zero_point] =
529
+ choose_qparams_per_token_asymmetric_reference_impl (input, dtype);
530
+
531
+ // Get implementation output
532
+ auto [impl_scale, impl_zero_point] =
533
+ torch::executor::native::choose_qparams_per_token_asymmetric_aten (
534
+ input, dtype);
535
+
536
+ // Compare outputs
537
+ const bool scale_correct = at::allclose (reference_scale, impl_scale);
538
+ const bool zero_point_correct =
539
+ at::equal (reference_zero_point, impl_zero_point);
540
+
541
+ if (!scale_correct || !zero_point_correct) {
542
+ std::cout << " \n "
543
+ << " Failed with parameters: " << std::endl;
544
+
545
+ std::cout << " input:" << std::endl;
546
+ std::cout << input << std::endl;
547
+ std::cout << " reference scale:" << std::endl;
548
+ std::cout << reference_scale << std::endl;
549
+ std::cout << " implementation scale:" << std::endl;
550
+ std::cout << impl_scale << std::endl;
551
+ std::cout << " reference zero_point:" << std::endl;
552
+ std::cout << reference_zero_point << std::endl;
553
+ std::cout << " implementation zero_point:" << std::endl;
554
+ std::cout << impl_zero_point << std::endl;
555
+ }
556
+
557
+ ASSERT_TRUE (scale_correct && zero_point_correct);
558
+ }
559
+
560
+ void test_vulkan_choose_qparams_per_token_asymmetric_impl (
561
+ const std::vector<int >& input_sizes,
562
+ at::ScalarType dtype,
563
+ const vkcompute::utils::StorageType in_storage,
564
+ const vkcompute::utils::StorageType out_storage) {
565
+ std::vector<int64_t > input_sizes_int64 (
566
+ input_sizes.begin (), input_sizes.end ());
567
+ at::Tensor input =
568
+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
569
+
570
+ // Calculate output sizes
571
+ std::vector<int64_t > output_sizes;
572
+ for (int64_t i = 0 ; i < input.dim () - 1 ; i++) {
573
+ output_sizes.push_back (input.size (i));
574
+ }
575
+ output_sizes.push_back (1 );
576
+
577
+ // Get reference output
578
+ auto [reference_scale, reference_zero_point] =
579
+ torch::executor::native::choose_qparams_per_token_asymmetric_aten (
580
+ input, dtype);
581
+
582
+ // Build Vulkan choose_qparams_per_token_asymmetric graph
583
+ using namespace vkcompute ;
584
+
585
+ GraphConfig config;
586
+ config.set_storage_type_override (in_storage);
587
+ ComputeGraph graph (config);
588
+
589
+ IOValueRef r_input = graph.add_input_tensor (
590
+ input.sizes ().vec (), from_at_scalartype (input.scalar_type ()), in_storage);
591
+
592
+ // Output tensors
593
+ const ValueRef r_scale =
594
+ graph.add_tensor (output_sizes, vkapi::kFloat , out_storage);
595
+ const ValueRef r_zero_point =
596
+ graph.add_tensor (output_sizes, vkapi::kInt , out_storage);
597
+
598
+ VK_GET_OP_FN (" choose_qparams_per_token_asymmetric.default" )
599
+ (graph,
600
+ {
601
+ r_input.value ,
602
+ r_scale,
603
+ r_zero_point,
604
+ });
605
+
606
+ ValueRef staging_scale = graph.set_output_tensor (r_scale);
607
+ ValueRef staging_zero_point = graph.set_output_tensor (r_zero_point);
608
+
609
+ graph.prepare ();
610
+ graph.encode_prepack ();
611
+ graph.prepack ();
612
+ graph.encode_execute ();
613
+
614
+ // Run Vulkan choose_qparams_per_token_asymmetric
615
+ graph.copy_into_staging (
616
+ r_input.staging , input.const_data_ptr (), input.numel ());
617
+
618
+ graph.execute ();
619
+
620
+ // Create output tensors to hold the results - use types that match GPU output
621
+ at::Tensor vk_scale =
622
+ at::empty (output_sizes, at::device (at::kCPU ).dtype (at::kFloat ))
623
+ .contiguous ();
624
+ at::Tensor vk_zero_point =
625
+ at::empty (output_sizes, at::device (at::kCPU ).dtype (at::kInt ))
626
+ .contiguous ();
627
+
628
+ // Copy results from GPU to CPU
629
+ graph.copy_from_staging (
630
+ staging_scale, vk_scale.mutable_data_ptr (), vk_scale.numel ());
631
+ graph.copy_from_staging (
632
+ staging_zero_point,
633
+ vk_zero_point.mutable_data_ptr (),
634
+ vk_zero_point.numel ());
635
+
636
+ // Convert reference values to match Vulkan output types for comparison
637
+ at::Tensor reference_scale_float = reference_scale.to (at::kFloat );
638
+ at::Tensor reference_zero_point_int = reference_zero_point.to (at::kInt );
639
+
640
+ // Compare outputs
641
+ const bool scale_correct = at::allclose (reference_scale_float, vk_scale);
642
+ const bool zero_point_correct =
643
+ at::equal (reference_zero_point_int, vk_zero_point);
644
+ if (!scale_correct || !zero_point_correct) {
645
+ std::cout << " \n "
646
+ << " Failed with parameters: " << std::endl;
647
+ std::cout << " storage type: "
648
+ << (in_storage == vkcompute::utils::kBuffer ? " buffer"
649
+ : " texture" )
650
+ << std::endl;
651
+
652
+ if (input.numel () < 100 ) {
653
+ std::cout << " input:" << std::endl;
654
+ std::cout << input << " \n " << std::endl;
655
+ std::cout << " reference scale:" << std::endl;
656
+ std::cout << reference_scale << std::endl;
657
+ std::cout << " vulkan scale:" << std::endl;
658
+ std::cout << vk_scale << " \n " << std::endl;
659
+ std::cout << " reference zero_point:" << std::endl;
660
+ std::cout << reference_zero_point << std::endl;
661
+ std::cout << " vulkan zero_point:" << std::endl;
662
+ std::cout << vk_zero_point << std::endl;
663
+ }
664
+ }
665
+
666
+ ASSERT_TRUE (scale_correct && zero_point_correct);
667
+ }
668
+
669
+ TEST (
670
+ VulkanChooseQparamsTest,
671
+ test_reference_choose_qparams_per_token_asymmetric_int8) {
672
+ test_reference_choose_qparams_per_token_asymmetric (
673
+ {2 , 3 , 4 }, // input sizes (2*3=6 tokens)
674
+ at::kChar );
675
+ }
0 commit comments