Skip to content

Commit 8a6ae21

Browse files
committed
Update on "[ET-VK] Clean up shader library and introduce some new conventions"
## Context This changeset introduces some fairly mechnical improvements to the Vulkan compute graph shader library in order to introduce some new conventions. **Note that backwards compatibility with existing shader authoring methods is preserved**. ### Only List `VALUE` in the `.yaml` files Previously, to generate variants for a combination of vales, the YAML file will contain ``` PACKING: - VALUE: CHANNELS_PACKED SUFFIX: C_packed - VALUE: WIDTH_PACKED SUFFIX: W_packed - VALUE: HEIGHT_PACKED SUFFIX: H_packed ``` however, the shader code generation script will use the `VALUE` as the `SUFFIX` if no `SUFFIX` is provided. Therefore, only the below is needed: ``` PACKING: - VALUE: C_packed - VALUE: W_packed - VALUE: H_packed ``` ### Change indexing utility macros to lowercase Indexing utility macros have been changed to lowercase, and the packing identifiers have been changed due to the change in YAML files. The change to lowercase is to make calls to the macro read more like functions (and indeed they are typically used as functions) in order to help make the code more readable. ``` POS_TO_COORD_${PACKING} -> pos_to_coord_${PACKING} ``` ### Use convention of defining macros in order to reduce Python code blocks usage Previously python code blocks were used in the GLSL code itself in order to vary the shader between different settings. However, usage of Python code blocks negatively impact code readability. Therefore, this diff seeks to introduce a convention of defining macros near the top of the shader to reduce the usage of Python code blocks, i.e. ``` #define pos_to_coord pos_to_coord_${PACKING} #define get_packed_dim get_packed_dim_${PACKING} #define get_packed_stride get_packed_stride_${PACKING} ``` ### Improve GLSL type definitions Previously, the following Python code blocks were used to determine appropriate vectorized and scalar types: ``` ${VEC4_T[DTYPE}} texel = ... ${T[DTYPE]} scalar = ... ``` This changeset replaces that with: ``` #define BUF_T ${buffer_scalar_type(DTYPE)} #define VEC4_T ${texel_type(DTYPE)} #define SCALAR_T ${texel_component_type(DTYPE)} layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { BUF_T data[]; } buffer_in; VEC4_T texel = ... SCALAR_T scalar = ... ``` The main differences are as such: * `buffer_scalar_type()` produces the same result as `T[DTYPE]` * `texel_type()` is not determined from a mapping with `DTYPE`, but is determined indirectly based on the image format that is associated with the `DTYPE`. * `texel_component_type()` is based on the result of `texel_type(DTYPE)` Essentially, the mapping is more in-line with what happens in code. The reason for this change is to enable FP16 support and is a bit complicated. Basically, we need a way to distinguish the scalar type used for buffer storage, vs the scalar type used to store a component of a vec4 type (hence `BUF_T` vs `SCALAR_T`). The reason this is required is that to support half-precision tensors, the buffer representation will use a 16-bit float type but textures will still extract to `vec4` (i.e. 4x34bit floats). Differential Revision: [D56082461](https://our.internmc.facebook.com/intern/diff/D56082461/) [ghstack-poisoned]
2 parents 2d454a6 + 858d6fa commit 8a6ae21

File tree

14 files changed

+47
-24
lines changed

14 files changed

+47
-24
lines changed

.ci/docker/ci_commit_pins/pytorch.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0a038cf0cff2d071b7359ac0491fd2ba7798a438
1+
868e5ced5df34f1aef3703654f76e03f5126b534

backends/vulkan/runtime/api/Adapter.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,8 +401,7 @@ std::string Adapter::stringize() const {
401401
ss << " Memory Info {" << std::endl;
402402
ss << " Memory Types [" << std::endl;
403403
for (size_t i = 0; i < mem_props.memoryTypeCount; ++i) {
404-
ss << " "
405-
<< " [Heap " << mem_props.memoryTypes[i].heapIndex << "] "
404+
ss << " " << " [Heap " << mem_props.memoryTypes[i].heapIndex << "] "
406405
<< get_memory_properties_str(mem_props.memoryTypes[i].propertyFlags)
407406
<< std::endl;
408407
}

backends/vulkan/runtime/api/gen_vulkan_spv.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,15 @@
9898
}
9999

100100

101+
def define_variable(name: str) -> str:
102+
if name in locals():
103+
return f"#define {name} {locals()[name]}"
104+
elif name in globals():
105+
return f"#define {name} {globals()[name]}"
106+
else:
107+
raise RuntimeError(f"{name} is not defined")
108+
109+
101110
def get_buffer_scalar_type(dtype: str) -> str:
102111
# TODO(ssjia): use float16_t for half types
103112
if dtype == "half":
@@ -120,6 +129,11 @@ def get_texel_type(dtype: str) -> str:
120129
raise AssertionError(f"Invalid image format: {image_format}")
121130

122131

132+
def get_gvec_type(dtype: str, n: int) -> str:
133+
gvec4_type = get_texel_type(dtype)
134+
return gvec4_type[:-1] + str(n)
135+
136+
123137
def get_texel_component_type(dtype: str) -> str:
124138
vec4_type = get_texel_type(dtype)
125139
if vec4_type[:3] == "vec":
@@ -132,12 +146,14 @@ def get_texel_component_type(dtype: str) -> str:
132146

133147

134148
UTILITY_FNS: Dict[str, Any] = {
149+
"macro_define": define_variable,
135150
"get_pos": {
136151
3: lambda pos: pos,
137152
2: lambda pos: f"{pos}.xy",
138153
},
139154
"buffer_scalar_type": get_buffer_scalar_type,
140155
"texel_type": get_texel_type,
156+
"gvec_type": get_gvec_type,
141157
"texel_component_type": get_texel_component_type,
142158
}
143159

backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ bool OperatorRegistry::has_op(const std::string& name) {
1616

1717
OperatorRegistry::OpFunction& OperatorRegistry::get_op_fn(
1818
const std::string& name) {
19-
return table_.find(name)->second;
19+
const auto it = table_.find(name);
20+
VK_CHECK_COND(it != table_.end(), "Could not find operator with name ", name);
21+
return it->second;
2022
}
2123

2224
void OperatorRegistry::register_op(const std::string& name, OpFunction& fn) {

backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
#version 450 core
1010

1111
#define PRECISION ${PRECISION}
12+
1213
#define op(X, Y, A) ${OPERATOR}
1314

1415
#define VEC4_T ${texel_type(DTYPE)}
16+
1517
#define to_tensor_idx to_tensor_idx_${PACKING}
1618
#define to_texture_pos to_texture_pos_${PACKING}
1719

@@ -59,13 +61,13 @@ void main() {
5961
return;
6062
}
6163

62-
ivec4 in_idx = broadcast(idx, in_sizes.data);
64+
ivec4 in_idx = broadcast_indices(idx, in_sizes.data);
6365
VEC4_T in_texel = VEC4_T(texelFetch(
6466
image_in,
6567
to_texture_pos(in_idx, in_sizes.data),
6668
0));
6769

68-
ivec4 other_idx = broadcast(idx, other_sizes.data);
70+
ivec4 other_idx = broadcast_indices(idx, other_sizes.data);
6971
VEC4_T other_texel = VEC4_T(texelFetch(
7072
image_other,
7173
to_texture_pos(other_idx, other_sizes.data),

backends/vulkan/runtime/graph/ops/glsl/broadcasting_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
ivec4 broadcast(const ivec4 out_idx, const ivec4 in_sizes) {
9+
ivec4 broadcast_indices(const ivec4 out_idx, const ivec4 in_sizes) {
1010
ivec4 in_idx = out_idx;
1111
for (int i = 0; i < 4; ++i) {
1212
if (out_idx[i] >= in_sizes[i]) {

backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ void main() {
7171
ivec2 kstart = (start - ipos) / params.dilation;
7272
// During prepacking, the weight tensor was rearranged in order to optimize
7373
// for data access linearity in this shader. Therefore we need to adjust the
74-
// canonical idxinates to the corresponding index in the rearranged weight
75-
// tensor. The x-idxinate is multipled by 4 since each group of 4 channels
76-
// is folded into the X axis. The y-idxinate is offset based on the z-
77-
// idxinate because the 2D planes were stacked atop each other vertically.
74+
// canonical coordinates to the corresponding index in the rearranged weight
75+
// tensor. The x-coordinate is multipled by 4 since each group of 4 channels
76+
// is folded into the X axis. The y-coordinate is offset based on the z-
77+
// coordinate because the 2D planes were stacked atop each other vertically.
7878
kstart.x *= 4;
7979
kstart.y += pos.z * params.kernel_size.y;
8080

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@
4343
// describe sizes. As an example, let's say we want to swap dimensions 0,1 for a
4444
// tensor of shape {4,3,2,24} to obtain {3,4,2,24}. Then, x=4, y=3 and
4545
// plane=2*24=48.
46-
#define swap_adj_dims(cur, x, y, plane) \
47-
cur + \
48-
plane*( \
49-
(1 - y) * ((cur % (x * y * plane)) / (y * plane)) + \
50-
(x - 1) * ((cur % (y * plane)) / plane))
46+
#define swap_adj_dims(cur, x, y, plane) \
47+
cur + \
48+
plane * \
49+
((1 - y) * ((cur % (x * y * plane)) / (y * plane)) + \
50+
(x - 1) * ((cur % (y * plane)) / plane))
5151

5252
// Kept for backwards compatibility
5353
// TODO(ssjia): remove once there are no shaders that use these macros

docs/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ To build the documentation locally:
5757
```bash
5858
pip3 install -r ./.ci/docker/requirements-ci.txt
5959
```
60+
1. Update submodules
6061

62+
```bash
63+
git submodule sync && git submodule update --init
64+
```
6165
1. Run:
6266

6367
```bash

examples/models/llama2/runner/runner.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,8 +472,7 @@ std::string statsToJsonString(const Runner::Stats& stats) {
472472
<< "\"prompt_eval_end_ms\":" << stats.prompt_eval_end_ms << ","
473473
<< "\"first_token_ms\":" << stats.first_token_ms << ","
474474
<< "\"aggregate_sampling_time_ms\":" << stats.aggregate_sampling_time_ms
475-
<< ","
476-
<< "\"SCALING_FACTOR_UNITS_PER_SECOND\":"
475+
<< "," << "\"SCALING_FACTOR_UNITS_PER_SECOND\":"
477476
<< stats.SCALING_FACTOR_UNITS_PER_SECOND << "}";
478477
return ss.str();
479478
}

kernels/portable/cpu/op_cumsum.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
#include <executorch/runtime/platform/assert.h>
1212
#include <cmath>
1313
#include <cstddef>
14-
//#include <cstdint>
15-
//#include <type_traits>
14+
// #include <cstdint>
15+
// #include <type_traits>
1616

1717
namespace torch {
1818
namespace executor {

runtime/core/portable_type/optional.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ class optional final {
7474
}
7575

7676
optional& operator=(optional&& rhs) noexcept(
77-
std::is_nothrow_move_assignable<T>::value&&
78-
std::is_nothrow_move_constructible<T>::value) {
77+
std::is_nothrow_move_assignable<T>::value &&
78+
std::is_nothrow_move_constructible<T>::value) {
7979
if (init_ && !rhs.init_) {
8080
clear();
8181
} else if (!init_ && rhs.init_) {

sdk/etdump/etdump_flatcc.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ ETDumpGen::ETDumpGen(Span<uint8_t> buffer) {
103103
alloc.set_buffer(
104104
(uint8_t*)buffer_with_builder,
105105
buffer_size,
106-
(size_t)((buffer_size / 4 > max_alloc_buf_size) ? max_alloc_buf_size : buffer_size / 4));
106+
(size_t)((buffer_size / 4 > max_alloc_buf_size) ? max_alloc_buf_size
107+
: buffer_size / 4));
107108
et_flatcc_custom_init(builder, &alloc);
108109
} else {
109110
builder = (struct flatcc_builder*)malloc(sizeof(struct flatcc_builder));

third-party/pytorch

Submodule pytorch updated 589 files

0 commit comments

Comments
 (0)