Skip to content

Commit ea0c137

Browse files
Abhi-hppfacebook-github-bot
authored andcommitted
Aten _To_Copy (#6055)
Summary: Implement aten._to_copy. Currently we are only interested in fp32 <-> fp16 conversions, but it should theoritically support other dtype conversions too. I noticed an issue with int conversions so limited it to just fp32 and fp16 for now. Note: Most driver implementations of fp16 cast does not "round up" the result, therefore there might be 1 bit difference between vulkan output and cpu torch.to. Explained in greater detail in the comments. Reviewed By: SS-JIA Differential Revision: D64080303
1 parent e95aa9d commit ea0c137

File tree

4 files changed

+163
-0
lines changed

4 files changed

+163
-0
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __contains__(self, op):
7070
exir_ops.edge.aten.sin.default,
7171
exir_ops.edge.aten.sqrt.default,
7272
exir_ops.edge.aten.tanh.default,
73+
exir_ops.edge.aten._to_copy.default,
7374
# Matrix Multiplication
7475
exir_ops.edge.aten.bmm.default,
7576
exir_ops.edge.aten.mm.default,

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> bool:
144144

145145
return False
146146

147+
def is_valid_to_copy(self, node: torch.fx.node) -> bool:
148+
# lower only if floating point dtype conversion
149+
return len(node.args) > 1 and node.args[1] in (torch.float32, torch.float16)
150+
147151
def is_node_supported(
148152
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
149153
) -> bool:
@@ -172,6 +176,9 @@ def _is_node_supported(
172176

173177
features = VulkanSupportedOperators._ops[target]
174178

179+
if target == exir_ops.edge.aten._to_copy.default and not self.is_valid_to_copy(node):
180+
return False
181+
175182
if self.require_dynamic_shapes and not features.supports_dynamic_shape:
176183
return False
177184

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/BlitNode.h>
10+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
13+
#include <set>
14+
15+
namespace vkcompute {
16+
17+
void resize_to_copy_op_node(
18+
ComputeGraph* graph,
19+
const std::vector<ArgGroup>& args,
20+
const std::vector<ValueRef>& extra_args) {
21+
(void)extra_args;
22+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
23+
vTensorPtr self = graph->get_tensor(args[1].refs[0]);
24+
25+
out->virtual_resize(self->sizes());
26+
}
27+
28+
void add_to_copy_node(ComputeGraph& graph, ValueRef in, ValueRef out) {
29+
static std::set<vkapi::ScalarType> supported_types = {
30+
vkapi::ScalarType::Float, vkapi::ScalarType::Half};
31+
32+
VK_CHECK_COND(
33+
supported_types.find(graph.dtype_of(in)) != supported_types.end() &&
34+
supported_types.find(graph.dtype_of(out)) != supported_types.end(),
35+
"Unsupported dtype for to_copy, only Float and Half are currently supported, recieved ", vkapi::to_string(graph.dtype_of(in)), " <-> ", vkapi::to_string(graph.dtype_of(out)));
36+
37+
graph.execute_nodes().emplace_back(
38+
new BlitNode(graph, prepack_if_tensor_ref(graph, in), out));
39+
}
40+
41+
void to_copy(ComputeGraph& graph, const std::vector<ValueRef>& args) {
42+
return add_to_copy_node(graph, args[0], args[7]);
43+
}
44+
45+
REGISTER_OPERATORS {
46+
VK_REGISTER_OP(aten._to_copy.default, to_copy);
47+
}
48+
} // namespace vkcompute

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3251,3 +3251,110 @@ TEST(VulkanComputeGraphOpsTest, test_transpose_with_mm) {
32513251
test_transpose_view_mm(2, 7, 17, 5, storage_type);
32523252
}
32533253
}
3254+
3255+
void test_to_copy() {
3256+
GraphConfig config;
3257+
config.set_storage_type_override(utils::kTexture3D);
3258+
ComputeGraph graph(config);
3259+
int M = 8;
3260+
int N = 8;
3261+
int K = 8;
3262+
// Build graph
3263+
IOValueRef in = graph.add_input_tensor(
3264+
{1, M, N, K},
3265+
vkapi::kFloat,
3266+
utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED);
3267+
3268+
std::vector<float> data_in =
3269+
create_random_float_buffer(M * N * K, -1024, 1024);
3270+
graph.copy_into_staging(in.staging, data_in.data(), data_in.size());
3271+
3272+
IOValueRef out;
3273+
out.value = graph.add_tensor(
3274+
{1, M, N, K},
3275+
vkapi::kHalf,
3276+
utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED);
3277+
3278+
auto op = VK_GET_OP_FN("aten._to_copy.default");
3279+
op(graph,
3280+
{in.value,
3281+
graph.add_none(),
3282+
graph.add_none(),
3283+
graph.add_none(),
3284+
graph.add_none(),
3285+
graph.add_none(),
3286+
graph.add_none(),
3287+
out.value});
3288+
3289+
out.staging = graph.set_output_tensor(out.value);
3290+
3291+
graph.prepare();
3292+
graph.encode_prepack();
3293+
graph.prepack();
3294+
graph.encode_execute();
3295+
graph.propagate_resize();
3296+
graph.execute();
3297+
3298+
std::vector<torch::executor::Half> output_data(graph.numel_of(out.value));
3299+
graph.copy_from_staging(out.staging, output_data.data(), output_data.size());
3300+
3301+
EXPECT_EQ(data_in.size(), output_data.size());
3302+
3303+
float mse_ex = 0.0f;
3304+
float mse_vk = 0.0f;
3305+
3306+
// check results
3307+
for (size_t i = 0; i < output_data.size(); ++i) {
3308+
float input = data_in[i];
3309+
torch::executor::Half expected_output =
3310+
static_cast<torch::executor::Half>(input);
3311+
uint16_t* expected_bits = reinterpret_cast<uint16_t*>(&expected_output);
3312+
torch::executor::Half output = output_data[i];
3313+
uint16_t* output_bits = reinterpret_cast<uint16_t*>(&output);
3314+
3315+
std::string msg;
3316+
msg.reserve(64);
3317+
msg = "input = " + std::to_string(input) + "(0b"
3318+
+ std::bitset<32>(*reinterpret_cast<uint32_t*>(&input)).to_string()
3319+
+ "), expected output = " + std::to_string(expected_output) +"(0b"
3320+
+ std::bitset<16>(*expected_bits).to_string()
3321+
+ "), recieved output = " + std::to_string(output) + "(0b"
3322+
+ std::bitset<16>(*output_bits).to_string() + ")";
3323+
3324+
std::cout << msg<< std::endl;
3325+
3326+
// Note: Torch executor half "rounds up" when converting to fp16 whereas
3327+
// most driver implementations of Vulkan's opFConvert() just truncates the
3328+
// extra bits for performance (rounding introduces conditional).
3329+
// Example:
3330+
// INPUT F32 = 25.248 (sign{0b0}, exp{0b10000011},
3331+
// mantissa{0b10010011111101111100111}),
3332+
// TORCH HALF OUTPUT F16 = 25.25 (sign{0b0}, exp{0b10011},
3333+
// mantissa{0b1001010000}),
3334+
// VULKAN OUTPUT F16 = 25.2344 (sign{0b0}, exp{0b10011},
3335+
// mantissa{0b1001001111})
3336+
// Note:
3337+
// The vulkan mantissa exactly matches the first 10
3338+
// bits of the input 23 bit mantissa. But since the 11th bit is 1, the
3339+
// torch half output is rounded up (essentially adding a 1).
3340+
// Vulkan mantissa{0b1001001111} + 1 = Torch half mantissa{0b1001010000}
3341+
3342+
EXPECT_TRUE(
3343+
(*output_bits == *expected_bits) ||
3344+
/*rounding error*/ ((*output_bits + 1u) == *expected_bits));
3345+
mse_ex += std::pow(expected_output - input, 2);
3346+
mse_vk += std::pow(output - input, 2);
3347+
}
3348+
3349+
mse_ex /= output_data.size();
3350+
mse_vk /= output_data.size();
3351+
std::cout << "========================================================="
3352+
<< std::endl;
3353+
std::cout << "mse_ex = " << mse_ex << ", mse_vk = " << mse_vk << std::endl;
3354+
}
3355+
3356+
TEST(VulkanComputeGraphOpsTest, test_to_copy) {
3357+
if(context()->adapter_ptr()->has_16bit_storage()) {
3358+
test_to_copy();
3359+
}
3360+
}

0 commit comments

Comments
 (0)