3
3
from __future__ import annotations
4
4
from dataclasses import dataclass
5
5
from instructions import *
6
- from typing import Any ,Iterable ,Callable ,Optional ,Tuple
6
+ from typing import Any , Iterable , Callable , Optional , Tuple
7
7
import argparse
8
8
import fileinput
9
9
import inspect
@@ -54,8 +54,10 @@ def parseInstruction(i):
54
54
# - if 2 subsequent delimiters will mean 2 pieces. One with only the first delimiter, and the second
55
55
# with the delimiter and following instructions.
56
56
# - if the first instruction is a delimiter, the first piece will begin with this delimiter.
57
- def splitInstructions (splitType : type , instructions : Iterable [Instruction ]) -> list [list [Instruction ]]:
58
- blocks : list [list [Instruction ]] = [[]]
57
+ def splitInstructions (
58
+ splitType : type , instructions : Iterable [Instruction ]
59
+ ) -> list [list [Instruction ]]:
60
+ blocks : list [list [Instruction ]] = [[]]
59
61
for instruction in instructions :
60
62
if isinstance (instruction , splitType ) and len (blocks [- 1 ]) > 0 :
61
63
blocks .append ([])
@@ -171,6 +173,7 @@ def __add__(self, value: int):
171
173
self .function , self .basic_block , self .instruction_index + value
172
174
)
173
175
176
+
174
177
# Defines a Lane in this simulator.
175
178
class Lane :
176
179
# The registers known by this lane.
@@ -185,12 +188,12 @@ class Lane:
185
188
# The first element is the IP the function will return to.
186
189
# The second element is the callback to call to store the return value
187
190
# into the correct register.
188
- _callstack : list [Tuple [InstructionPointer , Callable [[Any ], None ] ]]
191
+ _callstack : list [Tuple [InstructionPointer , Callable [[Any ], None ]]]
189
192
190
- _previous_bb : Optional [BasicBlock ]
191
- _current_bb : Optional [BasicBlock ]
193
+ _previous_bb : Optional [BasicBlock ]
194
+ _current_bb : Optional [BasicBlock ]
192
195
193
- def __init__ (self , wave : Wave , tid : int ) -> None :
196
+ def __init__ (self , wave : Wave , tid : int ) -> None :
194
197
self ._registers = dict ()
195
198
self ._ip = None
196
199
self ._running = True
@@ -213,7 +216,7 @@ def is_first_active_lane(self) -> bool:
213
216
return self ._tid == self ._wave .get_first_active_lane_index ()
214
217
215
218
# Broadcast value into the registers of all active lanes.
216
- def broadcast_register (self , register : str , value : Any ) -> None :
219
+ def broadcast_register (self , register : str , value : Any ) -> None :
217
220
self ._wave .broadcast_register (register , value )
218
221
219
222
# Returns the IP this lane is currently at.
@@ -227,17 +230,17 @@ def running(self) -> bool:
227
230
return self ._running
228
231
229
232
# Set the register at "name" to "value" in this lane.
230
- def set_register (self , name : str , value : Any ) -> None :
233
+ def set_register (self , name : str , value : Any ) -> None :
231
234
self ._registers [name ] = value
232
235
233
236
# Get the value in register "name" in this lane.
234
237
# if allow_undef is true, fetching an unknown register won't fail.
235
- def get_register (self , name : str , allow_undef : bool = False ) -> Optional [Any ]:
238
+ def get_register (self , name : str , allow_undef : bool = False ) -> Optional [Any ]:
236
239
if allow_undef and name not in self ._registers :
237
240
return None
238
241
return self ._registers [name ]
239
242
240
- def set_ip (self , ip : InstructionPointer ) -> None :
243
+ def set_ip (self , ip : InstructionPointer ) -> None :
241
244
if ip .bb () != self ._current_bb :
242
245
self ._previous_bb = self ._current_bb
243
246
self ._current_bb = ip .bb ()
@@ -269,11 +272,11 @@ def do_return(self, value):
269
272
270
273
# Represents the SPIR-V module in the simulator.
271
274
class Module :
272
- _functions : dict [str , Function ]
273
- _prolog : list [Instruction ]
274
- _globals : list [Instruction ]
275
- _name2reg : dict [str , str ]
276
- _reg2name : dict [str , str ]
275
+ _functions : dict [str , Function ]
276
+ _prolog : list [Instruction ]
277
+ _globals : list [Instruction ]
278
+ _name2reg : dict [str , str ]
279
+ _reg2name : dict [str , str ]
277
280
278
281
def __init__ (self , instructions ) -> None :
279
282
chunks = splitInstructions (OpFunction , instructions )
@@ -388,21 +391,23 @@ class ConvergenceRequirement:
388
391
continueTarget : Optional [InstructionPointer ]
389
392
impactedLanes : set [int ]
390
393
394
+
391
395
Task = dict [InstructionPointer , list [Lane ]]
392
396
397
+
393
398
# Defines a Lane group/Wave in the simulator.
394
399
class Wave :
395
400
# The module this wave will execute.
396
- _module : Module
401
+ _module : Module
397
402
# The lanes this wave will be composed of.
398
- _lanes : list [Lane ]
403
+ _lanes : list [Lane ]
399
404
# The instructions scheduled for execution.
400
- _tasks : Task
405
+ _tasks : Task
401
406
# The actual requirements to comply with when executing instructions.
402
407
# e.g: the set of lanes required to merge before executing the merge block.
403
- _convergence_requirements : list [ConvergenceRequirement ]
408
+ _convergence_requirements : list [ConvergenceRequirement ]
404
409
# The indices of the active lanes for the current executing instruction.
405
- _active_lane_indices : set [int ]
410
+ _active_lane_indices : set [int ]
406
411
407
412
def __init__ (self , module , wave_size : int ) -> None :
408
413
assert wave_size > 0
@@ -419,7 +424,7 @@ def __init__(self, module, wave_size: int) -> None:
419
424
420
425
# Returns True if the given IP can be executed for the given list of lanes.
421
426
def _is_task_candidate (self , ip : InstructionPointer , lanes : list [Lane ]):
422
- merged_lanes : set [int ] = set ()
427
+ merged_lanes : set [int ] = set ()
423
428
for lane in self ._lanes :
424
429
if not lane .running ():
425
430
merged_lanes .add (lane .tid ())
@@ -498,7 +503,7 @@ def get_first_active_lane_index(self) -> int:
498
503
return min (self ._active_lane_indices )
499
504
500
505
# Broadcast the given value to all active lane registers'.
501
- def broadcast_register (self , register : str , value : Any ) -> None :
506
+ def broadcast_register (self , register : str , value : Any ) -> None :
502
507
for tid in self ._active_lane_indices :
503
508
self ._lanes [tid ].set_register (register , value )
504
509
@@ -512,7 +517,7 @@ def _get_function_entry_from_name(self, name: str) -> InstructionPointer:
512
517
# Run the wave on the function 'function_name' until all lanes are dead.
513
518
# If verbose is True, execution trace is printed.
514
519
# Returns the value returned by the function for each lane.
515
- def run (self , function_name : str , verbose : bool = False ) -> list [Any ]:
520
+ def run (self , function_name : str , verbose : bool = False ) -> list [Any ]:
516
521
for t in self ._lanes :
517
522
self ._module .initialize (t )
518
523
@@ -551,7 +556,7 @@ def run(self, function_name : str, verbose: bool = False) -> list[Any]:
551
556
output .append (lane .get_register ("__shader_output__" ))
552
557
return output
553
558
554
- def dump_register (self , register : str ) -> None :
559
+ def dump_register (self , register : str ) -> None :
555
560
for lane in self ._lanes :
556
561
print (
557
562
f" Lane { lane .tid ():2} | { register :3} = { lane .get_register (register )} "
@@ -575,7 +580,7 @@ def dump_register(self, register : str) -> None:
575
580
args = parser .parse_args ()
576
581
577
582
578
- def load_instructions (filename : str ):
583
+ def load_instructions (filename : str ):
579
584
if filename is None :
580
585
return []
581
586
0 commit comments