38
38
# Basic configuration settings for shaders
39
39
DEFAULT_ENV : Dict [str , Any ] = {
40
40
"PRECISION" : "highp" ,
41
+ # B is shorthand for "binding". This is used to automatically increment the
42
+ # layout binding index when declaring layout bindings. Note that a container
43
+ # type is used because integers are immutable in Python.
44
+ "B" : [0 ],
41
45
}
42
46
43
47
# Establishes relationships between different tensor types and different GLSL types
@@ -179,8 +183,14 @@ def get_access_qualifier(access_type: Optional[str]) -> str:
179
183
raise AssertionError (f"Invalid access type: { access_type } " )
180
184
181
185
186
+ def get_slot_val (slot : Union [int , List [int ]]) -> int :
187
+ if isinstance (slot , list ):
188
+ return slot [0 ]
189
+ return slot
190
+
191
+
182
192
def layout_declare_buffer (
183
- slot : int ,
193
+ slot : Union [ int , List [ int ]] ,
184
194
access_type : str ,
185
195
var_name : str ,
186
196
dtype : str ,
@@ -192,15 +202,18 @@ def layout_declare_buffer(
192
202
array_type = buffer_scalar_type (dtype )
193
203
194
204
out_str = f"""
195
- layout(set = 0, binding = { slot } ) buffer { precision } restrict { get_access_qualifier (access_type )} { var_name } Buffer {{
205
+ layout(set = 0, binding = { get_slot_val ( slot ) } ) buffer { precision } restrict { get_access_qualifier (access_type )} { var_name } Buffer {{
196
206
{ array_type } { var_name } [];
197
207
}};
198
208
"""
209
+
210
+ if isinstance (slot , list ):
211
+ slot [0 ] = slot [0 ] + 1
199
212
return out_str
200
213
201
214
202
215
def layout_declare_image (
203
- slot : int ,
216
+ slot : Union [ int , List [ int ]] ,
204
217
access_type : str ,
205
218
var_name : str ,
206
219
dtype : str ,
@@ -209,11 +222,16 @@ def layout_declare_image(
209
222
) -> str :
210
223
image_format = TYPE_MAPPINGS ["IMAGE_FORMAT" ][dtype ]
211
224
image_type = TYPE_MAPPINGS ["IMAGE_T" ][image_ndim ][dtype ]
212
- return f"layout(set = 0, binding = { slot } , { image_format } ) uniform { precision } restrict { get_access_qualifier (access_type )} { image_type } { var_name } ;"
225
+
226
+ ret_str = f"layout(set = 0, binding = { get_slot_val (slot )} , { image_format } ) uniform { precision } restrict { get_access_qualifier (access_type )} { image_type } { var_name } ;"
227
+
228
+ if isinstance (slot , list ):
229
+ slot [0 ] = slot [0 ] + 1
230
+ return ret_str
213
231
214
232
215
233
def layout_declare_sampler (
216
- slot : int ,
234
+ slot : Union [ int , List [ int ]] ,
217
235
access_type : str ,
218
236
var_name : str ,
219
237
dtype : str ,
@@ -222,11 +240,16 @@ def layout_declare_sampler(
222
240
image_ndim : int = 3 ,
223
241
) -> str :
224
242
sampler_type = TYPE_MAPPINGS ["SAMPLER_T" ][image_ndim ][dtype ]
225
- return f"layout(set = 0, binding = { slot } ) uniform { precision } { sampler_type } { var_name } ;"
243
+
244
+ ret_str = f"layout(set = 0, binding = { get_slot_val (slot )} ) uniform { precision } { sampler_type } { var_name } ;"
245
+
246
+ if isinstance (slot , list ):
247
+ slot [0 ] = slot [0 ] + 1
248
+ return ret_str
226
249
227
250
228
251
def layout_declare_tensor (
229
- slot : int ,
252
+ slot : Union [ int , List [ int ]] ,
230
253
access_type : str ,
231
254
var_name : str ,
232
255
dtype : str ,
@@ -262,7 +285,9 @@ def layout_declare_tensor(
262
285
)
263
286
264
287
265
- def layout_declare_ubo (slot : int , * args , precision : str = "PRECISION" ) -> str :
288
+ def layout_declare_ubo (
289
+ slot : Union [int , List [int ]], * args , precision : str = "PRECISION"
290
+ ) -> str :
266
291
assert len (args ) % 2 == 0
267
292
268
293
var_list = list (zip (args [::2 ], args [1 ::2 ]))
@@ -272,12 +297,14 @@ def layout_declare_ubo(slot: int, *args, precision: str = "PRECISION") -> str:
272
297
ubo_name += var_name + "_"
273
298
274
299
out_str = f"""
275
- layout(set = 0, binding = { slot } ) uniform { precision } restrict readonly { ubo_name } UBO {{
300
+ layout(set = 0, binding = { get_slot_val ( slot ) } ) uniform { precision } restrict readonly { ubo_name } UBO {{
276
301
"""
277
302
for type_name , var_name in var_list :
278
303
out_str += f"{ type_name } { var_name } ;\n "
279
304
out_str += "};"
280
305
306
+ if isinstance (slot , list ):
307
+ slot [0 ] = slot [0 ] + 1
281
308
return out_str
282
309
283
310
0 commit comments