Skip to content

[ET-VK][ez] Improvements to GLSL codegen script #10682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 87 additions & 47 deletions backends/vulkan/runtime/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,10 @@ def extract_filename(path: str, keep_ext: bool = True) -> Any:
return os.path.basename(path).split(".")[0]


def extract_extension(path: str) -> str:
return os.path.splitext(extract_filename(path))[1][1:]


############################
# SPIR-V Code Generation #
############################
Expand Down Expand Up @@ -561,26 +565,26 @@ def __init__(
self.glslc_flags_no_opt.remove("-Os")
self.replace_u16vecn = replace_u16vecn

self.glsl_src_files: Dict[str, str] = {}
self.src_files: Dict[str, str] = {}
self.template_yaml_files: List[str] = []

self.addSrcAndYamlFiles(self.src_dir_paths)
self.shader_template_params: Dict[Any, Any] = {}
for yaml_file in self.template_yaml_files:
self.parseTemplateYaml(yaml_file)

self.output_shader_map: Dict[str, Tuple[str, Dict[str, str]]] = {}
self.output_file_map: Dict[str, Tuple[str, Dict[str, str]]] = {}
self.constructOutputMap()

def addSrcAndYamlFiles(self, src_dir_paths: List[str]) -> None:
for src_path in src_dir_paths:
# Collect glsl source files
glsl_files = glob.glob(
src_files_list = glob.glob(
os.path.join(src_path, "**", "*.glsl*"), recursive=True
)
for file in glsl_files:
for file in src_files_list:
if len(file) > 1:
self.glsl_src_files[extract_filename(file, keep_ext=False)] = file
self.src_files[extract_filename(file, keep_ext=False)] = file
# Collect template yaml files
yaml_files = glob.glob(
os.path.join(src_path, "**", "*.yaml"), recursive=True
Expand Down Expand Up @@ -636,6 +640,7 @@ def parseTemplateYaml(self, yaml_file: str) -> None:
raise KeyError(f"{template_name} params file is defined twice")

default_params = params_dict["parameter_names_with_default_values"]
default_params["YAML_SRC_FULLPATH"] = yaml_file
params_names = set(default_params.keys()).union({"NAME"})

self.shader_template_params[template_name] = []
Expand Down Expand Up @@ -700,19 +705,19 @@ def create_shader_params(
return shader_params

def constructOutputMap(self) -> None:
for shader_name, params in self.shader_template_params.items():
for src_name, params in self.shader_template_params.items():
for variant in params:
source_glsl = self.glsl_src_files[shader_name]
src_file_fullpath = self.src_files[src_name]

self.output_shader_map[variant["NAME"]] = (
source_glsl,
self.output_file_map[variant["NAME"]] = (
src_file_fullpath,
self.create_shader_params(variant),
)

for shader_name, source_glsl in self.glsl_src_files.items():
if shader_name not in self.shader_template_params:
self.output_shader_map[shader_name] = (
source_glsl,
for src_name, src_file_fullpath in self.src_files.items():
if src_name not in self.shader_template_params:
self.output_file_map[src_name] = (
src_file_fullpath,
self.create_shader_params(),
)

Expand Down Expand Up @@ -763,56 +768,88 @@ def generateSPV( # noqa: C901
output_file_map = {}

def process_shader(shader_paths_pair):
shader_name = shader_paths_pair[0]
src_file_name = shader_paths_pair[0]

src_file_fullpath = shader_paths_pair[1][0]
codegen_params = shader_paths_pair[1][1]

source_glsl = shader_paths_pair[1][0]
shader_params = shader_paths_pair[1][1]
requires_codegen = True
if "YAML_SRC_FULLPATH" not in codegen_params:
requires_codegen = False

glsl_out_path = os.path.join(output_dir, f"{shader_name}.glsl")
spv_out_path = os.path.join(output_dir, f"{shader_name}.spv")
src_file_ext = extract_extension(src_file_fullpath)
out_file_ext = src_file_ext
compile_spv = False

if out_file_ext == "glsl":
compile_spv = True

gen_out_path = os.path.join(output_dir, f"{src_file_name}.{out_file_ext}")
spv_out_path = None
if compile_spv:
spv_out_path = os.path.join(output_dir, f"{src_file_name}.spv")

if cache_dir is not None:
cached_source_glsl = os.path.join(
cache_dir, os.path.basename(source_glsl) + ".t"
cached_src_file_fullpath = os.path.join(
cache_dir, os.path.basename(src_file_fullpath) + ".t"
)
cached_codegen_yaml = os.path.join(cache_dir, f"{src_file_name}.yaml")
cached_gen_out_path = os.path.join(
cache_dir, f"{src_file_name}.{out_file_ext}"
)
cached_glsl_out_path = os.path.join(cache_dir, f"{shader_name}.glsl")
cached_spv_out_path = os.path.join(cache_dir, f"{shader_name}.spv")
cached_spv_out_path = os.path.join(cache_dir, f"{src_file_name}.spv")
if (
not force_rebuild
and os.path.exists(cached_source_glsl)
and os.path.exists(cached_glsl_out_path)
and os.path.exists(cached_spv_out_path)
and os.path.exists(cached_src_file_fullpath)
and os.path.exists(cached_gen_out_path)
and (not requires_codegen or os.path.exists(cached_codegen_yaml))
and (not compile_spv or os.path.exists(cached_spv_out_path))
):
current_checksum = self.get_md5_checksum(source_glsl)
cached_checksum = self.get_md5_checksum(cached_source_glsl)
current_checksum = self.get_md5_checksum(src_file_fullpath)
cached_checksum = self.get_md5_checksum(cached_src_file_fullpath)
yaml_unchanged = True
if requires_codegen:
yaml_file_fullpath = codegen_params["YAML_SRC_FULLPATH"]
current_yaml_checksum = self.get_md5_checksum(
yaml_file_fullpath
)
cached_yaml_checksum = self.get_md5_checksum(
cached_codegen_yaml
)
yaml_unchanged = current_yaml_checksum == cached_yaml_checksum
# If the cached source GLSL template is the same as the current GLSL
# source file, then assume that the generated GLSL and SPIR-V will
# not have changed. In that case, just copy over the GLSL and SPIR-V
# files from the cache.
if current_checksum == cached_checksum:
shutil.copyfile(cached_spv_out_path, spv_out_path)
shutil.copyfile(cached_glsl_out_path, glsl_out_path)
return (spv_out_path, glsl_out_path)
if yaml_unchanged and current_checksum == cached_checksum:
shutil.copyfile(cached_gen_out_path, gen_out_path)
if compile_spv:
shutil.copyfile(cached_spv_out_path, spv_out_path)
return (spv_out_path, gen_out_path)

with codecs.open(source_glsl, "r", encoding="utf-8") as input_file:
with codecs.open(src_file_fullpath, "r", encoding="utf-8") as input_file:
input_text = input_file.read()
input_text = self.maybe_replace_u16vecn(input_text)
output_text = preprocess(input_text, shader_params)
output_text = preprocess(input_text, codegen_params)

with codecs.open(glsl_out_path, "w", encoding="utf-8") as output_file:
with codecs.open(gen_out_path, "w", encoding="utf-8") as output_file:
output_file.write(output_text)

if cache_dir is not None:
# Otherwise, store the generated GLSL files in the cache
shutil.copyfile(glsl_out_path, cached_glsl_out_path)

# If no GLSL compiler is specified, then only write out the generated GLSL shaders.
# This is mainly for testing purposes.
if self.glslc_path is not None:
shutil.copyfile(gen_out_path, cached_gen_out_path)
# If a YAML file was used to configure codegen, cache it as well
if requires_codegen:
yaml_file_fullpath = codegen_params["YAML_SRC_FULLPATH"]
shutil.copyfile(yaml_file_fullpath, cached_codegen_yaml)

# If no GLSL compiler is specified, or the source file is not a GLSL shader
# then only write out the generated GLSL shaders.
if compile_spv and self.glslc_path is not None:
cmd_base = [
self.glslc_path,
"-fshader-stage=compute",
glsl_out_path,
gen_out_path,
"-o",
spv_out_path,
"--target-env=vulkan1.1",
Expand All @@ -828,7 +865,7 @@ def process_shader(shader_paths_pair):
subprocess.run(cmd, check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as e:
opt_fail = "compilation succeeded but failed to optimize"
err_msg_base = f"Failed to compile {os.getcwd()}/{glsl_out_path}: "
err_msg_base = f"Failed to compile {os.getcwd()}/{gen_out_path}: "
if opt_fail in e.stderr or opt_fail in e.stdout:
cmd_no_opt = cmd_base + self.glslc_flags_no_opt
try:
Expand All @@ -844,23 +881,23 @@ def process_shader(shader_paths_pair):
if cache_dir is not None:
shutil.copyfile(spv_out_path, cached_spv_out_path)

return (spv_out_path, glsl_out_path)
return (spv_out_path, gen_out_path)

# Parallelize shader compilation as much as possible to optimize build time.
with ThreadPool(os.cpu_count()) as pool:
for spv_out_path, glsl_out_path in pool.map(
process_shader, self.output_shader_map.items()
process_shader, self.output_file_map.items()
):
output_file_map[spv_out_path] = glsl_out_path

# Save all source GLSL files to the cache. Only do this at the very end since
# multiple variants may use the same source file.
if cache_dir is not None:
for _, source_glsl in self.glsl_src_files.items():
cached_source_glsl = os.path.join(
cache_dir, os.path.basename(source_glsl) + ".t"
for _, src_file_fullpath in self.src_files.items():
cached_src_file = os.path.join(
cache_dir, os.path.basename(src_file_fullpath) + ".t"
)
shutil.copyfile(source_glsl, cached_source_glsl)
shutil.copyfile(src_file_fullpath, cached_src_file)

return output_file_map

Expand Down Expand Up @@ -1100,6 +1137,9 @@ def genCppFiles(
shader_registry_strs = []

for spvPath, srcPath in spv_files.items():
if spvPath is None:
continue

name = getName(spvPath).replace("_spv", "")

sizeBytes, spv_bin_str = generateSpvBinStr(spvPath, name)
Expand Down
7 changes: 3 additions & 4 deletions backends/vulkan/test/glsl/reference_matmul.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ ${layout_declare_ubo(8, "ivec4", "mat2_strides")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#include "reference_matmul_common_buffer.glslh"

void main() {
const ivec2 out_idx = ivec2(gl_GlobalInvocationID.x, gl_GlobalInvocationID.y);
if (any(greaterThanEqual(out_idx, out_sizes.xy))) {
Expand All @@ -37,10 +39,7 @@ void main() {

float sum = 0.0;
for (int i = 0; i < mat1_sizes.x; ++i) {
sum += t_mat1[mat1_id] * t_mat2[mat2_id];

mat1_id += mat1_strides.x;
mat2_id += mat2_strides.y;
sum += perform_dot_product(out_idx.y, out_idx.x, i);
}

const int out_id = out_idx.x * out_strides.x + out_idx.y * out_strides.y;
Expand Down
25 changes: 25 additions & 0 deletions backends/vulkan/test/glsl/reference_matmul_common.glslh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef MATMUL_COMMON_${STORAGE}_H
#define MATMUL_COMMON_${STORAGE}_H

$if STORAGE == "buffer":
float perform_dot_product(
const uint out_row,
const uint out_col,
const uint k) {
const uint mat1_bufi = out_row * mat1_strides.y + k * mat1_strides.x;
const uint mat2_bufi = k * mat2_strides.y + out_col * mat2_strides.x;

return t_mat1[mat1_bufi] * t_mat2[mat2_bufi];
}
$else:
vec4 perform_dot_product(
const uint out_row,
const uint out_col,
const uint k) {
vec4 mat1_tex = texelFetch(t_mat1, ivec3(k, out_row, 0), 0);
vec4 mat2_tex = texelFetch(t_mat2, ivec3(out_col, k, 0), 0);

return dot(mat1_tex, mat2_tex);
}

#endif
9 changes: 9 additions & 0 deletions backends/vulkan/test/glsl/reference_matmul_common.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
reference_matmul_common:
parameter_names_with_default_values:
STORAGE: buffer
generate_variant_forall:
STORAGE:
- VALUE: buffer
- VALUE: texture3d
shader_variants:
- NAME: reference_matmul_common
Loading