Skip to content

Commit ab1c8aa

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Fix zero size tensors (#3702)
Summary: Pull Request resolved: #3702 ## Context Dispatching a command buffer with a work group size that contains 0 is undefined behaviour. On some devices, this can cause the device to be lost. Fix this by setting the work group size to `{1, 1, 1}` right before dispatching a command buffer if the work group size contains a 0. Reviewed By: yipjustin Differential Revision: D57655257 fbshipit-source-id: 6209668c960f0e0afb0de0ab8b09c285e2de56b9
1 parent 705ac96 commit ab1c8aa

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

backends/vulkan/runtime/api/Context.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,25 @@ void Context::register_shader_dispatch(
9191
const ShaderInfo& shader_descriptor,
9292
const utils::uvec3& global_workgroup_size) {
9393
// Adjust the global workgroup size based on the output tile size
94+
uint32_t global_wg_w = utils::div_up(
95+
global_workgroup_size.data[0u], shader_descriptor.out_tile_size.data[0u]);
96+
uint32_t global_wg_h = utils::div_up(
97+
global_workgroup_size.data[1u], shader_descriptor.out_tile_size.data[1u]);
98+
uint32_t global_wg_d = utils::div_up(
99+
global_workgroup_size.data[2u], shader_descriptor.out_tile_size.data[2u]);
100+
101+
// Submitting a global work group size of 0 is undefined behaviour. If this is
102+
// detected then submit a single workgroup instead.
103+
if (global_wg_w == 0u || global_wg_h == 0u || global_wg_d == 0u) {
104+
global_wg_w = 1u;
105+
global_wg_h = 1u;
106+
global_wg_d = 1u;
107+
}
108+
94109
const utils::uvec3 effective_global_wg = {
95-
utils::div_up(
96-
global_workgroup_size.data[0u],
97-
shader_descriptor.out_tile_size.data[0u]),
98-
utils::div_up(
99-
global_workgroup_size.data[1u],
100-
shader_descriptor.out_tile_size.data[1u]),
101-
utils::div_up(
102-
global_workgroup_size.data[2u],
103-
shader_descriptor.out_tile_size.data[2u]),
110+
global_wg_w,
111+
global_wg_h,
112+
global_wg_d,
104113
};
105114

106115
cmd_.bind_descriptors(descriptors.get_bind_handle());

backends/vulkan/test/op_tests/utils/codegen.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -531,8 +531,11 @@ def gen_graph_build_code(self) -> str:
531531

532532
return graph_build
533533

534-
def gen_graph_exec_code(self) -> str:
534+
def gen_graph_exec_code(self, loop_range: int = 1) -> str:
535535
graph_exec = ""
536+
if loop_range > 1:
537+
graph_exec += f"for (int i = 0; i < {loop_range} ; ++i) "
538+
graph_exec += "{\n"
536539
for aten_arg in self.args:
537540
ref = self.refs[aten_arg.name]
538541
if ref.is_in:
@@ -544,6 +547,8 @@ def gen_graph_exec_code(self) -> str:
544547

545548
graph_exec += self.declare_vk_out_for(self.refs["out"])
546549
graph_exec += self.copy_from_staging(self.refs["out"])
550+
graph_exec += self.check_graph_out(self.refs["out"])
551+
graph_exec += "}\n"
547552

548553
return graph_exec
549554

@@ -564,7 +569,6 @@ def gen_op_check_fn(self) -> str:
564569
op_check_fn_body += self.gen_conditional_skips()
565570
op_check_fn_body += self.gen_graph_build_code()
566571
op_check_fn_body += self.gen_graph_exec_code()
567-
op_check_fn_body += self.check_graph_out(self.refs["out"])
568572

569573
# Add two level of indent for readability
570574
op_check_fn_body = re.sub(r"^", " ", op_check_fn_body, flags=re.M)

0 commit comments

Comments
 (0)