Skip to content

Commit d23548b

Browse files
authored
[ET-VK][BE][ez] Enable automatic layout slot index incrementing
Differential Revision: D62210119 Pull Request resolved: #5091
1 parent e119d51 commit d23548b

File tree

1 file changed

+36
-9
lines changed

1 file changed

+36
-9
lines changed

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838
# Basic configuration settings for shaders
3939
DEFAULT_ENV: Dict[str, Any] = {
4040
"PRECISION": "highp",
41+
# B is shorthand for "binding". This is used to automatically increment the
42+
# layout binding index when declaring layout bindings. Note that a container
43+
# type is used because integers are immutable in Python.
44+
"B": [0],
4145
}
4246

4347
# Establishes relationships between different tensor types and different GLSL types
@@ -179,8 +183,14 @@ def get_access_qualifier(access_type: Optional[str]) -> str:
179183
raise AssertionError(f"Invalid access type: {access_type}")
180184

181185

186+
def get_slot_val(slot: Union[int, List[int]]) -> int:
187+
if isinstance(slot, list):
188+
return slot[0]
189+
return slot
190+
191+
182192
def layout_declare_buffer(
183-
slot: int,
193+
slot: Union[int, List[int]],
184194
access_type: str,
185195
var_name: str,
186196
dtype: str,
@@ -192,15 +202,18 @@ def layout_declare_buffer(
192202
array_type = buffer_scalar_type(dtype)
193203

194204
out_str = f"""
195-
layout(set = 0, binding = {slot}) buffer {precision} restrict {get_access_qualifier(access_type)} {var_name}Buffer {{
205+
layout(set = 0, binding = {get_slot_val(slot)}) buffer {precision} restrict {get_access_qualifier(access_type)} {var_name}Buffer {{
196206
{array_type} {var_name}[];
197207
}};
198208
"""
209+
210+
if isinstance(slot, list):
211+
slot[0] = slot[0] + 1
199212
return out_str
200213

201214

202215
def layout_declare_image(
203-
slot: int,
216+
slot: Union[int, List[int]],
204217
access_type: str,
205218
var_name: str,
206219
dtype: str,
@@ -209,11 +222,16 @@ def layout_declare_image(
209222
) -> str:
210223
image_format = TYPE_MAPPINGS["IMAGE_FORMAT"][dtype]
211224
image_type = TYPE_MAPPINGS["IMAGE_T"][image_ndim][dtype]
212-
return f"layout(set = 0, binding = {slot}, {image_format}) uniform {precision} restrict {get_access_qualifier(access_type)} {image_type} {var_name};"
225+
226+
ret_str = f"layout(set = 0, binding = {get_slot_val(slot)}, {image_format}) uniform {precision} restrict {get_access_qualifier(access_type)} {image_type} {var_name};"
227+
228+
if isinstance(slot, list):
229+
slot[0] = slot[0] + 1
230+
return ret_str
213231

214232

215233
def layout_declare_sampler(
216-
slot: int,
234+
slot: Union[int, List[int]],
217235
access_type: str,
218236
var_name: str,
219237
dtype: str,
@@ -222,11 +240,16 @@ def layout_declare_sampler(
222240
image_ndim: int = 3,
223241
) -> str:
224242
sampler_type = TYPE_MAPPINGS["SAMPLER_T"][image_ndim][dtype]
225-
return f"layout(set = 0, binding = {slot}) uniform {precision} {sampler_type} {var_name};"
243+
244+
ret_str = f"layout(set = 0, binding = {get_slot_val(slot)}) uniform {precision} {sampler_type} {var_name};"
245+
246+
if isinstance(slot, list):
247+
slot[0] = slot[0] + 1
248+
return ret_str
226249

227250

228251
def layout_declare_tensor(
229-
slot: int,
252+
slot: Union[int, List[int]],
230253
access_type: str,
231254
var_name: str,
232255
dtype: str,
@@ -262,7 +285,9 @@ def layout_declare_tensor(
262285
)
263286

264287

265-
def layout_declare_ubo(slot: int, *args, precision: str = "PRECISION") -> str:
288+
def layout_declare_ubo(
289+
slot: Union[int, List[int]], *args, precision: str = "PRECISION"
290+
) -> str:
266291
assert len(args) % 2 == 0
267292

268293
var_list = list(zip(args[::2], args[1::2]))
@@ -272,12 +297,14 @@ def layout_declare_ubo(slot: int, *args, precision: str = "PRECISION") -> str:
272297
ubo_name += var_name + "_"
273298

274299
out_str = f"""
275-
layout(set = 0, binding = {slot}) uniform {precision} restrict readonly {ubo_name}UBO {{
300+
layout(set = 0, binding = {get_slot_val(slot)}) uniform {precision} restrict readonly {ubo_name}UBO {{
276301
"""
277302
for type_name, var_name in var_list:
278303
out_str += f"{type_name} {var_name};\n"
279304
out_str += "};"
280305

306+
if isinstance(slot, list):
307+
slot[0] = slot[0] + 1
281308
return out_str
282309

283310

0 commit comments

Comments
 (0)