Skip to content

Commit e7a6838

Browse files
committed
[ET-VK][ez] Improvements to GLSL codegen script
Pull Request resolved: #10605 ## Changes Fixed a bug in the caching mechanism in the `gen_vulkan_spv.py` script where changing the YAML file will not trigger a recompile. Allow the python codegen to be applied to files other than GLSL compute shaders. One example application of this is generating header files (`.glslh` files) which can be shared among multiple `.glsl` files. This enables better code re-use and organization. For a simple application of this, see the `reference_matmul_common.glslh` file that was added to the test compute shader library. ghstack-source-id: 282026609 @exported-using-ghexport Differential Revision: [D74008572](https://our.internmc.facebook.com/intern/diff/D74008572/)
1 parent cd3b53d commit e7a6838

File tree

4 files changed

+124
-51
lines changed

4 files changed

+124
-51
lines changed

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 87 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,10 @@ def extract_filename(path: str, keep_ext: bool = True) -> Any:
402402
return os.path.basename(path).split(".")[0]
403403

404404

405+
def extract_extension(path: str) -> str:
406+
return os.path.splitext(extract_filename(path))[1][1:]
407+
408+
405409
############################
406410
# SPIR-V Code Generation #
407411
############################
@@ -561,26 +565,26 @@ def __init__(
561565
self.glslc_flags_no_opt.remove("-Os")
562566
self.replace_u16vecn = replace_u16vecn
563567

564-
self.glsl_src_files: Dict[str, str] = {}
568+
self.src_files: Dict[str, str] = {}
565569
self.template_yaml_files: List[str] = []
566570

567571
self.addSrcAndYamlFiles(self.src_dir_paths)
568572
self.shader_template_params: Dict[Any, Any] = {}
569573
for yaml_file in self.template_yaml_files:
570574
self.parseTemplateYaml(yaml_file)
571575

572-
self.output_shader_map: Dict[str, Tuple[str, Dict[str, str]]] = {}
576+
self.output_file_map: Dict[str, Tuple[str, Dict[str, str]]] = {}
573577
self.constructOutputMap()
574578

575579
def addSrcAndYamlFiles(self, src_dir_paths: List[str]) -> None:
576580
for src_path in src_dir_paths:
577581
# Collect glsl source files
578-
glsl_files = glob.glob(
582+
src_files_list = glob.glob(
579583
os.path.join(src_path, "**", "*.glsl*"), recursive=True
580584
)
581-
for file in glsl_files:
585+
for file in src_files_list:
582586
if len(file) > 1:
583-
self.glsl_src_files[extract_filename(file, keep_ext=False)] = file
587+
self.src_files[extract_filename(file, keep_ext=False)] = file
584588
# Collect template yaml files
585589
yaml_files = glob.glob(
586590
os.path.join(src_path, "**", "*.yaml"), recursive=True
@@ -636,6 +640,7 @@ def parseTemplateYaml(self, yaml_file: str) -> None:
636640
raise KeyError(f"{template_name} params file is defined twice")
637641

638642
default_params = params_dict["parameter_names_with_default_values"]
643+
default_params["YAML_SRC_FULLPATH"] = yaml_file
639644
params_names = set(default_params.keys()).union({"NAME"})
640645

641646
self.shader_template_params[template_name] = []
@@ -700,19 +705,19 @@ def create_shader_params(
700705
return shader_params
701706

702707
def constructOutputMap(self) -> None:
703-
for shader_name, params in self.shader_template_params.items():
708+
for src_name, params in self.shader_template_params.items():
704709
for variant in params:
705-
source_glsl = self.glsl_src_files[shader_name]
710+
src_file_fullpath = self.src_files[src_name]
706711

707-
self.output_shader_map[variant["NAME"]] = (
708-
source_glsl,
712+
self.output_file_map[variant["NAME"]] = (
713+
src_file_fullpath,
709714
self.create_shader_params(variant),
710715
)
711716

712-
for shader_name, source_glsl in self.glsl_src_files.items():
713-
if shader_name not in self.shader_template_params:
714-
self.output_shader_map[shader_name] = (
715-
source_glsl,
717+
for src_name, src_file_fullpath in self.src_files.items():
718+
if src_name not in self.shader_template_params:
719+
self.output_file_map[src_name] = (
720+
src_file_fullpath,
716721
self.create_shader_params(),
717722
)
718723

@@ -763,56 +768,88 @@ def generateSPV( # noqa: C901
763768
output_file_map = {}
764769

765770
def process_shader(shader_paths_pair):
766-
shader_name = shader_paths_pair[0]
771+
src_file_name = shader_paths_pair[0]
772+
773+
src_file_fullpath = shader_paths_pair[1][0]
774+
codegen_params = shader_paths_pair[1][1]
767775

768-
source_glsl = shader_paths_pair[1][0]
769-
shader_params = shader_paths_pair[1][1]
776+
requires_codegen = True
777+
if "YAML_SRC_FULLPATH" not in codegen_params:
778+
requires_codegen = False
770779

771-
glsl_out_path = os.path.join(output_dir, f"{shader_name}.glsl")
772-
spv_out_path = os.path.join(output_dir, f"{shader_name}.spv")
780+
src_file_ext = extract_extension(src_file_fullpath)
781+
out_file_ext = src_file_ext
782+
compile_spv = False
783+
784+
if out_file_ext == "glsl":
785+
compile_spv = True
786+
787+
gen_out_path = os.path.join(output_dir, f"{src_file_name}.{out_file_ext}")
788+
spv_out_path = None
789+
if compile_spv:
790+
spv_out_path = os.path.join(output_dir, f"{src_file_name}.spv")
773791

774792
if cache_dir is not None:
775-
cached_source_glsl = os.path.join(
776-
cache_dir, os.path.basename(source_glsl) + ".t"
793+
cached_src_file_fullpath = os.path.join(
794+
cache_dir, os.path.basename(src_file_fullpath) + ".t"
795+
)
796+
cached_codegen_yaml = os.path.join(cache_dir, f"{src_file_name}.yaml")
797+
cached_gen_out_path = os.path.join(
798+
cache_dir, f"{src_file_name}.{out_file_ext}"
777799
)
778-
cached_glsl_out_path = os.path.join(cache_dir, f"{shader_name}.glsl")
779-
cached_spv_out_path = os.path.join(cache_dir, f"{shader_name}.spv")
800+
cached_spv_out_path = os.path.join(cache_dir, f"{src_file_name}.spv")
780801
if (
781802
not force_rebuild
782-
and os.path.exists(cached_source_glsl)
783-
and os.path.exists(cached_glsl_out_path)
784-
and os.path.exists(cached_spv_out_path)
803+
and os.path.exists(cached_src_file_fullpath)
804+
and os.path.exists(cached_gen_out_path)
805+
and (not requires_codegen or os.path.exists(cached_codegen_yaml))
806+
and (not compile_spv or os.path.exists(cached_spv_out_path))
785807
):
786-
current_checksum = self.get_md5_checksum(source_glsl)
787-
cached_checksum = self.get_md5_checksum(cached_source_glsl)
808+
current_checksum = self.get_md5_checksum(src_file_fullpath)
809+
cached_checksum = self.get_md5_checksum(cached_src_file_fullpath)
810+
yaml_unchanged = True
811+
if requires_codegen:
812+
yaml_file_fullpath = codegen_params["YAML_SRC_FULLPATH"]
813+
current_yaml_checksum = self.get_md5_checksum(
814+
yaml_file_fullpath
815+
)
816+
cached_yaml_checksum = self.get_md5_checksum(
817+
cached_codegen_yaml
818+
)
819+
yaml_unchanged = current_yaml_checksum == cached_yaml_checksum
788820
# If the cached source GLSL template is the same as the current GLSL
789821
# source file, then assume that the generated GLSL and SPIR-V will
790822
# not have changed. In that case, just copy over the GLSL and SPIR-V
791823
# files from the cache.
792-
if current_checksum == cached_checksum:
793-
shutil.copyfile(cached_spv_out_path, spv_out_path)
794-
shutil.copyfile(cached_glsl_out_path, glsl_out_path)
795-
return (spv_out_path, glsl_out_path)
824+
if yaml_unchanged and current_checksum == cached_checksum:
825+
shutil.copyfile(cached_gen_out_path, gen_out_path)
826+
if compile_spv:
827+
shutil.copyfile(cached_spv_out_path, spv_out_path)
828+
return (spv_out_path, gen_out_path)
796829

797-
with codecs.open(source_glsl, "r", encoding="utf-8") as input_file:
830+
with codecs.open(src_file_fullpath, "r", encoding="utf-8") as input_file:
798831
input_text = input_file.read()
799832
input_text = self.maybe_replace_u16vecn(input_text)
800-
output_text = preprocess(input_text, shader_params)
833+
output_text = preprocess(input_text, codegen_params)
801834

802-
with codecs.open(glsl_out_path, "w", encoding="utf-8") as output_file:
835+
with codecs.open(gen_out_path, "w", encoding="utf-8") as output_file:
803836
output_file.write(output_text)
804837

805838
if cache_dir is not None:
806839
# Otherwise, store the generated GLSL files in the cache
807-
shutil.copyfile(glsl_out_path, cached_glsl_out_path)
808-
809-
# If no GLSL compiler is specified, then only write out the generated GLSL shaders.
810-
# This is mainly for testing purposes.
811-
if self.glslc_path is not None:
840+
shutil.copyfile(gen_out_path, cached_gen_out_path)
841+
# If a YAML file was used to configure codegen, cache it as well
842+
if requires_codegen:
843+
yaml_file_fullpath = codegen_params["YAML_SRC_FULLPATH"]
844+
shutil.copyfile(yaml_file_fullpath, cached_codegen_yaml)
845+
846+
# If no GLSL compiler is specified, or the source file is not a GLSL shader
847+
# then only write out the generated GLSL shaders.
848+
if compile_spv and self.glslc_path is not None:
812849
cmd_base = [
813850
self.glslc_path,
814851
"-fshader-stage=compute",
815-
glsl_out_path,
852+
gen_out_path,
816853
"-o",
817854
spv_out_path,
818855
"--target-env=vulkan1.1",
@@ -828,7 +865,7 @@ def process_shader(shader_paths_pair):
828865
subprocess.run(cmd, check=True, capture_output=True, text=True)
829866
except subprocess.CalledProcessError as e:
830867
opt_fail = "compilation succeeded but failed to optimize"
831-
err_msg_base = f"Failed to compile {os.getcwd()}/{glsl_out_path}: "
868+
err_msg_base = f"Failed to compile {os.getcwd()}/{gen_out_path}: "
832869
if opt_fail in e.stderr or opt_fail in e.stdout:
833870
cmd_no_opt = cmd_base + self.glslc_flags_no_opt
834871
try:
@@ -844,23 +881,23 @@ def process_shader(shader_paths_pair):
844881
if cache_dir is not None:
845882
shutil.copyfile(spv_out_path, cached_spv_out_path)
846883

847-
return (spv_out_path, glsl_out_path)
884+
return (spv_out_path, gen_out_path)
848885

849886
# Parallelize shader compilation as much as possible to optimize build time.
850887
with ThreadPool(os.cpu_count()) as pool:
851888
for spv_out_path, glsl_out_path in pool.map(
852-
process_shader, self.output_shader_map.items()
889+
process_shader, self.output_file_map.items()
853890
):
854891
output_file_map[spv_out_path] = glsl_out_path
855892

856893
# Save all source GLSL files to the cache. Only do this at the very end since
857894
# multiple variants may use the same source file.
858895
if cache_dir is not None:
859-
for _, source_glsl in self.glsl_src_files.items():
860-
cached_source_glsl = os.path.join(
861-
cache_dir, os.path.basename(source_glsl) + ".t"
896+
for _, src_file_fullpath in self.src_files.items():
897+
cached_src_file = os.path.join(
898+
cache_dir, os.path.basename(src_file_fullpath) + ".t"
862899
)
863-
shutil.copyfile(source_glsl, cached_source_glsl)
900+
shutil.copyfile(src_file_fullpath, cached_src_file)
864901

865902
return output_file_map
866903

@@ -1100,6 +1137,9 @@ def genCppFiles(
11001137
shader_registry_strs = []
11011138

11021139
for spvPath, srcPath in spv_files.items():
1140+
if spvPath is None:
1141+
continue
1142+
11031143
name = getName(spvPath).replace("_spv", "")
11041144

11051145
sizeBytes, spv_bin_str = generateSpvBinStr(spvPath, name)

backends/vulkan/test/glsl/reference_matmul.glsl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ ${layout_declare_ubo(8, "ivec4", "mat2_strides")}
2424

2525
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2626

27+
#include "reference_matmul_common_buffer.glslh"
28+
2729
void main() {
2830
const ivec2 out_idx = ivec2(gl_GlobalInvocationID.x, gl_GlobalInvocationID.y);
2931
if (any(greaterThanEqual(out_idx, out_sizes.xy))) {
@@ -37,10 +39,7 @@ void main() {
3739

3840
float sum = 0.0;
3941
for (int i = 0; i < mat1_sizes.x; ++i) {
40-
sum += t_mat1[mat1_id] * t_mat2[mat2_id];
41-
42-
mat1_id += mat1_strides.x;
43-
mat2_id += mat2_strides.y;
42+
sum += perform_dot_product(out_idx.y, out_idx.x, i);
4443
}
4544

4645
const int out_id = out_idx.x * out_strides.x + out_idx.y * out_strides.y;
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#ifndef MATMUL_COMMON_${STORAGE}_H
2+
#define MATMUL_COMMON_${STORAGE}_H
3+
4+
$if STORAGE == "buffer":
5+
float perform_dot_product(
6+
const uint out_row,
7+
const uint out_col,
8+
const uint k) {
9+
const uint mat1_bufi = out_row * mat1_strides.y + k * mat1_strides.x;
10+
const uint mat2_bufi = k * mat2_strides.y + out_col * mat2_strides.x;
11+
12+
return t_mat1[mat1_bufi] * t_mat2[mat2_bufi];
13+
}
14+
$else:
15+
vec4 perform_dot_product(
16+
const uint out_row,
17+
const uint out_col,
18+
const uint k) {
19+
vec4 mat1_tex = texelFetch(t_mat1, ivec3(k, out_row, 0), 0);
20+
vec4 mat2_tex = texelFetch(t_mat2, ivec3(out_col, k, 0), 0);
21+
22+
return dot(mat1_tex, mat2_tex);
23+
}
24+
25+
#endif
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
reference_matmul_common:
2+
parameter_names_with_default_values:
3+
STORAGE: buffer
4+
generate_variant_forall:
5+
STORAGE:
6+
- VALUE: buffer
7+
- VALUE: texture3d
8+
shader_variants:
9+
- NAME: reference_matmul_common

0 commit comments

Comments
 (0)