Skip to content

Commit e08f2cd

Browse files
Abhi-hppfacebook-github-bot
authored andcommitted
Aten _To_Copy (#6055)
Summary: Implement aten._to_copy. Currently we are only interested in fp32 <-> fp16 conversions, but it should theoritically support other dtype conversions too. I noticed an issue with int conversions so limited it to just fp32 and fp16 for now. Note: Most driver implementations of fp16 cast does not "round up" the result, therefore there might be 1 bit difference between vulkan output and cpu torch.to. Explained in greater detail in the comments. Reviewed By: SS-JIA Differential Revision: D64080303
1 parent 8957dc8 commit e08f2cd

File tree

4 files changed

+169
-0
lines changed

4 files changed

+169
-0
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __contains__(self, op):
7777
exir_ops.edge.aten.sin.default,
7878
exir_ops.edge.aten.sqrt.default,
7979
exir_ops.edge.aten.tanh.default,
80+
exir_ops.edge.aten._to_copy.default,
8081
# Matrix Multiplication
8182
exir_ops.edge.aten.bmm.default,
8283
exir_ops.edge.aten.mm.default,

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> bool:
144144

145145
return False
146146

147+
def is_valid_to_copy(self, node: torch.fx.node) -> bool: # pyre-ignore[11]
148+
# lower only if floating point dtype conversion
149+
return len(node.args) > 1 and node.args[1] in (torch.float32, torch.float16)
150+
147151
def is_node_supported(
148152
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
149153
) -> bool:
@@ -172,6 +176,11 @@ def _is_node_supported(
172176

173177
features = VulkanSupportedOperators._ops[target]
174178

179+
if target == exir_ops.edge.aten._to_copy.default and not self.is_valid_to_copy(
180+
node
181+
):
182+
return False
183+
175184
if self.require_dynamic_shapes and not features.supports_dynamic_shape:
176185
return False
177186

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/BlitNode.h>
10+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
13+
#include <set>
14+
15+
namespace vkcompute {
16+
17+
void resize_to_copy_op_node(
18+
ComputeGraph* graph,
19+
const std::vector<ArgGroup>& args,
20+
const std::vector<ValueRef>& extra_args) {
21+
(void)extra_args;
22+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
23+
vTensorPtr self = graph->get_tensor(args[1].refs[0]);
24+
25+
out->virtual_resize(self->sizes());
26+
}
27+
28+
void add_to_copy_node(ComputeGraph& graph, ValueRef in, ValueRef out) {
29+
static std::set<vkapi::ScalarType> supported_types = {
30+
vkapi::ScalarType::Float, vkapi::ScalarType::Half};
31+
32+
VK_CHECK_COND(
33+
supported_types.find(graph.dtype_of(in)) != supported_types.end() &&
34+
supported_types.find(graph.dtype_of(out)) != supported_types.end(),
35+
"Unsupported dtype for to_copy, only Float and Half are currently supported, recieved ",
36+
vkapi::to_string(graph.dtype_of(in)),
37+
" <-> ",
38+
vkapi::to_string(graph.dtype_of(out)));
39+
40+
graph.execute_nodes().emplace_back(
41+
new BlitNode(graph, prepack_if_tensor_ref(graph, in), out));
42+
}
43+
44+
void to_copy(ComputeGraph& graph, const std::vector<ValueRef>& args) {
45+
return add_to_copy_node(graph, args[0], args[7]);
46+
}
47+
48+
REGISTER_OPERATORS {
49+
VK_REGISTER_OP(aten._to_copy.default, to_copy);
50+
}
51+
} // namespace vkcompute

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <gtest/gtest.h>
1010

11+
#include <bitset>
1112
#include <utility>
1213
#include <vector>
1314

@@ -3251,3 +3252,110 @@ TEST(VulkanComputeGraphOpsTest, test_transpose_with_mm) {
32513252
test_transpose_view_mm(2, 7, 17, 5, storage_type);
32523253
}
32533254
}
3255+
3256+
void test_to_copy() {
3257+
GraphConfig config;
3258+
config.set_storage_type_override(utils::kTexture3D);
3259+
ComputeGraph graph(config);
3260+
int M = 8;
3261+
int N = 8;
3262+
int K = 8;
3263+
// Build graph
3264+
IOValueRef in = graph.add_input_tensor(
3265+
{1, M, N, K},
3266+
vkapi::kFloat,
3267+
utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED);
3268+
3269+
std::vector<float> data_in =
3270+
create_random_float_buffer(M * N * K, -1024, 1024);
3271+
graph.copy_into_staging(in.staging, data_in.data(), data_in.size());
3272+
3273+
IOValueRef out;
3274+
out.value = graph.add_tensor(
3275+
{1, M, N, K},
3276+
vkapi::kHalf,
3277+
utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED);
3278+
3279+
auto op = VK_GET_OP_FN("aten._to_copy.default");
3280+
op(graph,
3281+
{in.value,
3282+
graph.add_none(),
3283+
graph.add_none(),
3284+
graph.add_none(),
3285+
graph.add_none(),
3286+
graph.add_none(),
3287+
graph.add_none(),
3288+
out.value});
3289+
3290+
out.staging = graph.set_output_tensor(out.value);
3291+
3292+
graph.prepare();
3293+
graph.encode_prepack();
3294+
graph.prepack();
3295+
graph.encode_execute();
3296+
graph.propagate_resize();
3297+
graph.execute();
3298+
3299+
std::vector<torch::executor::Half> output_data(graph.numel_of(out.value));
3300+
graph.copy_from_staging(out.staging, output_data.data(), output_data.size());
3301+
3302+
EXPECT_EQ(data_in.size(), output_data.size());
3303+
3304+
float mse_ex = 0.0f;
3305+
float mse_vk = 0.0f;
3306+
3307+
// check results
3308+
for (size_t i = 0; i < output_data.size(); ++i) {
3309+
float input = data_in[i];
3310+
torch::executor::Half expected_output =
3311+
static_cast<torch::executor::Half>(input);
3312+
uint16_t* expected_bits = reinterpret_cast<uint16_t*>(&expected_output);
3313+
torch::executor::Half output = output_data[i];
3314+
uint16_t* output_bits = reinterpret_cast<uint16_t*>(&output);
3315+
3316+
std::string msg;
3317+
msg.reserve(64);
3318+
msg = "input = " + std::to_string(input) + "(0b" +
3319+
std::bitset<32>(*reinterpret_cast<uint32_t*>(&input)).to_string() +
3320+
"), expected output = " + std::to_string(expected_output) + "(0b" +
3321+
std::bitset<16>(*expected_bits).to_string() +
3322+
"), recieved output = " + std::to_string(output) + "(0b" +
3323+
std::bitset<16>(*output_bits).to_string() + ")";
3324+
3325+
std::cout << msg << std::endl;
3326+
3327+
// Note: Torch executor half "rounds up" when converting to fp16 whereas
3328+
// most driver implementations of Vulkan's opFConvert() just truncates the
3329+
// extra bits for performance (rounding introduces conditional).
3330+
// Example:
3331+
// INPUT F32 = 25.248 (sign{0b0}, exp{0b10000011},
3332+
// mantissa{0b10010011111101111100111}),
3333+
// TORCH HALF OUTPUT F16 = 25.25 (sign{0b0}, exp{0b10011},
3334+
// mantissa{0b1001010000}),
3335+
// VULKAN OUTPUT F16 = 25.2344 (sign{0b0}, exp{0b10011},
3336+
// mantissa{0b1001001111})
3337+
// Note:
3338+
// The vulkan mantissa exactly matches the first 10
3339+
// bits of the input 23 bit mantissa. But since the 11th bit is 1, the
3340+
// torch half output is rounded up (essentially adding a 1).
3341+
// Vulkan mantissa{0b1001001111} + 1 = Torch half mantissa{0b1001010000}
3342+
3343+
EXPECT_TRUE(
3344+
(*output_bits == *expected_bits) ||
3345+
/*rounding error*/ ((*output_bits + 1u) == *expected_bits));
3346+
mse_ex += std::pow(expected_output - input, 2);
3347+
mse_vk += std::pow(output - input, 2);
3348+
}
3349+
3350+
mse_ex /= output_data.size();
3351+
mse_vk /= output_data.size();
3352+
std::cout << "========================================================="
3353+
<< std::endl;
3354+
std::cout << "mse_ex = " << mse_ex << ", mse_vk = " << mse_vk << std::endl;
3355+
}
3356+
3357+
TEST(VulkanComputeGraphOpsTest, test_to_copy) {
3358+
if (context()->adapter_ptr()->has_16bit_storage()) {
3359+
test_to_copy();
3360+
}
3361+
}

0 commit comments

Comments
 (0)