Skip to content

Commit b5558cd

Browse files
authored
Refactor code generators a bit (GH-128920)
Refactor code generators a bit to avoid passing stack property around all over the place
1 parent d66c08a commit b5558cd

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

Tools/cases_generator/optimizer_generator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def write_uop(
126126
try:
127127
out.start_line()
128128
if override:
129-
code_list, storage = Storage.for_uop(stack, prototype, extract_bits=False)
129+
code_list, storage = Storage.for_uop(stack, prototype)
130130
for code in code_list:
131131
out.emit(code)
132132
if debug:
@@ -151,11 +151,11 @@ def write_uop(
151151
var.defined = False
152152
storage = emitter.emit_tokens(override, storage, None)
153153
out.start_line()
154-
storage.flush(out, cast_type="_Py_UopsSymbol *", extract_bits=False)
154+
storage.flush(out, cast_type="_Py_UopsSymbol *")
155155
else:
156156
emit_default(out, uop, stack)
157157
out.start_line()
158-
stack.flush(out, cast_type="_Py_UopsSymbol *", extract_bits=False)
158+
stack.flush(out, cast_type="_Py_UopsSymbol *")
159159
except StackError as ex:
160160
raise analysis_error(ex.args[0], prototype.body[0]) # from None
161161

@@ -198,7 +198,7 @@ def generate_abstract_interpreter(
198198
declare_variables(override, out, skip_inputs=False)
199199
else:
200200
declare_variables(uop, out, skip_inputs=True)
201-
stack = Stack()
201+
stack = Stack(False)
202202
write_uop(override, uop, out, stack, debug, skip_inputs=(override is None))
203203
out.start_line()
204204
out.emit("break;\n")

Tools/cases_generator/stack.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,14 @@ def array_or_scalar(var: StackItem | Local) -> str:
224224
return "array" if var.is_array() else "scalar"
225225

226226
class Stack:
227-
def __init__(self) -> None:
227+
def __init__(self, extract_bits: bool=True) -> None:
228228
self.top_offset = StackOffset.empty()
229229
self.base_offset = StackOffset.empty()
230230
self.variables: list[Local] = []
231231
self.defined: set[str] = set()
232+
self.extract_bits = extract_bits
232233

233-
def pop(self, var: StackItem, extract_bits: bool = True) -> tuple[str, Local]:
234+
def pop(self, var: StackItem) -> tuple[str, Local]:
234235
self.top_offset.pop(var)
235236
indirect = "&" if var.is_array() else ""
236237
if self.variables:
@@ -272,7 +273,7 @@ def pop(self, var: StackItem, extract_bits: bool = True) -> tuple[str, Local]:
272273
return "", Local.unused(var)
273274
self.defined.add(var.name)
274275
cast = f"({var.type})" if (not indirect and var.type) else ""
275-
bits = ".bits" if cast and extract_bits else ""
276+
bits = ".bits" if cast and self.extract_bits else ""
276277
assign = f"{var.name} = {cast}{indirect}stack_pointer[{self.base_offset.to_c()}]{bits};"
277278
if var.condition:
278279
if var.condition == "1":
@@ -315,7 +316,7 @@ def _adjust_stack_pointer(self, out: CWriter, number: str) -> None:
315316
out.emit("assert(WITHIN_STACK_BOUNDS());\n")
316317

317318
def flush(
318-
self, out: CWriter, cast_type: str = "uintptr_t", extract_bits: bool = True
319+
self, out: CWriter, cast_type: str = "uintptr_t"
319320
) -> None:
320321
out.start_line()
321322
var_offset = self.base_offset.copy()
@@ -324,7 +325,7 @@ def flush(
324325
var.defined and
325326
not var.in_memory
326327
):
327-
Stack._do_emit(out, var.item, var_offset, cast_type, extract_bits)
328+
Stack._do_emit(out, var.item, var_offset, cast_type, self.extract_bits)
328329
var.in_memory = True
329330
var_offset.push(var.item)
330331
number = self.top_offset.to_c()
@@ -346,7 +347,7 @@ def as_comment(self) -> str:
346347
)
347348

348349
def copy(self) -> "Stack":
349-
other = Stack()
350+
other = Stack(self.extract_bits)
350351
other.top_offset = self.top_offset.copy()
351352
other.base_offset = self.base_offset.copy()
352353
other.variables = [var.copy() for var in self.variables]
@@ -507,10 +508,10 @@ def locals_cached(self) -> bool:
507508
return True
508509
return False
509510

510-
def flush(self, out: CWriter, cast_type: str = "uintptr_t", extract_bits: bool = True) -> None:
511+
def flush(self, out: CWriter, cast_type: str = "uintptr_t") -> None:
511512
self.clear_dead_inputs()
512513
self._push_defined_outputs()
513-
self.stack.flush(out, cast_type, extract_bits)
514+
self.stack.flush(out, cast_type)
514515

515516
def save(self, out: CWriter) -> None:
516517
assert self.spilled >= 0
@@ -530,12 +531,12 @@ def reload(self, out: CWriter) -> None:
530531
out.emit("stack_pointer = _PyFrame_GetStackPointer(frame);\n")
531532

532533
@staticmethod
533-
def for_uop(stack: Stack, uop: Uop, extract_bits: bool = True) -> tuple[list[str], "Storage"]:
534+
def for_uop(stack: Stack, uop: Uop) -> tuple[list[str], "Storage"]:
534535
code_list: list[str] = []
535536
inputs: list[Local] = []
536537
peeks: list[Local] = []
537538
for input in reversed(uop.stack.inputs):
538-
code, local = stack.pop(input, extract_bits)
539+
code, local = stack.pop(input)
539540
code_list.append(code)
540541
if input.peek:
541542
peeks.append(local)

0 commit comments

Comments
 (0)