Skip to content

Commit eb44e88

Browse files
copyrightlyfacebook-github-bot
authored andcommitted
aten.full.default (#3013)
Summary: Pull Request resolved: #3013 We implement [`aten.full.default`](https://pytorch.org/docs/stable/generated/torch.full.html) which has the following signature. ``` func: full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor ``` In order to bypass graph build error, we simply create null value for the following arg types: - torch.device - torch.dtype - torch.layout since they don't have any effect to our operator implementation on Vulkan. (Note that [`torch.layout`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.layout) is a totally different concept from `GPUMemoryLayout` on Vulkan.) Reviewed By: jorgep31415 Differential Revision: D56049674 fbshipit-source-id: dc2a27b4e702829e077e874ccf697f6c4196756d
1 parent 74576e8 commit eb44e88

File tree

9 files changed

+210
-7
lines changed

9 files changed

+210
-7
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5454
exir_ops.edge.aten.native_layer_norm.default,
5555
# Other
5656
operator.getitem,
57+
exir_ops.edge.aten.full.default,
5758
]
5859
return supported
5960

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
#define to_tensor_idx to_tensor_idx_${PACKING}
16+
#define get_packed_dim get_packed_dim_${PACKING}
17+
18+
#include "broadcasting_utils.h"
19+
#include "indexing_utils.h"
20+
21+
layout(std430) buffer;
22+
23+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
24+
25+
layout(set = 0, binding = 1) uniform PRECISION restrict GpuSizes {
26+
ivec4 data;
27+
}
28+
gpu_sizes;
29+
30+
layout(set = 0, binding = 2) uniform PRECISION restrict CpuSizes {
31+
ivec4 data;
32+
}
33+
cpu_sizes;
34+
35+
layout(set = 0, binding = 3) uniform PRECISION restrict FillVal {
36+
float data;
37+
}
38+
fill_value;
39+
40+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
41+
42+
void main() {
43+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
44+
const ivec4 idx = to_tensor_idx(pos, gpu_sizes.data);
45+
46+
if (any(greaterThanEqual(idx, gpu_sizes.data))) {
47+
return;
48+
}
49+
50+
VEC4_T outtex = VEC4_T(fill_value.data);
51+
const int packed_dim_size = get_packed_dim(cpu_sizes.data);
52+
int packed_idx = get_packed_dim(idx);
53+
54+
if (packed_idx + 3 >= packed_dim_size) {
55+
ivec4 packed_ind = ivec4(packed_idx) + ivec4(0, 1, 2, 3);
56+
VEC4_T valid_idx = VEC4_T(lessThan(packed_ind, ivec4(packed_dim_size)));
57+
outtex = outtex * valid_idx;
58+
}
59+
60+
imageStore(image_out, ${get_pos[NDIM]("pos")}, outtex);
61+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
full:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: float
11+
PACKING: C_packed
12+
generate_variant_forall:
13+
DTYPE:
14+
- VALUE: half
15+
- VALUE: float
16+
shader_variants:
17+
- NAME: full
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
13+
14+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
15+
16+
namespace vkcompute {
17+
18+
void resize_full_node(
19+
ComputeGraph* graph,
20+
const std::vector<ArgGroup>& args,
21+
const std::vector<ValueRef>& extra_args) {
22+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
23+
std::vector<int64_t> out_sizes = *graph->get_int_list(extra_args[0]);
24+
25+
out->virtual_resize(out_sizes);
26+
}
27+
28+
void add_full_node(
29+
ComputeGraph& graph,
30+
const ValueRef size,
31+
const ValueRef fill_value,
32+
const ValueRef out) {
33+
float fill_value_val = graph.extract_scalar<float>(fill_value);
34+
vTensorPtr t_out = graph.get_tensor(out);
35+
36+
api::utils::uvec3 global_size = t_out->extents();
37+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
38+
39+
std::string kernel_name("full");
40+
kernel_name.reserve(kShaderNameReserve);
41+
42+
add_dtype_suffix(kernel_name, *t_out);
43+
44+
graph.execute_nodes().emplace_back(new ExecuteNode(
45+
graph,
46+
VK_KERNEL_FROM_STR(kernel_name),
47+
global_size,
48+
local_size,
49+
// Inputs and Outputs
50+
{{out, api::MemoryAccessType::WRITE}},
51+
// Shader params buffers
52+
{t_out->gpu_sizes_ubo(),
53+
t_out->cpu_sizes_ubo(),
54+
graph.create_params_buffer(fill_value_val)},
55+
// Resizing
56+
resize_full_node,
57+
{size}));
58+
}
59+
60+
void full(ComputeGraph& graph, const std::vector<ValueRef>& args) {
61+
return add_full_node(graph, args[0], args[1], args[6]);
62+
}
63+
64+
REGISTER_OPERATORS {
65+
VK_REGISTER_OP(aten.full.default, full);
66+
}
67+
68+
} // namespace vkcompute

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,12 @@ def get_or_create_value_for(self, arg: _Argument):
225225
if arg in self.node_to_value_ids:
226226
return self.node_to_value_ids[arg]
227227
return self.create_node_value(arg)
228-
elif isinstance(arg, NoneType):
228+
elif (
229+
isinstance(arg, NoneType)
230+
or isinstance(arg, torch.device)
231+
or isinstance(arg, torch.dtype)
232+
or isinstance(arg, torch.layout)
233+
):
229234
return self.create_null_value()
230235
elif isinstance(arg, _ScalarType):
231236
return self.create_scalar_value(arg)

backends/vulkan/test/op_tests/cases.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,18 @@ def get_native_layer_norm_inputs():
130130
return test_suite
131131

132132

133+
def get_full_inputs():
134+
test_suite = VkTestSuite(
135+
[
136+
([S1, S2], 42.0),
137+
([M, M1, M2], 3.14),
138+
([L, M, M1, M2], 2.72),
139+
]
140+
)
141+
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
142+
return test_suite
143+
144+
133145
test_suites = {
134146
"aten.add.Tensor": get_binary_elementwise_inputs(),
135147
"aten.sub.Tensor": get_binary_elementwise_inputs(),
@@ -139,6 +151,7 @@ def get_native_layer_norm_inputs():
139151
"aten.max_pool2d_with_indices.default": get_pool2d_inputs(),
140152
"aten.convolution.default": get_conv2d_inputs(),
141153
"aten.native_layer_norm.default": get_native_layer_norm_inputs(),
154+
"aten.full.default": get_full_inputs(),
142155
}
143156

144157
prepacked_args = {"aten.mm.default": {"mat2"}}

backends/vulkan/test/op_tests/utils/codegen.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212
AT_INT_ARRAY_REF,
1313
AT_SCALAR,
1414
AT_TENSOR,
15-
AT_TENSOR_OPT,
1615
BOOL,
1716
CppTestFileGen,
1817
DOUBLE,
1918
INT,
19+
OPT_AT_TENSOR,
20+
OPT_BOOL,
21+
OPT_DEVICE,
22+
OPT_LAYOUT,
23+
OPT_SCALARTYPE,
2024
TestSuite,
2125
TestSuiteGen,
2226
THREE_TENSOR_TUPLE,
@@ -180,7 +184,6 @@ def create_aten_fn_call(self) -> str:
180184
func_call = generate_static_dispatch_backend_call(
181185
self.f_sig, self.f, TestSuiteGen.backend_key
182186
)[7:].replace("::cpu", "")
183-
184187
return func_call
185188

186189
def create_out_src(self) -> str:
@@ -205,7 +208,7 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901
205208

206209
cpp_type = "IOValueRef" if (ref.is_in and not prepack) else "ValueRef"
207210

208-
if ref.src_cpp_type == AT_TENSOR_OPT:
211+
if ref.src_cpp_type == OPT_AT_TENSOR:
209212
ret_str = f"{cpp_type} {ref.name} = "
210213
ret_str += f"!{ref.src_cpp_name}.has_value() ? "
211214
ret_str += f"{self.graph}{self.dot}add_none() : "
@@ -241,6 +244,13 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901
241244
ret_str += f"add_scalar<int64_t>({ref.src_cpp_name}); \n"
242245
elif ref.src_cpp_type == DOUBLE:
243246
ret_str += f"add_scalar<double>({ref.src_cpp_name}); \n"
247+
elif (
248+
ref.src_cpp_type == OPT_SCALARTYPE
249+
or ref.src_cpp_type == OPT_LAYOUT
250+
or ref.src_cpp_type == OPT_DEVICE
251+
or ref.src_cpp_type == OPT_BOOL
252+
):
253+
ret_str += "add_none(); \n"
244254
elif ref.src_cpp_type == TWO_TENSOR_TUPLE:
245255
ret_str += f"add_value_list({{{ref.name}_first, {ref.name}_second}}); \n"
246256
elif ref.src_cpp_type == THREE_TENSOR_TUPLE:
@@ -457,6 +467,7 @@ def gen_parameterization(self) -> str:
457467
#include <tuple>
458468
459469
using namespace vkcompute;
470+
using TensorOptions = at::TensorOptions;
460471
461472
api::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
462473
switch(at_scalartype) {

backends/vulkan/test/op_tests/utils/codegen_base.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@
1818
AT_INT_ARRAY_REF = "at::IntArrayRef"
1919
AT_SCALAR = "at::Scalar"
2020
AT_TENSOR = "at::Tensor"
21-
AT_TENSOR_OPT = "::std::optional<at::Tensor>"
2221
BOOL = "bool"
23-
INT = "int64_t"
2422
DOUBLE = "double"
23+
INT = "int64_t"
24+
OPT_AT_TENSOR = "::std::optional<at::Tensor>"
25+
OPT_BOOL = "::std::optional<bool>"
26+
OPT_DEVICE = "::std::optional<at::Device>"
27+
OPT_LAYOUT = "::std::optional<at::Layout>"
28+
OPT_SCALARTYPE = "::std::optional<at::ScalarType>"
2529
TWO_TENSOR_TUPLE = "::std::tuple<at::Tensor,at::Tensor>"
2630
THREE_TENSOR_TUPLE = "::std::tuple<at::Tensor,at::Tensor,at::Tensor>"
2731

@@ -120,7 +124,7 @@ def create_input_data(self, arg: Argument, data: Any) -> str:
120124

121125
if cpp_type == AT_TENSOR:
122126
ret_str += f"make_rand_tensor({init_list_str(data)}, test_dtype);"
123-
elif cpp_type == AT_TENSOR_OPT:
127+
elif cpp_type == OPT_AT_TENSOR:
124128
if str(data) == "None":
125129
ret_str += "std::nullopt;"
126130
else:
@@ -135,6 +139,13 @@ def create_input_data(self, arg: Argument, data: Any) -> str:
135139
ret_str += f"{str(data).lower()};"
136140
elif cpp_type == DOUBLE:
137141
ret_str += f"{str(data).lower()};"
142+
elif (
143+
cpp_type == OPT_SCALARTYPE
144+
or cpp_type == OPT_LAYOUT
145+
or cpp_type == OPT_DEVICE
146+
or cpp_type == OPT_BOOL
147+
):
148+
ret_str += "std::nullopt;"
138149
else:
139150
raise RuntimeError(f"Unsupported cpp type {cpp_type}")
140151
return ret_str + "\n"

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,3 +665,19 @@ def forward(self, x):
665665
sample_inputs,
666666
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
667667
)
668+
669+
def test_vulkan_backend_full(self):
670+
class FullModule(torch.nn.Module):
671+
def __init__(self):
672+
super().__init__()
673+
674+
def forward(self, x):
675+
return torch.full(x.shape, 42.0)
676+
677+
sample_inputs = (torch.randn(size=(2, 3, 4, 5), dtype=torch.float32),)
678+
679+
self.lower_module_and_test_output(
680+
FullModule(),
681+
sample_inputs,
682+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
683+
)

0 commit comments

Comments
 (0)