Skip to content

Commit 51c7050

Browse files
committed
[ET-VK] Clean up shader library and introduce some new conventions
## Context This changeset introduces some fairly mechnical improvements to the Vulkan compute graph shader library in order to introduce some new conventions. **Note that backwards compatibility with existing shader authoring methods is preserved**. ### Only List `VALUE` in the `.yaml` files Previously, to generate variants for a combination of vales, the YAML file will contain ``` PACKING: - VALUE: CHANNELS_PACKED SUFFIX: C_packed - VALUE: WIDTH_PACKED SUFFIX: W_packed - VALUE: HEIGHT_PACKED SUFFIX: H_packed ``` however, the shader code generation script will use the `VALUE` as the `SUFFIX` if no `SUFFIX` is provided. Therefore, only the below is needed: ``` PACKING: - VALUE: C_packed - VALUE: W_packed - VALUE: H_packed ``` ### Change indexing utility macros to lowercase Indexing utility macros have been changed to lowercase, and the packing identifiers have been changed due to the change in YAML files. The change to lowercase is to make calls to the macro read more like functions (and indeed they are typically used as functions) in order to help make the code more readable. ``` POS_TO_COORD_${PACKING} -> pos_to_coord_${PACKING} ``` ### Use convention of defining macros in order to reduce Python code blocks usage Previously python code blocks were used in the GLSL code itself in order to vary the shader between different settings. However, usage of Python code blocks negatively impact code readability. Therefore, this diff seeks to introduce a convention of defining macros near the top of the shader to reduce the usage of Python code blocks, i.e. ``` #define pos_to_coord pos_to_coord_${PACKING} #define get_packed_dim get_packed_dim_${PACKING} #define get_packed_stride get_packed_stride_${PACKING} ``` ### Improve GLSL type definitions Previously, the following Python code blocks were used to determine appropriate vectorized and scalar types: ``` ${VEC4_T[DTYPE}} texel = ... ${T[DTYPE]} scalar = ... ``` This changeset replaces that with: ``` #define BUF_T ${buffer_scalar_type(DTYPE)} #define VEC4_T ${texel_type(DTYPE)} #define SCALAR_T ${texel_component_type(DTYPE)} layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { BUF_T data[]; } buffer_in; VEC4_T texel = ... SCALAR_T scalar = ... ``` The main differences are as such: * `buffer_scalar_type()` produces the same result as `T[DTYPE]` * `texel_type()` is not determined from a mapping with `DTYPE`, but is determined indirectly based on the image format that is associated with the `DTYPE`. * `texel_component_type()` is based on the result of `texel_type(DTYPE)` Essentially, the mapping is more in-line with what happens in code. The reason for this change is to enable FP16 support and is a bit complicated. Basically, we need a way to distinguish the scalar type used for buffer storage, vs the scalar type used to store a component of a vec4 type (hence `BUF_T` vs `SCALAR_T`). The reason this is required is that to support half-precision tensors, the buffer representation will use a 16-bit float type but textures will still extract to `vec4` (i.e. 4x34bit floats). Differential Revision: [D56082461](https://our.internmc.facebook.com/intern/diff/D56082461/) ghstack-source-id: 222379977 Pull Request resolved: #3024
1 parent 74eb8b3 commit 51c7050

38 files changed

+301
-234
lines changed

backends/vulkan/runtime/api/gen_vulkan_spv.py

Lines changed: 56 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,13 @@
3434
CPP_H_NAME = "spv.h"
3535
CPP_SRC_NAME = "spv.cpp"
3636

37+
# Basic configuration settings for shaders
3738
DEFAULT_ENV: Dict[str, Any] = {
3839
"PRECISION": "highp",
39-
"FLOAT_IMAGE_FORMAT": "rgba16f",
40-
"INT_IMAGE_FORMAT": "rgba32i",
41-
"UINT_IMAGE_FORMAT": "rgba32ui",
4240
}
4341

44-
TYPES_ENV: Dict[str, Any] = {
45-
"IMAGE_FORMAT": {
46-
"float": "rgba32f",
47-
"half": "rgba16f",
48-
"int": "rgba32i",
49-
"uint": "rgba32ui",
50-
"int8": "rgba8i",
51-
"uint8": "rgba8ui",
52-
},
42+
# Establishes relationships between different tensor types and different GLSL types
43+
TYPE_MAPPINGS: Dict[str, Any] = {
5344
"IMAGE_T": {
5445
3: {
5546
"float": "image3D",
@@ -78,6 +69,37 @@
7869
"uint": "usampler2D",
7970
},
8071
},
72+
"IMAGE_FORMAT": {
73+
"float": "rgba32f",
74+
"half": "rgba16f",
75+
"int": "rgba32i",
76+
"uint": "rgba32ui",
77+
"int8": "rgba8i",
78+
"uint8": "rgba8ui",
79+
},
80+
"TEXEL_EXTRACT_TYPE": {
81+
"rgba32f": "vec4",
82+
"rgba16f": "vec4",
83+
"rgba32i": "ivec4",
84+
"rgba32ui": "uvec4",
85+
"int8": "ivec4",
86+
"uint8": "uvec4",
87+
},
88+
"TEXEL_COMPONENT_TYPE": {
89+
"vec4": "float",
90+
"ivec4": "int",
91+
"uvec4": "uint",
92+
},
93+
"BUFFER_SCALAR_TYPE": {
94+
"float": "float",
95+
"half": "float",
96+
"int": "int",
97+
"uint": "uint",
98+
"int8": "int",
99+
"uint8": "uint",
100+
},
101+
# Kept for backwards compatibility
102+
# TODO(ssjia): remove when no more shaders use these
81103
"VEC4_T": {
82104
"float": "vec4",
83105
"half": "vec4",
@@ -96,11 +118,28 @@
96118
},
97119
}
98120

99-
FUNCS_ENV: Dict[str, Any] = {
100-
"GET_POS": {
121+
122+
def get_buffer_scalar_type(dtype: str) -> str:
123+
return TYPE_MAPPINGS["BUFFER_SCALAR_TYPE"][dtype]
124+
125+
126+
def get_texel_type(dtype: str) -> str:
127+
image_format = TYPE_MAPPINGS["IMAGE_FORMAT"][dtype]
128+
return TYPE_MAPPINGS["TEXEL_EXTRACT_TYPE"][image_format]
129+
130+
131+
def get_texel_component_type(dtype: str) -> str:
132+
return TYPE_MAPPINGS["TEXEL_COMPONENT_TYPE"][get_texel_type(dtype)]
133+
134+
135+
UTILITY_FNS: Dict[str, Any] = {
136+
"get_pos": {
101137
3: lambda pos: pos,
102138
2: lambda pos: f"{pos}.xy",
103-
}
139+
},
140+
"buffer_scalar_type": get_buffer_scalar_type,
141+
"texel_type": get_texel_type,
142+
"texel_component_type": get_texel_component_type,
104143
}
105144

106145

@@ -376,26 +415,6 @@ def create_shader_params(
376415
for key, value in variant_params.items():
377416
shader_params[key] = value
378417

379-
shader_dtype = shader_params.get("DTYPE", "float")
380-
381-
if shader_dtype == "int":
382-
shader_params["FORMAT"] = self.env["INT_IMAGE_FORMAT"]
383-
elif shader_dtype == "uint":
384-
shader_params["FORMAT"] = self.env["UINT_IMAGE_FORMAT"]
385-
elif shader_dtype == "int32":
386-
shader_params["FORMAT"] = "rgba32i"
387-
elif shader_dtype == "uint32":
388-
shader_params["FORMAT"] = "rgba32ui"
389-
elif shader_dtype == "int8":
390-
shader_params["FORMAT"] = "rgba8i"
391-
elif shader_dtype == "uint8":
392-
shader_params["FORMAT"] = "rgba8ui"
393-
elif shader_dtype == "float32":
394-
shader_params["FORMAT"] = "rgba32f"
395-
# Assume float by default
396-
else:
397-
shader_params["FORMAT"] = self.env["FLOAT_IMAGE_FORMAT"]
398-
399418
return shader_params
400419

401420
def constructOutputMap(self) -> None:
@@ -732,9 +751,9 @@ def main(argv: List[str]) -> int:
732751
)
733752
options = parser.parse_args()
734753

735-
DEFAULT_ENV.update(TYPES_ENV)
736-
DEFAULT_ENV.update(FUNCS_ENV)
737754
env = DEFAULT_ENV
755+
env.update(TYPE_MAPPINGS)
756+
env.update(UTILITY_FNS)
738757

739758
for key, value in parse_arg_env(options.env).items():
740759
env[key] = value

backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,16 @@
1313

1414
#define PRECISION ${PRECISION}
1515

16-
#define OP(X, Y, A) ${OPERATOR}
1716

1817
layout(std430) buffer;
1918

20-
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
19+
#define OP(X, Y, A) ${OPERATOR}
20+
21+
#define VEC4_T ${texel_type(DTYPE)}
22+
#define pos_to_coord pos_to_coord_${PACKING}
23+
#define coord_to_pos coord_to_pos_${PACKING}
24+
25+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[ND][DTYPE]} image_out;
2126
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
2227
layout(set = 0, binding = 2) uniform PRECISION sampler3D image_other;
2328

@@ -50,22 +55,22 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
5055

5156
void main() {
5257
const ivec3 pos = ivec3(gl_GlobalInvocationID);
53-
const ivec4 coord = POS_TO_COORD_${PACKING}(pos, out_sizes.data);
58+
const ivec4 coord = pos_to_coord(pos, out_sizes.data);
5459

5560
if (any(greaterThanEqual(coord, out_sizes.data))) {
5661
return;
5762
}
5863

5964
ivec4 in_coord = out_coord_to_in_coord(coord, in_sizes.data);
60-
${VEC4_T[DTYPE]} in_texel = ${VEC4_T[DTYPE]}(texelFetch(
65+
VEC4_T in_texel = VEC4_T(texelFetch(
6166
image_in,
62-
COORD_TO_POS_${PACKING}(in_coord, in_sizes.data),
67+
coord_to_pos(in_coord, in_sizes.data),
6368
0));
6469

6570
ivec4 other_coord = out_coord_to_in_coord(coord, other_sizes.data);
66-
${VEC4_T[DTYPE]} other_texel = ${VEC4_T[DTYPE]}(texelFetch(
71+
VEC4_T other_texel = VEC4_T(texelFetch(
6772
image_other,
68-
COORD_TO_POS_${PACKING}(other_coord, other_sizes.data),
73+
coord_to_pos(other_coord, other_sizes.data),
6974
0));
7075

7176
// Check boolean broadcast flags; we use ivec2 instead of bvec2 for alignment.
@@ -76,5 +81,5 @@ void main() {
7681
other_texel = other_texel.xxxx;
7782
}
7883

79-
imageStore(image_out, pos, ${VEC4_T[DTYPE]}(OP(in_texel, other_texel, alpha.data)));
84+
imageStore(image_out, pos, VEC4_T(OP(in_texel, other_texel, alpha.data)));
8085
}

backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,18 @@
77
binary_op:
88
parameter_names_with_default_values:
99
OPERATOR: X + A * Y
10-
NDIM: 3
10+
ND: 3
1111
DTYPE: float
12-
PACKING: CHANNELS_PACKED
12+
PACKING: C_packed
1313
generate_variant_forall:
1414
PACKING:
15-
- VALUE: CHANNELS_PACKED
16-
SUFFIX: C_packed
17-
- VALUE: WIDTH_PACKED
18-
SUFFIX: W_packed
19-
- VALUE: HEIGHT_PACKED
20-
SUFFIX: H_packed
15+
- VALUE: C_packed
16+
- VALUE: W_packed
17+
- VALUE: H_packed
2118
DTYPE:
2219
- VALUE: half
23-
SUFFIX: half
2420
- VALUE: float
25-
SUFFIX: float
2621
- VALUE: int
27-
SUFFIX: int
2822
shader_variants:
2923
- NAME: binary_add
3024
- NAME: binary_sub

backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
layout(std430) buffer;
1616

17-
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
17+
#define VEC4_T ${texel_type(DTYPE)}
18+
19+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[ND][DTYPE]} image_out;
1820
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
1921
layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in;
2022
layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in;
@@ -78,12 +80,12 @@ void main() {
7880
kstart.y += pos.z * params.kernel_size.y;
7981

8082
// Perform the convolution by iterating over the overlay region.
81-
${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
83+
VEC4_T sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
8284
const int ic4 = extra_params.in_group_size / 4;
8385
for (int z4 = 0; z4 < ic4; ++z4, kstart.x += params.kernel_size.x * 4) {
8486
for (int y = start.y, ky = kstart.y; y < end.y; y += params.dilation.y, ++ky) {
8587
for (int x = start.x, kx = kstart.x; x < end.x; x += params.dilation.x, kx += 4) {
86-
const ${VEC4_T[DTYPE]} in_texel = texelFetch(image_in, ivec3(x, y, z4), 0);
88+
const VEC4_T in_texel = texelFetch(image_in, ivec3(x, y, z4), 0);
8789
const ivec4 kxs = kx + ivec4(0, 1, 2, 3);
8890

8991
// To explain the calculation below, the contents of in_texel and the

backends/vulkan/runtime/graph/ops/glsl/conv2d.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,11 @@
66

77
conv2d:
88
parameter_names_with_default_values:
9-
NDIM: 3
9+
ND: 3
1010
DTYPE: float
1111
generate_variant_forall:
1212
DTYPE:
1313
- VALUE: half
14-
SUFFIX: half
1514
- VALUE: float
16-
SUFFIX: float
1715
shader_variants:
1816
- NAME: conv2d

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
layout(std430) buffer;
1616

17-
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
17+
#define VEC4_T ${texel_type(DTYPE)}
18+
19+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[ND][DTYPE]} image_out;
1820
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
1921
layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in;
2022
layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in;
@@ -66,14 +68,14 @@ void main() {
6668
const ivec2 start = ipos;
6769
const ivec2 end = ipos + extra_params.overlay_region.xy;
6870

69-
${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
71+
VEC4_T sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
7072
int kx = 0;
7173
for (int y = start.y; y < end.y; y += params.dilation.y) {
7274
for (int x = start.x; x < end.x; x += params.dilation.x) {
7375
// The weight kernel was rearranged such that every NxN filter is
7476
// flattened to fit in one row. Each filter was then stacked on top of
7577
// each other vertically.
76-
const ${VEC4_T[DTYPE]} in_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0);
78+
const VEC4_T in_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0);
7779
sum = fma(in_texel, texelFetch(kernel_in, ivec2(kx, pos.z), 0), sum);
7880
++kx;
7981
}

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,11 @@
66

77
conv2d_dw:
88
parameter_names_with_default_values:
9-
NDIM: 3
9+
ND: 3
1010
DTYPE: float
1111
generate_variant_forall:
1212
DTYPE:
1313
- VALUE: half
14-
SUFFIX: half
1514
- VALUE: float
16-
SUFFIX: float
1715
shader_variants:
1816
- NAME: conv2d_dw

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
layout(std430) buffer;
1616

17-
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
17+
#define VEC4_T ${texel_type(DTYPE)}
18+
19+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[ND][DTYPE]} image_out;
1820
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
1921
layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in;
2022
layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in;
@@ -66,7 +68,7 @@ void main() {
6668
const ivec2 start = ipos;
6769
const ivec2 end = ipos + extra_params.overlay_region.xy;
6870

69-
${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
71+
VEC4_T sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
7072
int kx = 0;
7173
for (int y = start.y, i = 0; i < ${TILE_SIZE}; y += params.dilation.y, i++) {
7274
for (int x = start.x, j = 0; j < ${TILE_SIZE}; x += params.dilation.x, j++) {

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@
66

77
conv2d_dw_output_tile:
88
parameter_names_with_default_values:
9-
NDIM: 3
9+
ND: 3
1010
DTYPE: float
1111
TILE_SIZE: 3
1212
generate_variant_forall:
1313
DTYPE:
1414
- VALUE: half
15-
SUFFIX: half
1615
- VALUE: float
17-
SUFFIX: float
1816
shader_variants:
1917
- NAME: conv2d_dw_output_tile_3x3
2018
- NAME: conv2d_dw_output_tile_5x5

0 commit comments

Comments
 (0)