Skip to content

Commit c703225

Browse files
committed
format
Signed-off-by: Nathan Gauër <[email protected]>
1 parent 6ab5f43 commit c703225

File tree

1 file changed

+31
-26
lines changed

1 file changed

+31
-26
lines changed

llvm/utils/spirv-sim/spirv-sim.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44
from dataclasses import dataclass
55
from instructions import *
6-
from typing import Any,Iterable,Callable,Optional,Tuple
6+
from typing import Any, Iterable, Callable, Optional, Tuple
77
import argparse
88
import fileinput
99
import inspect
@@ -54,8 +54,10 @@ def parseInstruction(i):
5454
# - if 2 subsequent delimiters will mean 2 pieces. One with only the first delimiter, and the second
5555
# with the delimiter and following instructions.
5656
# - if the first instruction is a delimiter, the first piece will begin with this delimiter.
57-
def splitInstructions(splitType: type, instructions: Iterable[Instruction]) -> list[list[Instruction]]:
58-
blocks : list[list[Instruction]] = [[]]
57+
def splitInstructions(
58+
splitType: type, instructions: Iterable[Instruction]
59+
) -> list[list[Instruction]]:
60+
blocks: list[list[Instruction]] = [[]]
5961
for instruction in instructions:
6062
if isinstance(instruction, splitType) and len(blocks[-1]) > 0:
6163
blocks.append([])
@@ -171,6 +173,7 @@ def __add__(self, value: int):
171173
self.function, self.basic_block, self.instruction_index + value
172174
)
173175

176+
174177
# Defines a Lane in this simulator.
175178
class Lane:
176179
# The registers known by this lane.
@@ -185,12 +188,12 @@ class Lane:
185188
# The first element is the IP the function will return to.
186189
# The second element is the callback to call to store the return value
187190
# into the correct register.
188-
_callstack: list[Tuple[InstructionPointer, Callable[[Any], None] ]]
191+
_callstack: list[Tuple[InstructionPointer, Callable[[Any], None]]]
189192

190-
_previous_bb : Optional[BasicBlock]
191-
_current_bb : Optional[BasicBlock]
193+
_previous_bb: Optional[BasicBlock]
194+
_current_bb: Optional[BasicBlock]
192195

193-
def __init__(self, wave : Wave, tid : int) -> None:
196+
def __init__(self, wave: Wave, tid: int) -> None:
194197
self._registers = dict()
195198
self._ip = None
196199
self._running = True
@@ -213,7 +216,7 @@ def is_first_active_lane(self) -> bool:
213216
return self._tid == self._wave.get_first_active_lane_index()
214217

215218
# Broadcast value into the registers of all active lanes.
216-
def broadcast_register(self, register : str, value : Any) -> None:
219+
def broadcast_register(self, register: str, value: Any) -> None:
217220
self._wave.broadcast_register(register, value)
218221

219222
# Returns the IP this lane is currently at.
@@ -227,17 +230,17 @@ def running(self) -> bool:
227230
return self._running
228231

229232
# Set the register at "name" to "value" in this lane.
230-
def set_register(self, name : str, value : Any) -> None:
233+
def set_register(self, name: str, value: Any) -> None:
231234
self._registers[name] = value
232235

233236
# Get the value in register "name" in this lane.
234237
# if allow_undef is true, fetching an unknown register won't fail.
235-
def get_register(self, name : str, allow_undef : bool = False) -> Optional[Any]:
238+
def get_register(self, name: str, allow_undef: bool = False) -> Optional[Any]:
236239
if allow_undef and name not in self._registers:
237240
return None
238241
return self._registers[name]
239242

240-
def set_ip(self, ip : InstructionPointer) -> None:
243+
def set_ip(self, ip: InstructionPointer) -> None:
241244
if ip.bb() != self._current_bb:
242245
self._previous_bb = self._current_bb
243246
self._current_bb = ip.bb()
@@ -269,11 +272,11 @@ def do_return(self, value):
269272

270273
# Represents the SPIR-V module in the simulator.
271274
class Module:
272-
_functions : dict[str, Function]
273-
_prolog : list[Instruction]
274-
_globals : list[Instruction]
275-
_name2reg : dict[str, str]
276-
_reg2name : dict[str, str]
275+
_functions: dict[str, Function]
276+
_prolog: list[Instruction]
277+
_globals: list[Instruction]
278+
_name2reg: dict[str, str]
279+
_reg2name: dict[str, str]
277280

278281
def __init__(self, instructions) -> None:
279282
chunks = splitInstructions(OpFunction, instructions)
@@ -388,21 +391,23 @@ class ConvergenceRequirement:
388391
continueTarget: Optional[InstructionPointer]
389392
impactedLanes: set[int]
390393

394+
391395
Task = dict[InstructionPointer, list[Lane]]
392396

397+
393398
# Defines a Lane group/Wave in the simulator.
394399
class Wave:
395400
# The module this wave will execute.
396-
_module : Module
401+
_module: Module
397402
# The lanes this wave will be composed of.
398-
_lanes : list[Lane]
403+
_lanes: list[Lane]
399404
# The instructions scheduled for execution.
400-
_tasks : Task
405+
_tasks: Task
401406
# The actual requirements to comply with when executing instructions.
402407
# e.g: the set of lanes required to merge before executing the merge block.
403-
_convergence_requirements : list[ConvergenceRequirement]
408+
_convergence_requirements: list[ConvergenceRequirement]
404409
# The indices of the active lanes for the current executing instruction.
405-
_active_lane_indices : set[int]
410+
_active_lane_indices: set[int]
406411

407412
def __init__(self, module, wave_size: int) -> None:
408413
assert wave_size > 0
@@ -419,7 +424,7 @@ def __init__(self, module, wave_size: int) -> None:
419424

420425
# Returns True if the given IP can be executed for the given list of lanes.
421426
def _is_task_candidate(self, ip: InstructionPointer, lanes: list[Lane]):
422-
merged_lanes : set[int] = set()
427+
merged_lanes: set[int] = set()
423428
for lane in self._lanes:
424429
if not lane.running():
425430
merged_lanes.add(lane.tid())
@@ -498,7 +503,7 @@ def get_first_active_lane_index(self) -> int:
498503
return min(self._active_lane_indices)
499504

500505
# Broadcast the given value to all active lane registers'.
501-
def broadcast_register(self, register : str, value : Any) -> None:
506+
def broadcast_register(self, register: str, value: Any) -> None:
502507
for tid in self._active_lane_indices:
503508
self._lanes[tid].set_register(register, value)
504509

@@ -512,7 +517,7 @@ def _get_function_entry_from_name(self, name: str) -> InstructionPointer:
512517
# Run the wave on the function 'function_name' until all lanes are dead.
513518
# If verbose is True, execution trace is printed.
514519
# Returns the value returned by the function for each lane.
515-
def run(self, function_name : str, verbose: bool = False) -> list[Any]:
520+
def run(self, function_name: str, verbose: bool = False) -> list[Any]:
516521
for t in self._lanes:
517522
self._module.initialize(t)
518523

@@ -551,7 +556,7 @@ def run(self, function_name : str, verbose: bool = False) -> list[Any]:
551556
output.append(lane.get_register("__shader_output__"))
552557
return output
553558

554-
def dump_register(self, register : str) -> None:
559+
def dump_register(self, register: str) -> None:
555560
for lane in self._lanes:
556561
print(
557562
f" Lane {lane.tid():2} | {register:3} = {lane.get_register(register)}"
@@ -575,7 +580,7 @@ def dump_register(self, register : str) -> None:
575580
args = parser.parse_args()
576581

577582

578-
def load_instructions(filename : str):
583+
def load_instructions(filename: str):
579584
if filename is None:
580585
return []
581586

0 commit comments

Comments
 (0)