Skip to content

Commit ee02c32

Browse files
committed
[ET-VK][12/n] aten.cat with more codegen
Pull Request resolved: #3388 1. The `aten.cat` operation are very straightforward using the `copy_*_node`. 2. Complexity comes from the code-gen. We need to introduce a `AT_TENSOR_LIST` type, which contains a list of `AT_TENSOR` with `is_in=True`. The tensor list itself as a container is not an `IOValueRef`, but the element inside are. It leads to some ugly if-then-else in the codegen. ghstack-source-id: 224249802 Differential Revision: [D56626865](https://our.internmc.facebook.com/intern/diff/D56626865/)
1 parent 49cf7f2 commit ee02c32

File tree

7 files changed

+301
-24
lines changed

7 files changed

+301
-24
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Copy.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
16+
17+
namespace vkcompute {
18+
19+
void add_cat_default_node(
20+
ComputeGraph& graph,
21+
ValueRef in_list_ref,
22+
ValueRef dim_ref,
23+
ValueRef out) {
24+
ValueListPtr input_list = graph.get_value_list(in_list_ref);
25+
26+
for (ValueRef input_ref : *input_list) {
27+
vTensorPtr t_in = graph.get_tensor(input_ref);
28+
VK_CHECK_COND(check_memory_layout_is(*t_in, api::kChannelsPacked));
29+
}
30+
31+
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
32+
vTensorPtr t_out = graph.get_tensor(out);
33+
34+
NchwDim nchw_dim = normalize_to_nchw_dim(*t_out, dim);
35+
36+
// TODO: Find ways to factor out the similar code for width, height, and batch
37+
if (nchw_dim == DimWidth) {
38+
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
39+
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
40+
41+
for (ValueRef input_ref : *input_list) {
42+
vTensorPtr t_in = graph.get_tensor(input_ref);
43+
api::utils::ivec3 range = t_in->texture_limits();
44+
add_copy_offset_node(
45+
graph, input_ref, range, src_offset, dst_offset, out);
46+
dst_offset.data[0] += range.data[0];
47+
}
48+
49+
} else if (nchw_dim == DimHeight) {
50+
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
51+
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
52+
53+
for (ValueRef input_ref : *input_list) {
54+
vTensorPtr t_in = graph.get_tensor(input_ref);
55+
api::utils::ivec3 range = t_in->texture_limits();
56+
add_copy_offset_node(
57+
graph, input_ref, range, src_offset, dst_offset, out);
58+
dst_offset.data[1] += range.data[1];
59+
}
60+
} else if (nchw_dim == DimBatch) {
61+
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
62+
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
63+
64+
for (ValueRef input_ref : *input_list) {
65+
vTensorPtr t_in = graph.get_tensor(input_ref);
66+
api::utils::ivec3 range = t_in->texture_limits();
67+
add_copy_offset_node(
68+
graph, input_ref, range, src_offset, dst_offset, out);
69+
dst_offset.data[2] += range.data[2];
70+
}
71+
} else if (nchw_dim == DimChannel) {
72+
int32_t src_offset = 0;
73+
int32_t dst_offset = 0;
74+
75+
for (ValueRef input_ref : *input_list) {
76+
vTensorPtr t_in = graph.get_tensor(input_ref);
77+
int32_t range = dim_at<Dim4D::Channel>(t_in->sizes());
78+
add_copy_channel_offset_node(
79+
graph, input_ref, range, src_offset, dst_offset, out);
80+
dst_offset += range;
81+
}
82+
} else {
83+
VK_THROW("Unexpected value of nchw_dim=", nchw_dim);
84+
}
85+
}
86+
87+
void cat_default(ComputeGraph& graph, const std::vector<ValueRef>& args) {
88+
add_cat_default_node(graph, args[0], args[1], args[2]);
89+
}
90+
91+
REGISTER_OPERATORS {
92+
VK_REGISTER_OP(aten.cat.default, cat_default);
93+
}
94+
95+
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Copy.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,23 @@ void add_copy_channel_offset_node(
9393

9494
VK_CHECK_COND(
9595
dim_at<Dim4D::Channel>(in_sizes) >= src_channel_offset + channel_range,
96-
"Source channel plus range should be less than or equal to input tensor's channel size");
96+
"Src channel (",
97+
src_channel_offset,
98+
") and range (",
99+
channel_range,
100+
") should be less than or equal to input tensor's channel size (",
101+
dim_at<Dim4D::Channel>(in_sizes),
102+
")");
103+
97104
VK_CHECK_COND(
98105
dim_at<Dim4D::Channel>(out_sizes) >= dst_channel_offset + channel_range,
99-
"Source channel and range should be less than or equal to input tensor's channel size");
106+
"Dst channel (",
107+
dst_channel_offset,
108+
") and range (",
109+
channel_range,
110+
") should be less than or equal to input tensor's channel size (",
111+
dim_at<Dim4D::Channel>(out_sizes),
112+
")");
100113

101114
VK_CHECK_COND(channel_range >= 0, "Channel range must be non-negative");
102115
VK_CHECK_COND(

backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,50 @@ uint32_t dim_at(const vTensor& v_in) {
7070
return dim_at<N>(v_in.sizes());
7171
}
7272

73+
// A canonical way to represent dimensions as enum. Intended to use the same
74+
// value as Dim4D for potential future refactoring.
75+
76+
enum NchwDim {
77+
DimWidth = 1,
78+
DimHeight = 2,
79+
DimChannel = 3,
80+
DimBatch = 4,
81+
};
82+
83+
/* This function return a NchwDim
84+
* given a Tensor and a user provided dim. The reason for this normalization is
85+
* that in the user tensor coordinate, it is using a "big-endian" mechanism when
86+
* referring to a nchw dimension, in that dim=0 refers to the batch dimension in
87+
* a 4d tensor but dim=0 reference to height in a 2d tensor. Despite in a common
88+
* texture representation of channel packing, a 2d tensor has exactly the same
89+
* layout as a 4d with the batch and channel size equals to 1. This function
90+
* returns a canonical dimension to simplify dimension reasoning in the code.
91+
*
92+
*/
93+
94+
inline NchwDim normalize_to_nchw_dim(const vTensor& v_in, int32_t dim) {
95+
return static_cast<NchwDim>(v_in.dim() - dim);
96+
}
97+
98+
inline std::ostream& operator<<(std::ostream& os, NchwDim nchw_dim) {
99+
switch (nchw_dim) {
100+
case DimWidth:
101+
os << "DimWidth";
102+
break;
103+
case DimHeight:
104+
os << "DimHeight";
105+
break;
106+
case DimChannel:
107+
os << "DimChannel";
108+
break;
109+
case DimBatch:
110+
os << "DimBatch";
111+
break;
112+
default:
113+
os << "DimUnknown";
114+
break;
115+
}
116+
return os;
117+
}
118+
73119
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,52 @@ def get_repeat_inputs():
428428
return test_suite
429429

430430

431+
def get_cat_inputs():
432+
# TensorList must be specified as list of tuples
433+
test_suite = VkTestSuite(
434+
[
435+
# Cat on Height
436+
([(S1, S1, 3, 5), (S1, S1, 4, 5)], 2),
437+
([(S1, 3, 5), (S1, 4, 5)], 1),
438+
([(3, 5), (4, 5)], 0),
439+
([(3, 5), (4, 5), (1, 5)], 0),
440+
(
441+
[
442+
(3, 5),
443+
],
444+
0,
445+
),
446+
# Cat on Width
447+
([(S1, S1, 5, 3), (S1, S1, 5, 4)], 3),
448+
([(S1, 5, 3), (S1, 5, 4)], 2),
449+
([(5, 3), (5, 4)], 1),
450+
([(5, 3), (5, 4), (5, 1)], 1),
451+
(
452+
[
453+
(5, 4),
454+
],
455+
1,
456+
),
457+
([(5,), (6,)], 0),
458+
# Cat on Batch
459+
([(S, S1, 5, 4), (S1, S1, 5, 4)], 0),
460+
([(S, XS, 5, 4), (S1, XS, 5, 4)], 0),
461+
([(S, S2, 5, 4), (S1, S2, 5, 4)], 0),
462+
# Cat on Channel
463+
([(S, 5, 4), (S1, 5, 4), (S2, 5, 4)], 0),
464+
([(XS, 5, 4), (XS, 5, 4), (S2, 5, 4)], 0),
465+
([(XS, S, 5, 4), (XS, S1, 5, 4), (XS, S2, 5, 4)], 1),
466+
([(XS, XS, 5, 4), (XS, XS, 5, 4), (XS, S2, 5, 4)], 1),
467+
]
468+
)
469+
test_suite.layouts = [
470+
"api::kChannelsPacked",
471+
]
472+
test_suite.data_gen = "make_seq_tensor"
473+
test_suite.dtypes = ["at::kFloat"]
474+
return test_suite
475+
476+
431477
test_suites = {
432478
"aten.add.Tensor": get_binary_elementwise_inputs(),
433479
"aten.sub.Tensor": get_binary_elementwise_inputs(),
@@ -447,4 +493,5 @@ def get_repeat_inputs():
447493
"aten.unsqueeze_copy.default": get_unsqueeze_inputs(),
448494
"aten.clone.default": get_clone_inputs(),
449495
"aten.repeat.default": get_repeat_inputs(),
496+
"aten.cat.default": get_cat_inputs(),
450497
}

backends/vulkan/test/op_tests/generate_op_tests.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
TestSuite,
1717
TestSuiteGen,
1818
)
19+
from torchgen import local
1920

2021
from torchgen.gen import parse_native_yaml, ParsedYaml
2122
from torchgen.model import DispatchKey, NativeFunction
@@ -45,6 +46,9 @@ def process_test_suites(
4546
cpp_generator.add_suite(registry_name, f, op_test_suite)
4647

4748

49+
@local.parametrize(
50+
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
51+
)
4852
def generate_cpp(
4953
native_functions_yaml_path: str, tags_path: str, output_dir: str
5054
) -> None:

backends/vulkan/test/op_tests/utils/codegen.py

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
AT_INT_ARRAY_REF,
1313
AT_SCALAR,
1414
AT_TENSOR,
15+
AT_TENSOR_LIST,
1516
BOOL,
1617
CppTestFileGen,
1718
DOUBLE,
@@ -28,6 +29,7 @@
2829
THREE_TENSOR_TUPLE,
2930
TWO_TENSOR_TUPLE,
3031
)
32+
3133
from torchgen.api import cpp
3234
from torchgen.api.types import CppSignatureGroup
3335

@@ -75,6 +77,8 @@ class ValueRef:
7577

7678
ValueRefList = Union[ValueRef, List[ValueRef]]
7779

80+
InableCppType = frozenset([AT_TENSOR, AT_TENSOR_LIST])
81+
7882

7983
class ComputeGraphGen:
8084
def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite):
@@ -114,7 +118,7 @@ def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite):
114118
name=f"{arg.name}_ref",
115119
src_cpp_name=arg.name,
116120
src_cpp_type=cpp_type,
117-
is_in=(cpp_type == AT_TENSOR),
121+
is_in=(cpp_type in InableCppType),
118122
requires_prepack=requires_prepack,
119123
supports_prepack=supports_prepack,
120124
)
@@ -244,6 +248,25 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901
244248
ret_str += f"{self.graph}{self.dot}add_scalar<int64_t>"
245249
ret_str += f"({ref.src_cpp_name}.value());\n"
246250
return ret_str
251+
elif ref.src_cpp_type == AT_TENSOR_LIST:
252+
assert ref.is_in, "AT_TENSOR_LIST must be an input"
253+
# This logic is a bit convoluted. We need to create a IOValueRef for
254+
# each tensor, to facilate staging. On the other hand, we will
255+
# use the .value tensor to create a ValueList, which will be passed
256+
# to the corresponding ops.
257+
ret_str = f"std::vector<IOValueRef> {ref.name}_io_value_refs;\n"
258+
ret_str += f"std::vector<ValueRef> {ref.name}_value_refs;\n"
259+
ret_str += f"for (int i=0; i < {ref.src_cpp_name}.size(); i++) {{\n"
260+
ret_str += f" {cpp_type} io_value_ref = {self.graph}{self.dot}add_input_tensor(\n"
261+
ret_str += f" {ref.src_cpp_name}[i].sizes().vec(),\n"
262+
ret_str += (
263+
f" from_at_scalartype({ref.src_cpp_name}[i].scalar_type())); \n"
264+
)
265+
ret_str += f" {ref.name}_value_refs.emplace_back(io_value_ref.value);\n"
266+
ret_str += f" {ref.name}_io_value_refs.emplace_back(io_value_ref);\n"
267+
ret_str += "}\n"
268+
ret_str += f"ValueRef {ref.name} = {self.graph}{self.dot}add_value_list(std::move({ref.name}_value_refs));\n"
269+
return ret_str
247270

248271
ret_str = f"{cpp_type} {ref.name} = {self.graph}{self.dot}"
249272
if ref.src_cpp_type == AT_TENSOR and not prepack:
@@ -288,11 +311,16 @@ def create_op_call(self) -> str:
288311

289312
for aten_arg in self.args:
290313
ref = self.refs[aten_arg.name]
291-
op_create_code += (
292-
f"{ref.name}.value, "
293-
if (ref.is_in and not self.prepack_ref(ref)) or ref.is_out
294-
else f"{ref.name}, "
295-
)
314+
if ref.src_cpp_type == AT_TENSOR_LIST:
315+
# Special case. Underlying tensors are input tensors, but the
316+
# container itself is just a normal value.
317+
op_create_code += f"{ref.name}, "
318+
else:
319+
op_create_code += (
320+
f"{ref.name}.value, "
321+
if (ref.is_in and not self.prepack_ref(ref)) or ref.is_out
322+
else f"{ref.name}, "
323+
)
296324

297325
op_create_code += "out_ref});\n"
298326
return op_create_code
@@ -311,22 +339,46 @@ def set_output(self, ref: ValueRefList) -> str:
311339

312340
def virtual_resize(self, ref: ValueRefList) -> str:
313341
assert isinstance(ref, ValueRef)
314-
assert ref.src_cpp_type == AT_TENSOR and ref.is_in
342+
assert ref.src_cpp_type in InableCppType and ref.is_in
315343
if self.prepack_ref(ref):
316344
return ""
317-
ret_str = f"{self.graph}{self.dot}get_tensor({ref.name}.value)"
318-
ret_str += f"->virtual_resize({ref.src_cpp_name}.sizes().vec());\n"
345+
346+
if ref.src_cpp_type == AT_TENSOR:
347+
ret_str = f"{self.graph}{self.dot}get_tensor({ref.name}.value)"
348+
ret_str += f"->virtual_resize({ref.src_cpp_name}.sizes().vec());\n"
349+
elif ref.src_cpp_type == AT_TENSOR_LIST:
350+
ret_str = ""
351+
ret_str += f"for (int i=0; i < {ref.name}_io_value_refs.size(); i++) {{\n"
352+
ret_str += f" {self.graph}{self.dot}get_tensor({ref.name}_io_value_refs[i].value)"
353+
ret_str += f"->virtual_resize({ref.src_cpp_name}[i].sizes().vec());\n"
354+
ret_str += "}\n"
355+
else:
356+
raise AssertionError(f"{ref.src_cpp_type} not expected")
357+
319358
return ret_str
320359

321360
def copy_into_staging(self, ref: ValueRefList) -> str:
322361
assert isinstance(ref, ValueRef)
323-
assert ref.src_cpp_type == AT_TENSOR and ref.is_in
362+
assert ref.src_cpp_type in InableCppType and ref.is_in
363+
324364
if self.prepack_ref(ref):
325365
return ""
326-
ret_str = f"{self.graph}{self.dot}copy_into_staging("
327-
ret_str += f"{ref.name}.staging, "
328-
ret_str += f"{ref.src_cpp_name}.const_data_ptr(), "
329-
ret_str += f"{ref.src_cpp_name}.numel());\n"
366+
367+
if ref.src_cpp_type == AT_TENSOR:
368+
ret_str = f"{self.graph}{self.dot}copy_into_staging("
369+
ret_str += f"{ref.name}.staging, "
370+
ret_str += f"{ref.src_cpp_name}.const_data_ptr(), "
371+
ret_str += f"{ref.src_cpp_name}.numel());\n"
372+
elif ref.src_cpp_type == AT_TENSOR_LIST:
373+
ret_str = ""
374+
ret_str += f"for (int i=0; i < {ref.name}_io_value_refs.size(); i++) {{\n"
375+
ret_str += f" {self.graph}{self.dot}copy_into_staging("
376+
ret_str += f"{ref.name}_io_value_refs[i].staging, "
377+
ret_str += f"{ref.src_cpp_name}[i].const_data_ptr(), "
378+
ret_str += f"{ref.src_cpp_name}[i].numel());\n"
379+
ret_str += "}\n"
380+
else:
381+
raise AssertionError(f"{ref.src_cpp_type} not expected")
330382
return ret_str
331383

332384
def declare_vk_out_for(self, ref: Union[ValueRef, List[ValueRef]]) -> str:
@@ -547,8 +599,10 @@ def gen_parameterization(self) -> str:
547599
if (!is_close && t1.numel() < 500) {
548600
std::cout << "reference: " << std::endl;
549601
print(t1, 150);
602+
std::cout << std::endl;
550603
std::cout << "vulkan: " << std::endl;
551604
print(t2, 150);
605+
std::cout << std::endl;
552606
}
553607
return is_close;
554608
}

0 commit comments

Comments
 (0)