Skip to content

Commit 9a89ac9

Browse files
copyrightlyfacebook-github-bot
authored andcommitted
aten.full.default
Summary: We implement [`aten.full.default`](https://pytorch.org/docs/stable/generated/torch.full.html) which has the following signature. https://www.internalfb.com/code/fbsource/[8db4b5872791bb88a62ecaa60b667ee4c1b189bf]/fbcode/caffe2/aten/src/ATen/native/native_functions.yaml?lines=2801 In order to bypass graph build error, we simply create null value for the following arg types: - torch.device - torch.dtype - torch.layout since they don't have any effect to our operator implementation on Vulkan. (Note that [`torch.layout`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.layout) is a totally different concept from `GPUMemoryLayout` on Vulkan.) Differential Revision: D56049674
1 parent 6f47383 commit 9a89ac9

File tree

6 files changed

+150
-1
lines changed

6 files changed

+150
-1
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5454
exir_ops.edge.aten.native_layer_norm.default,
5555
# Other
5656
operator.getitem,
57+
exir_ops.edge.aten.full.default,
5758
]
5859
return supported
5960

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#include "broadcasting_utils.h"
12+
#include "indexing_utils.h"
13+
14+
#define PRECISION ${PRECISION}
15+
16+
layout(std430) buffer;
17+
18+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
19+
20+
layout(set = 0, binding = 1) uniform PRECISION restrict OutExtents {
21+
uvec4 data;
22+
}
23+
out_extents;
24+
25+
layout(set = 0, binding = 2) uniform PRECISION restrict FillVal {
26+
float data;
27+
}
28+
fill_value;
29+
30+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
31+
32+
void main() {
33+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
34+
35+
if (any(greaterThanEqual(pos, out_extents.data.xyz))) {
36+
return;
37+
}
38+
39+
imageStore(image_out, pos, vec4(fill_value.data));
40+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
full:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: float
11+
PACKING: CHANNELS_PACKED
12+
generate_variant_forall:
13+
DTYPE:
14+
- VALUE: half
15+
SUFFIX: half
16+
- VALUE: float
17+
SUFFIX: float
18+
shader_variants:
19+
- NAME: full
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15+
16+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17+
18+
namespace vkcompute {
19+
20+
void resize_full_node(
21+
ComputeGraph* graph,
22+
const std::vector<ArgGroup>& args,
23+
const std::vector<ValueRef>& extra_args) {
24+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
25+
std::vector<int64_t> out_sizes = *graph->get_int_list(extra_args[0]);
26+
27+
out->virtual_resize(out_sizes);
28+
}
29+
30+
void add_full_node(
31+
ComputeGraph& graph,
32+
const ValueRef size,
33+
const ValueRef fill_value,
34+
const ValueRef out) {
35+
float fill_value_val = graph.extract_scalar<float>(fill_value);
36+
vTensorPtr t_out = graph.get_tensor(out);
37+
38+
api::utils::uvec3 global_size = t_out->extents();
39+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
40+
41+
std::string kernel_name("full");
42+
kernel_name.reserve(kShaderNameReserve);
43+
44+
add_dtype_suffix(kernel_name, *t_out);
45+
46+
graph.execute_nodes().emplace_back(new ExecuteNode(
47+
graph,
48+
VK_KERNEL_FROM_STR(kernel_name),
49+
global_size,
50+
local_size,
51+
// Inputs and Outputs
52+
{{out, api::MemoryAccessType::WRITE}},
53+
// Shader params buffers
54+
{t_out->extents_ubo(), graph.create_params_buffer(fill_value_val)},
55+
// Resizing
56+
resize_full_node,
57+
{size}));
58+
}
59+
60+
void full(ComputeGraph& graph, const std::vector<ValueRef>& args) {
61+
return add_full_node(graph, args[0], args[1], args[6]);
62+
}
63+
64+
REGISTER_OPERATORS {
65+
VK_REGISTER_OP(aten.full.default, full);
66+
}
67+
68+
} // namespace vkcompute

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,12 @@ def get_or_create_value_for(self, arg: _Argument):
225225
if arg in self.node_to_value_ids:
226226
return self.node_to_value_ids[arg]
227227
return self.create_node_value(arg)
228-
elif isinstance(arg, NoneType):
228+
elif (
229+
isinstance(arg, NoneType)
230+
or isinstance(arg, torch.device)
231+
or isinstance(arg, torch.dtype)
232+
or isinstance(arg, torch.layout)
233+
):
229234
return self.create_null_value()
230235
elif isinstance(arg, _ScalarType):
231236
return self.create_scalar_value(arg)

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,3 +665,19 @@ def forward(self, x):
665665
sample_inputs,
666666
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
667667
)
668+
669+
def test_vulkan_backend_full(self):
670+
class FullModule(torch.nn.Module):
671+
def __init__(self):
672+
super().__init__()
673+
674+
def forward(self, x):
675+
return torch.full(x.shape, 42.0)
676+
677+
sample_inputs = (torch.randn(size=(2, 3, 4, 5), dtype=torch.float32),)
678+
679+
self.lower_module_and_test_output(
680+
FullModule(),
681+
sample_inputs,
682+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
683+
)

0 commit comments

Comments
 (0)