@@ -232,3 +232,277 @@ vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
232
232
" )" );
233
233
}
234
234
}
235
+
236
+ //
237
+ // Reference Implementation
238
+ //
239
+
240
+ /*
241
+ * Reference implementation of choose_qparams_tensor
242
+ */
243
+ std::tuple<at::Tensor, at::Tensor> choose_qparams_tensor_reference_impl (
244
+ const at::Tensor& input,
245
+ int64_t quant_min,
246
+ int64_t quant_max) {
247
+ // Create output tensors
248
+ at::Tensor scale_out = at::empty ({}, at::device (at::kCPU ).dtype (at::kDouble ));
249
+ at::Tensor zero_point_out =
250
+ at::empty ({}, at::device (at::kCPU ).dtype (at::kLong ));
251
+
252
+ // Find min and max values in the input tensor
253
+ float min_val = input.min ().item <float >();
254
+ float max_val = input.max ().item <float >();
255
+
256
+ // Extend the [min, max] interval to ensure it contains 0
257
+ min_val = std::min (min_val, 0 .f );
258
+ max_val = std::max (max_val, 0 .f );
259
+
260
+ // Calculate scale
261
+ double scale =
262
+ (static_cast <double >(max_val) - min_val) / (quant_max - quant_min);
263
+
264
+ // Handle small scale
265
+ constexpr float SMALL_SCALE_THRESHOLD = 6 .1e-5f ;
266
+ if (float (scale) == 0 .0f || std::isinf (1 .0f / float (scale))) {
267
+ scale = 0.1 ;
268
+ }
269
+
270
+ if (scale < SMALL_SCALE_THRESHOLD) {
271
+ float org_scale = scale;
272
+ scale = SMALL_SCALE_THRESHOLD;
273
+ // Adjust min and max based on new scale
274
+ if (min_val == 0 .0f ) {
275
+ max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
276
+ } else if (max_val == 0 .0f ) {
277
+ min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
278
+ } else {
279
+ float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
280
+ min_val *= amplifier;
281
+ max_val *= amplifier;
282
+ }
283
+ }
284
+
285
+ // Calculate zero point
286
+ double zero_point_from_min = quant_min - min_val / static_cast <double >(scale);
287
+ double zero_point_from_max = quant_max - max_val / static_cast <double >(scale);
288
+ double zero_point_from_min_error =
289
+ std::abs (quant_min) - std::abs (min_val / static_cast <double >(scale));
290
+ double zero_point_from_max_error =
291
+ std::abs (quant_max) - std::abs (max_val / static_cast <double >(scale));
292
+ double initial_zero_point =
293
+ zero_point_from_min_error < zero_point_from_max_error
294
+ ? zero_point_from_min
295
+ : zero_point_from_max;
296
+
297
+ // Nudge zero point to be an integer
298
+ int64_t nudged_zero_point = 0 ;
299
+ if (initial_zero_point < quant_min) {
300
+ nudged_zero_point = quant_min;
301
+ } else if (initial_zero_point > quant_max) {
302
+ nudged_zero_point = quant_max;
303
+ } else {
304
+ nudged_zero_point = std::nearbyint (static_cast <float >(initial_zero_point));
305
+ }
306
+
307
+ // Set output values - use item_mutable() for scalar tensors
308
+ scale_out.fill_ (scale);
309
+ zero_point_out.fill_ (nudged_zero_point);
310
+
311
+ return std::make_tuple (scale_out, zero_point_out);
312
+ }
313
+
314
+ // Forward declaration of implementation functions
315
+ void test_vulkan_choose_qparams_tensor_impl (
316
+ const std::vector<int >& input_sizes,
317
+ int64_t quant_min,
318
+ int64_t quant_max,
319
+ at::ScalarType dtype,
320
+ const vkcompute::utils::StorageType in_storage,
321
+ const vkcompute::utils::StorageType out_storage);
322
+
323
+ // Wrapper function to test both buffer and texture storage types
324
+ void test_vulkan_choose_qparams_tensor (
325
+ const std::vector<int >& input_sizes,
326
+ int64_t quant_min,
327
+ int64_t quant_max,
328
+ at::ScalarType dtype) {
329
+ // Test with buffer storage
330
+ test_vulkan_choose_qparams_tensor_impl (
331
+ input_sizes,
332
+ quant_min,
333
+ quant_max,
334
+ dtype,
335
+ vkcompute::utils::kBuffer ,
336
+ vkcompute::utils::kBuffer );
337
+
338
+ // Test with texture storage
339
+ test_vulkan_choose_qparams_tensor_impl (
340
+ input_sizes,
341
+ quant_min,
342
+ quant_max,
343
+ dtype,
344
+ vkcompute::utils::kTexture3D ,
345
+ vkcompute::utils::kTexture3D );
346
+ }
347
+
348
+ void test_reference_choose_qparams_tensor (
349
+ const std::vector<int >& input_sizes,
350
+ int64_t quant_min,
351
+ int64_t quant_max,
352
+ at::ScalarType dtype) {
353
+ std::vector<int64_t > input_sizes_int64 (
354
+ input_sizes.begin (), input_sizes.end ());
355
+ at::Tensor input =
356
+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
357
+
358
+ // Get reference output
359
+ auto [reference_scale, reference_zero_point] =
360
+ choose_qparams_tensor_reference_impl (input, quant_min, quant_max);
361
+
362
+ // Get implementation output
363
+ auto [impl_scale, impl_zero_point] =
364
+ torch::executor::native::choose_qparams_tensor_aten (
365
+ input, quant_min, quant_max, dtype);
366
+
367
+ // Compare outputs
368
+ const bool scale_correct = at::allclose (reference_scale, impl_scale);
369
+ const bool zero_point_correct =
370
+ at::equal (reference_zero_point, impl_zero_point);
371
+
372
+ if (!scale_correct || !zero_point_correct) {
373
+ std::cout << " \n "
374
+ << " Failed with parameters: " << std::endl;
375
+ std::cout << " quant_min: " << quant_min << std::endl;
376
+ std::cout << " quant_max: " << quant_max << std::endl;
377
+
378
+ std::cout << " input:" << std::endl;
379
+ std::cout << input << std::endl;
380
+ std::cout << " reference scale:" << std::endl;
381
+ std::cout << reference_scale << std::endl;
382
+ std::cout << " implementation scale:" << std::endl;
383
+ std::cout << impl_scale << std::endl;
384
+ std::cout << " reference zero_point:" << std::endl;
385
+ std::cout << reference_zero_point << std::endl;
386
+ std::cout << " implementation zero_point:" << std::endl;
387
+ std::cout << impl_zero_point << std::endl;
388
+ }
389
+
390
+ ASSERT_TRUE (scale_correct && zero_point_correct);
391
+ }
392
+
393
+ void test_vulkan_choose_qparams_tensor_impl (
394
+ const std::vector<int >& input_sizes,
395
+ int64_t quant_min,
396
+ int64_t quant_max,
397
+ at::ScalarType dtype,
398
+ const vkcompute::utils::StorageType in_storage,
399
+ const vkcompute::utils::StorageType out_storage) {
400
+ std::vector<int64_t > input_sizes_int64 (
401
+ input_sizes.begin (), input_sizes.end ());
402
+ at::Tensor input =
403
+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
404
+
405
+ // Get reference output
406
+ auto [reference_scale, reference_zero_point] =
407
+ torch::executor::native::choose_qparams_tensor_aten (
408
+ input, quant_min, quant_max, dtype);
409
+
410
+ // Build Vulkan choose_qparams_tensor graph
411
+ using namespace vkcompute ;
412
+
413
+ GraphConfig config;
414
+ config.set_storage_type_override (in_storage);
415
+ ComputeGraph graph (config);
416
+
417
+ IOValueRef r_input = graph.add_input_tensor (
418
+ input.sizes ().vec (), from_at_scalartype (input.scalar_type ()), in_storage);
419
+
420
+ const ValueRef r_quant_min = graph.add_scalar <int64_t >(quant_min);
421
+ const ValueRef r_quant_max = graph.add_scalar <int64_t >(quant_max);
422
+
423
+ // Output tensors
424
+ const ValueRef r_scale = graph.add_tensor ({}, vkapi::kFloat , out_storage);
425
+ const ValueRef r_zero_point = graph.add_tensor ({}, vkapi::kInt , out_storage);
426
+
427
+ VK_GET_OP_FN (" choose_qparams.tensor" )
428
+ (graph,
429
+ {
430
+ r_input.value ,
431
+ r_quant_min,
432
+ r_quant_max,
433
+ r_scale,
434
+ r_zero_point,
435
+ });
436
+
437
+ ValueRef staging_scale = graph.set_output_tensor (r_scale);
438
+ ValueRef staging_zero_point = graph.set_output_tensor (r_zero_point);
439
+
440
+ graph.prepare ();
441
+ graph.encode_prepack ();
442
+ graph.prepack ();
443
+ graph.encode_execute ();
444
+
445
+ // Run Vulkan choose_qparams_tensor
446
+ graph.copy_into_staging (
447
+ r_input.staging , input.const_data_ptr (), input.numel ());
448
+
449
+ graph.execute ();
450
+
451
+ // Create output tensors to hold the results - use types that match GPU output
452
+ at::Tensor vk_scale =
453
+ at::empty ({}, at::device (at::kCPU ).dtype (at::kFloat )).contiguous ();
454
+ at::Tensor vk_zero_point =
455
+ at::empty ({}, at::device (at::kCPU ).dtype (at::kInt )).contiguous ();
456
+
457
+ // Copy results from GPU to CPU
458
+ graph.copy_from_staging (
459
+ staging_scale, vk_scale.mutable_data_ptr (), vk_scale.numel ());
460
+ graph.copy_from_staging (
461
+ staging_zero_point,
462
+ vk_zero_point.mutable_data_ptr (),
463
+ vk_zero_point.numel ());
464
+
465
+ // Convert reference values to match Vulkan output types for comparison
466
+ at::Tensor reference_scale_float = reference_scale.to (at::kFloat );
467
+ at::Tensor reference_zero_point_int = reference_zero_point.to (at::kInt );
468
+
469
+ // Compare outputs
470
+ const bool scale_correct = at::allclose (reference_scale_float, vk_scale);
471
+ const bool zero_point_correct =
472
+ at::equal (reference_zero_point_int, vk_zero_point);
473
+
474
+ if (!scale_correct || !zero_point_correct) {
475
+ std::cout << " \n "
476
+ << " Failed with parameters: " << std::endl;
477
+ std::cout << " quant_min: " << quant_min << std::endl;
478
+ std::cout << " quant_max: " << quant_max << std::endl;
479
+ std::cout << " storage type: "
480
+ << (in_storage == vkcompute::utils::kBuffer ? " buffer"
481
+ : " texture" )
482
+ << std::endl;
483
+
484
+ // make sure that there arent a ton of elements in the input tensor
485
+ if (input.numel () < 100 ) {
486
+ std::cout << " input:" << std::endl;
487
+ std::cout << input << " \n " << std::endl;
488
+ std::cout << " reference scale:" << std::endl;
489
+ std::cout << reference_scale << std::endl;
490
+ std::cout << " vulkan scale:" << std::endl;
491
+ std::cout << vk_scale << " \n " << std::endl;
492
+ std::cout << " reference zero_point:" << std::endl;
493
+ std::cout << reference_zero_point << std::endl;
494
+ std::cout << " vulkan zero_point:" << std::endl;
495
+ std::cout << vk_zero_point << std::endl;
496
+ }
497
+ }
498
+
499
+ ASSERT_TRUE (scale_correct && zero_point_correct);
500
+ }
501
+
502
+ TEST (VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) {
503
+ test_reference_choose_qparams_tensor (
504
+ {2 , 3 , 4 }, // input sizes
505
+ -128 , // quant_min
506
+ 127 , // quant_max
507
+ at::kChar );
508
+ }
0 commit comments