Skip to content

[ET-VK] Clean up shader library and introduce some new conventions #3024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 68 additions & 52 deletions backends/vulkan/runtime/api/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,13 @@
CPP_H_NAME = "spv.h"
CPP_SRC_NAME = "spv.cpp"

# Basic configuration settings for shaders
DEFAULT_ENV: Dict[str, Any] = {
"PRECISION": "highp",
"FLOAT_IMAGE_FORMAT": "rgba16f",
"INT_IMAGE_FORMAT": "rgba32i",
"UINT_IMAGE_FORMAT": "rgba32ui",
}

TYPES_ENV: Dict[str, Any] = {
"IMAGE_FORMAT": {
"float": "rgba32f",
"half": "rgba16f",
"int": "rgba32i",
"uint": "rgba32ui",
"int8": "rgba8i",
"uint8": "rgba8ui",
},
# Establishes relationships between different tensor types and different GLSL types
TYPE_MAPPINGS: Dict[str, Any] = {
"IMAGE_T": {
3: {
"float": "image3D",
Expand Down Expand Up @@ -78,29 +69,74 @@
"uint": "usampler2D",
},
},
"VEC4_T": {
"float": "vec4",
"half": "vec4",
"int": "ivec4",
"uint": "uvec4",
"int8": "vec4",
"uint8": "uvec4",
},
"T": {
"float": "float",
"half": "float",
"int": "int",
"uint": "uint",
"int8": "int",
"uint8": "uint8",
"IMAGE_FORMAT": {
"float": "rgba32f",
"half": "rgba16f",
"int": "rgba32i",
"uint": "rgba32ui",
"int8": "rgba8i",
"uint8": "rgba8ui",
},
}

FUNCS_ENV: Dict[str, Any] = {
"GET_POS": {

def define_variable(name: str) -> str:
if name in locals():
return f"#define {name} {locals()[name]}"
elif name in globals():
return f"#define {name} {globals()[name]}"
else:
raise RuntimeError(f"{name} is not defined")


def get_buffer_scalar_type(dtype: str) -> str:
# TODO(ssjia): use float16_t for half types
if dtype == "half":
return "float"
# TODO(ssjia): use int8_t for int8 types
elif dtype[-1] == "8":
return dtype[:-1]

return dtype


def get_texel_type(dtype: str) -> str:
image_format = TYPE_MAPPINGS["IMAGE_FORMAT"][dtype]
if image_format[-1] == "f":
return "vec4"
elif image_format[-2] == "ui":
return "uvec4"
elif image_format[-1] == "i":
return "ivec4"
raise AssertionError(f"Invalid image format: {image_format}")


def get_gvec_type(dtype: str, n: int) -> str:
gvec4_type = get_texel_type(dtype)
return gvec4_type[:-1] + str(n)


def get_texel_component_type(dtype: str) -> str:
vec4_type = get_texel_type(dtype)
if vec4_type[:3] == "vec":
return "float"
elif vec4_type[:4] == "ivec":
return "int"
elif vec4_type[:4] == "uvec":
return "uint"
raise AssertionError(f"Invalid vec4 type: {vec4_type}")


UTILITY_FNS: Dict[str, Any] = {
"macro_define": define_variable,
"get_pos": {
3: lambda pos: pos,
2: lambda pos: f"{pos}.xy",
}
},
"buffer_scalar_type": get_buffer_scalar_type,
"texel_type": get_texel_type,
"gvec_type": get_gvec_type,
"texel_component_type": get_texel_component_type,
}


Expand Down Expand Up @@ -376,26 +412,6 @@ def create_shader_params(
for key, value in variant_params.items():
shader_params[key] = value

shader_dtype = shader_params.get("DTYPE", "float")

if shader_dtype == "int":
shader_params["FORMAT"] = self.env["INT_IMAGE_FORMAT"]
elif shader_dtype == "uint":
shader_params["FORMAT"] = self.env["UINT_IMAGE_FORMAT"]
elif shader_dtype == "int32":
shader_params["FORMAT"] = "rgba32i"
elif shader_dtype == "uint32":
shader_params["FORMAT"] = "rgba32ui"
elif shader_dtype == "int8":
shader_params["FORMAT"] = "rgba8i"
elif shader_dtype == "uint8":
shader_params["FORMAT"] = "rgba8ui"
elif shader_dtype == "float32":
shader_params["FORMAT"] = "rgba32f"
# Assume float by default
else:
shader_params["FORMAT"] = self.env["FLOAT_IMAGE_FORMAT"]

return shader_params

def constructOutputMap(self) -> None:
Expand Down Expand Up @@ -732,9 +748,9 @@ def main(argv: List[str]) -> int:
)
options = parser.parse_args()

DEFAULT_ENV.update(TYPES_ENV)
DEFAULT_ENV.update(FUNCS_ENV)
env = DEFAULT_ENV
env.update(TYPE_MAPPINGS)
env.update(UTILITY_FNS)

for key, value in parse_arg_env(options.env).items():
env[key] = value
Expand Down
31 changes: 18 additions & 13 deletions backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@

#version 450 core

#include "broadcasting_utils.h"
#include "indexing_utils.h"

#define PRECISION ${PRECISION}

#define OP(X, Y, A) ${OPERATOR}
#define VEC4_T ${texel_type(DTYPE)}

#define to_tensor_idx to_tensor_idx_${PACKING}
#define to_texture_pos to_texture_pos_${PACKING}

#define op(X, Y, A) ${OPERATOR}

#include "broadcasting_utils.h"
#include "indexing_utils.h"

layout(std430) buffer;

Expand Down Expand Up @@ -50,22 +55,22 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const ivec4 coord = POS_TO_COORD_${PACKING}(pos, out_sizes.data);
const ivec4 idx = to_tensor_idx(pos, out_sizes.data);

if (any(greaterThanEqual(coord, out_sizes.data))) {
if (any(greaterThanEqual(idx, out_sizes.data))) {
return;
}

ivec4 in_coord = out_coord_to_in_coord(coord, in_sizes.data);
${VEC4_T[DTYPE]} in_texel = ${VEC4_T[DTYPE]}(texelFetch(
ivec4 in_idx = broadcast_indices(idx, in_sizes.data);
VEC4_T in_texel = VEC4_T(texelFetch(
image_in,
COORD_TO_POS_${PACKING}(in_coord, in_sizes.data),
to_texture_pos(in_idx, in_sizes.data),
0));

ivec4 other_coord = out_coord_to_in_coord(coord, other_sizes.data);
${VEC4_T[DTYPE]} other_texel = ${VEC4_T[DTYPE]}(texelFetch(
ivec4 other_idx = broadcast_indices(idx, other_sizes.data);
VEC4_T other_texel = VEC4_T(texelFetch(
image_other,
COORD_TO_POS_${PACKING}(other_coord, other_sizes.data),
to_texture_pos(other_idx, other_sizes.data),
0));

// Check boolean broadcast flags; we use ivec2 instead of bvec2 for alignment.
Expand All @@ -76,5 +81,5 @@ void main() {
other_texel = other_texel.xxxx;
}

imageStore(image_out, pos, ${VEC4_T[DTYPE]}(OP(in_texel, other_texel, alpha.data)));
imageStore(image_out, pos, VEC4_T(op(in_texel, other_texel, alpha.data)));
}
14 changes: 4 additions & 10 deletions backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,16 @@ binary_op:
OPERATOR: X + A * Y
NDIM: 3
DTYPE: float
PACKING: CHANNELS_PACKED
PACKING: C_packed
generate_variant_forall:
PACKING:
- VALUE: CHANNELS_PACKED
SUFFIX: C_packed
- VALUE: WIDTH_PACKED
SUFFIX: W_packed
- VALUE: HEIGHT_PACKED
SUFFIX: H_packed
- VALUE: C_packed
- VALUE: W_packed
- VALUE: H_packed
DTYPE:
- VALUE: half
SUFFIX: half
- VALUE: float
SUFFIX: float
- VALUE: int
SUFFIX: int
shader_variants:
- NAME: binary_add
- NAME: binary_sub
Expand Down
10 changes: 5 additions & 5 deletions backends/vulkan/runtime/graph/ops/glsl/broadcasting_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
* LICENSE file in the root directory of this source tree.
*/

ivec4 out_coord_to_in_coord(const ivec4 out_coord, const ivec4 in_sizes) {
ivec4 in_coord = out_coord;
ivec4 broadcast_indices(const ivec4 out_idx, const ivec4 in_sizes) {
ivec4 in_idx = out_idx;
for (int i = 0; i < 4; ++i) {
if (out_coord[i] >= in_sizes[i]) {
in_coord[i] = 0;
if (out_idx[i] >= in_sizes[i]) {
in_idx[i] = 0;
}
}
return in_coord;
return in_idx;
}
6 changes: 4 additions & 2 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_type(DTYPE)}

#include "indexing_utils.h"

layout(std430) buffer;
Expand Down Expand Up @@ -78,12 +80,12 @@ void main() {
kstart.y += pos.z * params.kernel_size.y;

// Perform the convolution by iterating over the overlay region.
${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
VEC4_T sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
const int ic4 = extra_params.in_group_size / 4;
for (int z4 = 0; z4 < ic4; ++z4, kstart.x += params.kernel_size.x * 4) {
for (int y = start.y, ky = kstart.y; y < end.y; y += params.dilation.y, ++ky) {
for (int x = start.x, kx = kstart.x; x < end.x; x += params.dilation.x, kx += 4) {
const ${VEC4_T[DTYPE]} in_texel = texelFetch(image_in, ivec3(x, y, z4), 0);
const VEC4_T in_texel = texelFetch(image_in, ivec3(x, y, z4), 0);
const ivec4 kxs = kx + ivec4(0, 1, 2, 3);

// To explain the calculation below, the contents of in_texel and the
Expand Down
2 changes: 0 additions & 2 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ conv2d:
generate_variant_forall:
DTYPE:
- VALUE: half
SUFFIX: half
- VALUE: float
SUFFIX: float
shader_variants:
- NAME: conv2d
6 changes: 4 additions & 2 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_type(DTYPE)}

#include "indexing_utils.h"

layout(std430) buffer;
Expand Down Expand Up @@ -66,14 +68,14 @@ void main() {
const ivec2 start = ipos;
const ivec2 end = ipos + extra_params.overlay_region.xy;

${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
VEC4_T sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
int kx = 0;
for (int y = start.y; y < end.y; y += params.dilation.y) {
for (int x = start.x; x < end.x; x += params.dilation.x) {
// The weight kernel was rearranged such that every NxN filter is
// flattened to fit in one row. Each filter was then stacked on top of
// each other vertically.
const ${VEC4_T[DTYPE]} in_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0);
const VEC4_T in_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0);
sum = fma(in_texel, texelFetch(kernel_in, ivec2(kx, pos.z), 0), sum);
++kx;
}
Expand Down
2 changes: 0 additions & 2 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ conv2d_dw:
generate_variant_forall:
DTYPE:
- VALUE: half
SUFFIX: half
- VALUE: float
SUFFIX: float
shader_variants:
- NAME: conv2d_dw
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_type(DTYPE)}

#include "indexing_utils.h"

layout(std430) buffer;
Expand Down Expand Up @@ -66,7 +68,7 @@ void main() {
const ivec2 start = ipos;
const ivec2 end = ipos + extra_params.overlay_region.xy;

${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
VEC4_T sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
int kx = 0;
for (int y = start.y, i = 0; i < ${TILE_SIZE}; y += params.dilation.y, i++) {
for (int x = start.x, j = 0; j < ${TILE_SIZE}; x += params.dilation.x, j++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ conv2d_dw_output_tile:
generate_variant_forall:
DTYPE:
- VALUE: half
SUFFIX: half
- VALUE: float
SUFFIX: float
shader_variants:
- NAME: conv2d_dw_output_tile_3x3
- NAME: conv2d_dw_output_tile_5x5
Expand Down
Loading