Skip to content

Commit dda3e9e

Browse files
committed
[ET-VK] Clean up shader library and introduce some new conventions
Pull Request resolved: #3024 ## Context This changeset introduces some fairly mechnical improvements to the Vulkan compute graph shader library in order to introduce some new conventions. **Note that backwards compatibility with existing shader authoring methods is preserved**. ### Only List `VALUE` in the `.yaml` files Previously, to generate variants for a combination of vales, the YAML file will contain ``` PACKING: - VALUE: CHANNELS_PACKED SUFFIX: C_packed - VALUE: WIDTH_PACKED SUFFIX: W_packed - VALUE: HEIGHT_PACKED SUFFIX: H_packed ``` however, the shader code generation script will use the `VALUE` as the `SUFFIX` if no `SUFFIX` is provided. Therefore, only the below is needed: ``` PACKING: - VALUE: C_packed - VALUE: W_packed - VALUE: H_packed ``` ### Change indexing utility macros to lowercase Indexing utility macros have been changed to lowercase, and the packing identifiers have been changed due to the change in YAML files. The change to lowercase is to make calls to the macro read more like functions (and indeed they are typically used as functions) in order to help make the code more readable. ``` POS_TO_COORD_${PACKING} -> pos_to_coord_${PACKING} ``` ### Use convention of defining macros in order to reduce Python code blocks usage Previously python code blocks were used in the GLSL code itself in order to vary the shader between different settings. However, usage of Python code blocks negatively impact code readability. Therefore, this diff seeks to introduce a convention of defining macros near the top of the shader to reduce the usage of Python code blocks, i.e. ``` #define pos_to_coord pos_to_coord_${PACKING} #define get_packed_dim get_packed_dim_${PACKING} #define get_packed_stride get_packed_stride_${PACKING} ``` ### Improve GLSL type definitions Previously, the following Python code blocks were used to determine appropriate vectorized and scalar types: ``` ${VEC4_T[DTYPE}} texel = ... ${T[DTYPE]} scalar = ... ``` This changeset replaces that with: ``` #define BUF_T ${buffer_scalar_type(DTYPE)} #define VEC4_T ${texel_type(DTYPE)} #define SCALAR_T ${texel_component_type(DTYPE)} layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { BUF_T data[]; } buffer_in; VEC4_T texel = ... SCALAR_T scalar = ... ``` The main differences are as such: * `buffer_scalar_type()` produces the same result as `T[DTYPE]` * `texel_type()` is not determined from a mapping with `DTYPE`, but is determined indirectly based on the image format that is associated with the `DTYPE`. * `texel_component_type()` is based on the result of `texel_type(DTYPE)` Essentially, the mapping is more in-line with what happens in code. The reason for this change is to enable FP16 support and is a bit complicated. Basically, we need a way to distinguish the scalar type used for buffer storage, vs the scalar type used to store a component of a vec4 type (hence `BUF_T` vs `SCALAR_T`). The reason this is required is that to support half-precision tensors, the buffer representation will use a 16-bit float type but textures will still extract to `vec4` (i.e. 4x34bit floats). ghstack-source-id: 222551445 Differential Revision: [D56082461](https://our.internmc.facebook.com/intern/diff/D56082461/)
1 parent c61ef44 commit dda3e9e

40 files changed

+285
-274
lines changed

backends/vulkan/runtime/api/gen_vulkan_spv.py

Lines changed: 68 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,13 @@
3434
CPP_H_NAME = "spv.h"
3535
CPP_SRC_NAME = "spv.cpp"
3636

37+
# Basic configuration settings for shaders
3738
DEFAULT_ENV: Dict[str, Any] = {
3839
"PRECISION": "highp",
39-
"FLOAT_IMAGE_FORMAT": "rgba16f",
40-
"INT_IMAGE_FORMAT": "rgba32i",
41-
"UINT_IMAGE_FORMAT": "rgba32ui",
4240
}
4341

44-
TYPES_ENV: Dict[str, Any] = {
45-
"IMAGE_FORMAT": {
46-
"float": "rgba32f",
47-
"half": "rgba16f",
48-
"int": "rgba32i",
49-
"uint": "rgba32ui",
50-
"int8": "rgba8i",
51-
"uint8": "rgba8ui",
52-
},
42+
# Establishes relationships between different tensor types and different GLSL types
43+
TYPE_MAPPINGS: Dict[str, Any] = {
5344
"IMAGE_T": {
5445
3: {
5546
"float": "image3D",
@@ -78,29 +69,74 @@
7869
"uint": "usampler2D",
7970
},
8071
},
81-
"VEC4_T": {
82-
"float": "vec4",
83-
"half": "vec4",
84-
"int": "ivec4",
85-
"uint": "uvec4",
86-
"int8": "vec4",
87-
"uint8": "uvec4",
88-
},
89-
"T": {
90-
"float": "float",
91-
"half": "float",
92-
"int": "int",
93-
"uint": "uint",
94-
"int8": "int",
95-
"uint8": "uint8",
72+
"IMAGE_FORMAT": {
73+
"float": "rgba32f",
74+
"half": "rgba16f",
75+
"int": "rgba32i",
76+
"uint": "rgba32ui",
77+
"int8": "rgba8i",
78+
"uint8": "rgba8ui",
9679
},
9780
}
9881

99-
FUNCS_ENV: Dict[str, Any] = {
100-
"GET_POS": {
82+
83+
def define_variable(name: str) -> str:
84+
if name in locals():
85+
return f"#define {name} {locals()[name]}"
86+
elif name in globals():
87+
return f"#define {name} {globals()[name]}"
88+
else:
89+
raise RuntimeError(f"{name} is not defined")
90+
91+
92+
def get_buffer_scalar_type(dtype: str) -> str:
93+
# TODO(ssjia): use float16_t for half types
94+
if dtype == "half":
95+
return "float"
96+
# TODO(ssjia): use int8_t for int8 types
97+
elif dtype[-1] == "8":
98+
return dtype[:-1]
99+
100+
return dtype
101+
102+
103+
def get_texel_type(dtype: str) -> str:
104+
image_format = TYPE_MAPPINGS["IMAGE_FORMAT"][dtype]
105+
if image_format[-1] == "f":
106+
return "vec4"
107+
elif image_format[-2] == "ui":
108+
return "uvec4"
109+
elif image_format[-1] == "i":
110+
return "ivec4"
111+
raise AssertionError(f"Invalid image format: {image_format}")
112+
113+
114+
def get_gvec_type(dtype: str, n: int) -> str:
115+
gvec4_type = get_texel_type(dtype)
116+
return gvec4_type[:-1] + str(n)
117+
118+
119+
def get_texel_component_type(dtype: str) -> str:
120+
vec4_type = get_texel_type(dtype)
121+
if vec4_type[:3] == "vec":
122+
return "float"
123+
elif vec4_type[:4] == "ivec":
124+
return "int"
125+
elif vec4_type[:4] == "uvec":
126+
return "uint"
127+
raise AssertionError(f"Invalid vec4 type: {vec4_type}")
128+
129+
130+
UTILITY_FNS: Dict[str, Any] = {
131+
"macro_define": define_variable,
132+
"get_pos": {
101133
3: lambda pos: pos,
102134
2: lambda pos: f"{pos}.xy",
103-
}
135+
},
136+
"buffer_scalar_type": get_buffer_scalar_type,
137+
"texel_type": get_texel_type,
138+
"gvec_type": get_gvec_type,
139+
"texel_component_type": get_texel_component_type,
104140
}
105141

106142

@@ -376,26 +412,6 @@ def create_shader_params(
376412
for key, value in variant_params.items():
377413
shader_params[key] = value
378414

379-
shader_dtype = shader_params.get("DTYPE", "float")
380-
381-
if shader_dtype == "int":
382-
shader_params["FORMAT"] = self.env["INT_IMAGE_FORMAT"]
383-
elif shader_dtype == "uint":
384-
shader_params["FORMAT"] = self.env["UINT_IMAGE_FORMAT"]
385-
elif shader_dtype == "int32":
386-
shader_params["FORMAT"] = "rgba32i"
387-
elif shader_dtype == "uint32":
388-
shader_params["FORMAT"] = "rgba32ui"
389-
elif shader_dtype == "int8":
390-
shader_params["FORMAT"] = "rgba8i"
391-
elif shader_dtype == "uint8":
392-
shader_params["FORMAT"] = "rgba8ui"
393-
elif shader_dtype == "float32":
394-
shader_params["FORMAT"] = "rgba32f"
395-
# Assume float by default
396-
else:
397-
shader_params["FORMAT"] = self.env["FLOAT_IMAGE_FORMAT"]
398-
399415
return shader_params
400416

401417
def constructOutputMap(self) -> None:
@@ -732,9 +748,9 @@ def main(argv: List[str]) -> int:
732748
)
733749
options = parser.parse_args()
734750

735-
DEFAULT_ENV.update(TYPES_ENV)
736-
DEFAULT_ENV.update(FUNCS_ENV)
737751
env = DEFAULT_ENV
752+
env.update(TYPE_MAPPINGS)
753+
env.update(UTILITY_FNS)
738754

739755
for key, value in parse_arg_env(options.env).items():
740756
env[key] = value

backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,17 @@
88

99
#version 450 core
1010

11-
#include "broadcasting_utils.h"
12-
#include "indexing_utils.h"
13-
1411
#define PRECISION ${PRECISION}
1512

16-
#define OP(X, Y, A) ${OPERATOR}
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
#define to_tensor_idx to_tensor_idx_${PACKING}
16+
#define to_texture_pos to_texture_pos_${PACKING}
17+
18+
#define op(X, Y, A) ${OPERATOR}
19+
20+
#include "broadcasting_utils.h"
21+
#include "indexing_utils.h"
1722

1823
layout(std430) buffer;
1924

@@ -50,22 +55,22 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
5055

5156
void main() {
5257
const ivec3 pos = ivec3(gl_GlobalInvocationID);
53-
const ivec4 coord = POS_TO_COORD_${PACKING}(pos, out_sizes.data);
58+
const ivec4 idx = to_tensor_idx(pos, out_sizes.data);
5459

55-
if (any(greaterThanEqual(coord, out_sizes.data))) {
60+
if (any(greaterThanEqual(idx, out_sizes.data))) {
5661
return;
5762
}
5863

59-
ivec4 in_coord = out_coord_to_in_coord(coord, in_sizes.data);
60-
${VEC4_T[DTYPE]} in_texel = ${VEC4_T[DTYPE]}(texelFetch(
64+
ivec4 in_idx = broadcast_indices(idx, in_sizes.data);
65+
VEC4_T in_texel = VEC4_T(texelFetch(
6166
image_in,
62-
COORD_TO_POS_${PACKING}(in_coord, in_sizes.data),
67+
to_texture_pos(in_idx, in_sizes.data),
6368
0));
6469

65-
ivec4 other_coord = out_coord_to_in_coord(coord, other_sizes.data);
66-
${VEC4_T[DTYPE]} other_texel = ${VEC4_T[DTYPE]}(texelFetch(
70+
ivec4 other_idx = broadcast_indices(idx, other_sizes.data);
71+
VEC4_T other_texel = VEC4_T(texelFetch(
6772
image_other,
68-
COORD_TO_POS_${PACKING}(other_coord, other_sizes.data),
73+
to_texture_pos(other_idx, other_sizes.data),
6974
0));
7075

7176
// Check boolean broadcast flags; we use ivec2 instead of bvec2 for alignment.
@@ -76,5 +81,5 @@ void main() {
7681
other_texel = other_texel.xxxx;
7782
}
7883

79-
imageStore(image_out, pos, ${VEC4_T[DTYPE]}(OP(in_texel, other_texel, alpha.data)));
84+
imageStore(image_out, pos, VEC4_T(op(in_texel, other_texel, alpha.data)));
8085
}

backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,16 @@ binary_op:
99
OPERATOR: X + A * Y
1010
NDIM: 3
1111
DTYPE: float
12-
PACKING: CHANNELS_PACKED
12+
PACKING: C_packed
1313
generate_variant_forall:
1414
PACKING:
15-
- VALUE: CHANNELS_PACKED
16-
SUFFIX: C_packed
17-
- VALUE: WIDTH_PACKED
18-
SUFFIX: W_packed
19-
- VALUE: HEIGHT_PACKED
20-
SUFFIX: H_packed
15+
- VALUE: C_packed
16+
- VALUE: W_packed
17+
- VALUE: H_packed
2118
DTYPE:
2219
- VALUE: half
23-
SUFFIX: half
2420
- VALUE: float
25-
SUFFIX: float
2621
- VALUE: int
27-
SUFFIX: int
2822
shader_variants:
2923
- NAME: binary_add
3024
- NAME: binary_sub

backends/vulkan/runtime/graph/ops/glsl/broadcasting_utils.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
ivec4 out_coord_to_in_coord(const ivec4 out_coord, const ivec4 in_sizes) {
10-
ivec4 in_coord = out_coord;
9+
ivec4 broadcast_indices(const ivec4 out_idx, const ivec4 in_sizes) {
10+
ivec4 in_idx = out_idx;
1111
for (int i = 0; i < 4; ++i) {
12-
if (out_coord[i] >= in_sizes[i]) {
13-
in_coord[i] = 0;
12+
if (out_idx[i] >= in_sizes[i]) {
13+
in_idx[i] = 0;
1414
}
1515
}
16-
return in_coord;
16+
return in_idx;
1717
}

backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13+
#define VEC4_T ${texel_type(DTYPE)}
14+
1315
#include "indexing_utils.h"
1416

1517
layout(std430) buffer;
@@ -78,12 +80,12 @@ void main() {
7880
kstart.y += pos.z * params.kernel_size.y;
7981

8082
// Perform the convolution by iterating over the overlay region.
81-
${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
83+
VEC4_T sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
8284
const int ic4 = extra_params.in_group_size / 4;
8385
for (int z4 = 0; z4 < ic4; ++z4, kstart.x += params.kernel_size.x * 4) {
8486
for (int y = start.y, ky = kstart.y; y < end.y; y += params.dilation.y, ++ky) {
8587
for (int x = start.x, kx = kstart.x; x < end.x; x += params.dilation.x, kx += 4) {
86-
const ${VEC4_T[DTYPE]} in_texel = texelFetch(image_in, ivec3(x, y, z4), 0);
88+
const VEC4_T in_texel = texelFetch(image_in, ivec3(x, y, z4), 0);
8789
const ivec4 kxs = kx + ivec4(0, 1, 2, 3);
8890

8991
// To explain the calculation below, the contents of in_texel and the

backends/vulkan/runtime/graph/ops/glsl/conv2d.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ conv2d:
1111
generate_variant_forall:
1212
DTYPE:
1313
- VALUE: half
14-
SUFFIX: half
1514
- VALUE: float
16-
SUFFIX: float
1715
shader_variants:
1816
- NAME: conv2d

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13+
#define VEC4_T ${texel_type(DTYPE)}
14+
1315
#include "indexing_utils.h"
1416

1517
layout(std430) buffer;
@@ -66,14 +68,14 @@ void main() {
6668
const ivec2 start = ipos;
6769
const ivec2 end = ipos + extra_params.overlay_region.xy;
6870

69-
${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
71+
VEC4_T sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
7072
int kx = 0;
7173
for (int y = start.y; y < end.y; y += params.dilation.y) {
7274
for (int x = start.x; x < end.x; x += params.dilation.x) {
7375
// The weight kernel was rearranged such that every NxN filter is
7476
// flattened to fit in one row. Each filter was then stacked on top of
7577
// each other vertically.
76-
const ${VEC4_T[DTYPE]} in_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0);
78+
const VEC4_T in_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0);
7779
sum = fma(in_texel, texelFetch(kernel_in, ivec2(kx, pos.z), 0), sum);
7880
++kx;
7981
}

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ conv2d_dw:
1111
generate_variant_forall:
1212
DTYPE:
1313
- VALUE: half
14-
SUFFIX: half
1514
- VALUE: float
16-
SUFFIX: float
1715
shader_variants:
1816
- NAME: conv2d_dw

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13+
#define VEC4_T ${texel_type(DTYPE)}
14+
1315
#include "indexing_utils.h"
1416

1517
layout(std430) buffer;
@@ -66,7 +68,7 @@ void main() {
6668
const ivec2 start = ipos;
6769
const ivec2 end = ipos + extra_params.overlay_region.xy;
6870

69-
${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
71+
VEC4_T sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
7072
int kx = 0;
7173
for (int y = start.y, i = 0; i < ${TILE_SIZE}; y += params.dilation.y, i++) {
7274
for (int x = start.x, j = 0; j < ${TILE_SIZE}; x += params.dilation.x, j++) {

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ conv2d_dw_output_tile:
1212
generate_variant_forall:
1313
DTYPE:
1414
- VALUE: half
15-
SUFFIX: half
1615
- VALUE: float
17-
SUFFIX: float
1816
shader_variants:
1917
- NAME: conv2d_dw_output_tile_3x3
2018
- NAME: conv2d_dw_output_tile_5x5

0 commit comments

Comments
 (0)