@@ -3251,3 +3251,110 @@ TEST(VulkanComputeGraphOpsTest, test_transpose_with_mm) {
3251
3251
test_transpose_view_mm (2 , 7 , 17 , 5 , storage_type);
3252
3252
}
3253
3253
}
3254
+
3255
+ void test_to_copy () {
3256
+ GraphConfig config;
3257
+ config.set_storage_type_override (utils::kTexture3D );
3258
+ ComputeGraph graph (config);
3259
+ int M = 8 ;
3260
+ int N = 8 ;
3261
+ int K = 8 ;
3262
+ // Build graph
3263
+ IOValueRef in = graph.add_input_tensor (
3264
+ {1 , M, N, K},
3265
+ vkapi::kFloat ,
3266
+ utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED);
3267
+
3268
+ std::vector<float > data_in =
3269
+ create_random_float_buffer (M * N * K, -1024 , 1024 );
3270
+ graph.copy_into_staging (in.staging , data_in.data (), data_in.size ());
3271
+
3272
+ IOValueRef out;
3273
+ out.value = graph.add_tensor (
3274
+ {1 , M, N, K},
3275
+ vkapi::kHalf ,
3276
+ utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED);
3277
+
3278
+ auto op = VK_GET_OP_FN (" aten._to_copy.default" );
3279
+ op (graph,
3280
+ {in.value ,
3281
+ graph.add_none (),
3282
+ graph.add_none (),
3283
+ graph.add_none (),
3284
+ graph.add_none (),
3285
+ graph.add_none (),
3286
+ graph.add_none (),
3287
+ out.value });
3288
+
3289
+ out.staging = graph.set_output_tensor (out.value );
3290
+
3291
+ graph.prepare ();
3292
+ graph.encode_prepack ();
3293
+ graph.prepack ();
3294
+ graph.encode_execute ();
3295
+ graph.propagate_resize ();
3296
+ graph.execute ();
3297
+
3298
+ std::vector<torch::executor::Half> output_data (graph.numel_of (out.value ));
3299
+ graph.copy_from_staging (out.staging , output_data.data (), output_data.size ());
3300
+
3301
+ EXPECT_EQ (data_in.size (), output_data.size ());
3302
+
3303
+ float mse_ex = 0 .0f ;
3304
+ float mse_vk = 0 .0f ;
3305
+
3306
+ // check results
3307
+ for (size_t i = 0 ; i < output_data.size (); ++i) {
3308
+ float input = data_in[i];
3309
+ torch::executor::Half expected_output =
3310
+ static_cast <torch::executor::Half>(input);
3311
+ uint16_t * expected_bits = reinterpret_cast <uint16_t *>(&expected_output);
3312
+ torch::executor::Half output = output_data[i];
3313
+ uint16_t * output_bits = reinterpret_cast <uint16_t *>(&output);
3314
+
3315
+ std::string msg;
3316
+ msg.reserve (64 );
3317
+ msg = " input = " + std::to_string (input) + " (0b"
3318
+ + std::bitset<32 >(*reinterpret_cast <uint32_t *>(&input)).to_string ()
3319
+ + " ), expected output = " + std::to_string (expected_output) +" (0b"
3320
+ + std::bitset<16 >(*expected_bits).to_string ()
3321
+ + " ), recieved output = " + std::to_string (output) + " (0b"
3322
+ + std::bitset<16 >(*output_bits).to_string () + " )" ;
3323
+
3324
+ std::cout << msg<< std::endl;
3325
+
3326
+ // Note: Torch executor half "rounds up" when converting to fp16 whereas
3327
+ // most driver implementations of Vulkan's opFConvert() just truncates the
3328
+ // extra bits for performance (rounding introduces conditional).
3329
+ // Example:
3330
+ // INPUT F32 = 25.248 (sign{0b0}, exp{0b10000011},
3331
+ // mantissa{0b10010011111101111100111}),
3332
+ // TORCH HALF OUTPUT F16 = 25.25 (sign{0b0}, exp{0b10011},
3333
+ // mantissa{0b1001010000}),
3334
+ // VULKAN OUTPUT F16 = 25.2344 (sign{0b0}, exp{0b10011},
3335
+ // mantissa{0b1001001111})
3336
+ // Note:
3337
+ // The vulkan mantissa exactly matches the first 10
3338
+ // bits of the input 23 bit mantissa. But since the 11th bit is 1, the
3339
+ // torch half output is rounded up (essentially adding a 1).
3340
+ // Vulkan mantissa{0b1001001111} + 1 = Torch half mantissa{0b1001010000}
3341
+
3342
+ EXPECT_TRUE (
3343
+ (*output_bits == *expected_bits) ||
3344
+ /* rounding error*/ ((*output_bits + 1u ) == *expected_bits));
3345
+ mse_ex += std::pow (expected_output - input, 2 );
3346
+ mse_vk += std::pow (output - input, 2 );
3347
+ }
3348
+
3349
+ mse_ex /= output_data.size ();
3350
+ mse_vk /= output_data.size ();
3351
+ std::cout << " ========================================================="
3352
+ << std::endl;
3353
+ std::cout << " mse_ex = " << mse_ex << " , mse_vk = " << mse_vk << std::endl;
3354
+ }
3355
+
3356
+ TEST (VulkanComputeGraphOpsTest, test_to_copy) {
3357
+ if (context ()->adapter_ptr ()->has_16bit_storage ()) {
3358
+ test_to_copy ();
3359
+ }
3360
+ }
0 commit comments