@@ -3251,3 +3251,106 @@ TEST(VulkanComputeGraphOpsTest, test_transpose_with_mm) {
3251
3251
test_transpose_view_mm (2 , 7 , 17 , 5 , storage_type);
3252
3252
}
3253
3253
}
3254
+
3255
+ void test_to_copy () {
3256
+ GraphConfig config;
3257
+ config.set_storage_type_override (utils::kTexture3D );
3258
+ ComputeGraph graph (config);
3259
+ int M = 8 ;
3260
+ int N = 8 ;
3261
+ int K = 8 ;
3262
+ // Build graph
3263
+ IOValueRef in = graph.add_input_tensor (
3264
+ {1 , M, N, K},
3265
+ vkapi::kFloat ,
3266
+ utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED);
3267
+
3268
+ std::vector<float > data_in =
3269
+ create_random_float_buffer (M * N * K, -1024 , 1024 );
3270
+ graph.copy_into_staging (in.staging , data_in.data (), data_in.size ());
3271
+
3272
+ IOValueRef out;
3273
+ out.value = graph.add_tensor (
3274
+ {1 , M, N, K},
3275
+ vkapi::kHalf ,
3276
+ utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED);
3277
+
3278
+ auto op = VK_GET_OP_FN (" aten._to_copy.default" );
3279
+ op (graph,
3280
+ {in.value ,
3281
+ graph.add_none (),
3282
+ graph.add_none (),
3283
+ graph.add_none (),
3284
+ graph.add_none (),
3285
+ graph.add_none (),
3286
+ graph.add_none (),
3287
+ out.value });
3288
+
3289
+ out.staging = graph.set_output_tensor (out.value );
3290
+
3291
+ graph.prepare ();
3292
+ graph.encode_prepack ();
3293
+ graph.prepack ();
3294
+ graph.encode_execute ();
3295
+ graph.propagate_resize ();
3296
+ graph.execute ();
3297
+
3298
+ std::vector<torch::executor::Half> output_data (graph.numel_of (out.value ));
3299
+ graph.copy_from_staging (out.staging , output_data.data (), output_data.size ());
3300
+
3301
+ EXPECT_EQ (data_in.size (), output_data.size ());
3302
+
3303
+ float mse_ex = 0 .0f ;
3304
+ float mse_vk = 0 .0f ;
3305
+
3306
+ // check results
3307
+ for (size_t i = 0 ; i < output_data.size (); ++i) {
3308
+ float input = data_in[i];
3309
+ torch::executor::Half expected_output =
3310
+ static_cast <torch::executor::Half>(input);
3311
+ uint16_t * expected_bits = reinterpret_cast <uint16_t *>(&expected_output);
3312
+ torch::executor::Half output = output_data[i];
3313
+ uint16_t * output_bits = reinterpret_cast <uint16_t *>(&output);
3314
+
3315
+ std::cout << " input = " << input << " (0b"
3316
+ << std::bitset<32 >(*reinterpret_cast <uint32_t *>(&input))
3317
+ << " ), expected output = " << expected_output << " (0b"
3318
+ << std::bitset<16 >(*expected_bits)
3319
+ << " ), recieved output = " << output << " (0b"
3320
+ << std::bitset<16 >(*output_bits) << " )" << std::endl;
3321
+
3322
+ // Note: Torch executor half "rounds up" when converting to fp16 whereas
3323
+ // most driver implementations of Vulkan's opFConvert() just truncates the
3324
+ // extra bits for performance (rounding introduces conditional).
3325
+ // Example:
3326
+ // INPUT F32 = 25.248 (sign{0b0}, exp{0b10000011},
3327
+ // mantissa{0b10010011111101111100111}),
3328
+ // TORCH HALF OUTPUT F16 = 25.25 (sign{0b0}, exp{0b10011},
3329
+ // mantissa{0b1001010000}),
3330
+ // VULKAN OUTPUT F16 = 25.2344 (sign{0b0}, exp{0b10011},
3331
+ // mantissa{0b1001001111})
3332
+ // Note:
3333
+ // The vulkan mantissa exactly matches the first 10
3334
+ // bits of the input 23 bit mantissa. But since the 11th bit is 1, the
3335
+ // torch half output is rounded up (essentially adding a 1).
3336
+ // Vulkan mantissa{0b1001001111} + 1 = Torch half mantissa{0b1001010000}
3337
+
3338
+ EXPECT_TRUE (
3339
+ (*output_bits == *expected_bits) ||
3340
+ /* rounding error*/ ((*output_bits + 1u ) == *expected_bits));
3341
+ mse_ex += std::pow (expected_output - input, 2 );
3342
+ mse_vk += std::pow (output - input, 2 );
3343
+ }
3344
+
3345
+ mse_ex /= output_data.size ();
3346
+ mse_vk /= output_data.size ();
3347
+ std::cout << " ========================================================="
3348
+ << std::endl;
3349
+ std::cout << " mse_ex = " << mse_ex << " , mse_vk = " << mse_vk << std::endl;
3350
+ }
3351
+
3352
+ TEST (VulkanComputeGraphOpsTest, test_to_copy) {
3353
+ if (context ()->adapter_ptr ()->has_fp16_support ()) {
3354
+ test_to_copy ();
3355
+ }
3356
+ }
0 commit comments