Skip to content

Commit cc77d6a

Browse files
committed
Support typed stack effects
This was more convoluted than I expected.
1 parent 74ea0a2 commit cc77d6a

File tree

4 files changed

+122
-91
lines changed

4 files changed

+122
-91
lines changed

Python/bytecodes.c

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ static PyObject *value, *value1, *value2, *left, *right, *res, *sum, *prod, *sub
8383
static PyObject *container, *start, *stop, *v, *lhs, *rhs;
8484
static PyObject *list, *tuple, *dict;
8585
static PyObject *exit_func, *lasti, *val;
86-
static PyObject *jump;
87-
// Dummy variables for stack effects
86+
static size_t jump;
87+
// Dummy variables for cache effects
8888
static _Py_CODEUNIT when_to_jump_mask, invert;
8989
// Dummy opcode names for 'op' opcodes
9090
#define _BINARY_OP_INPLACE_ADD_UNICODE_PART_1 1001
@@ -2087,7 +2087,7 @@ dummy_func(
20872087
}
20882088

20892089
// The result is an int disguised as an object pointer.
2090-
op(_COMPARE_OP_FLOAT, (unused/1, left, right, when_to_jump_mask/1 -- jump)) {
2090+
op(_COMPARE_OP_FLOAT, (unused/1, left, right, when_to_jump_mask/1 -- jump: size_t)) {
20912091
assert(cframe.use_tracing == 0);
20922092
// Combined: COMPARE_OP (float ? float) + POP_JUMP_IF_(true/false)
20932093
DEOPT_IF(!PyFloat_CheckExact(left), COMPARE_OP);
@@ -2101,10 +2101,10 @@ dummy_func(
21012101
STAT_INC(COMPARE_OP, hit);
21022102
_Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
21032103
_Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
2104-
jump = (PyObject *)(size_t)(sign_ish & when_to_jump_mask);
2104+
jump = sign_ish & when_to_jump_mask;
21052105
}
21062106
// The input is an int disguised as an object pointer!
2107-
op(_JUMP_ON_SIGN, (jump --)) {
2107+
op(_JUMP_ON_SIGN, (jump: size_t --)) {
21082108
assert(opcode == POP_JUMP_IF_FALSE || opcode == POP_JUMP_IF_TRUE);
21092109
if (jump) {
21102110
JUMPBY(oparg);
@@ -2114,7 +2114,7 @@ dummy_func(
21142114
super(COMPARE_OP_FLOAT_JUMP) = _COMPARE_OP_FLOAT + _JUMP_ON_SIGN;
21152115

21162116
// Similar to COMPARE_OP_FLOAT
2117-
op(_COMPARE_OP_INT, (unused/1, left, right, when_to_jump_mask/1 -- jump)) {
2117+
op(_COMPARE_OP_INT, (unused/1, left, right, when_to_jump_mask/1 -- jump: size_t)) {
21182118
assert(cframe.use_tracing == 0);
21192119
// Combined: COMPARE_OP (int ? int) + POP_JUMP_IF_(true/false)
21202120
DEOPT_IF(!PyLong_CheckExact(left), COMPARE_OP);
@@ -2129,12 +2129,12 @@ dummy_func(
21292129
int sign_ish = 2*(ileft > iright) + 2 - (ileft < iright);
21302130
_Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
21312131
_Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
2132-
jump = (PyObject *)(size_t)(sign_ish & when_to_jump_mask);
2132+
jump = sign_ish & when_to_jump_mask;
21332133
}
21342134
super(COMPARE_OP_INT_JUMP) = _COMPARE_OP_INT + _JUMP_ON_SIGN;
21352135

21362136
// Similar to COMPARE_OP_FLOAT, but for ==, != only
2137-
op(_COMPARE_OP_STR, (unused/1, left, right, invert/1 -- jump)) {
2137+
op(_COMPARE_OP_STR, (unused/1, left, right, invert/1 -- jump: size_t)) {
21382138
assert(cframe.use_tracing == 0);
21392139
// Combined: COMPARE_OP (str == str or str != str) + POP_JUMP_IF_(true/false)
21402140
DEOPT_IF(!PyUnicode_CheckExact(left), COMPARE_OP);
@@ -2146,7 +2146,7 @@ dummy_func(
21462146
_Py_DECREF_SPECIALIZED(right, _PyUnicode_ExactDealloc);
21472147
assert(res == 0 || res == 1);
21482148
assert(invert == 0 || invert == 1);
2149-
jump = (PyObject *)(size_t)(res ^ invert);
2149+
jump = res ^ invert;
21502150
}
21512151
super(COMPARE_OP_STR_JUMP) = _COMPARE_OP_STR + _JUMP_ON_SIGN;
21522152

Python/generated_cases.c.h

Lines changed: 12 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Tools/cases_generator/generate_cases.py

Lines changed: 74 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import typing
1414

1515
import parser
16+
from parser import StackEffect
1617

1718
DEFAULT_INPUT = os.path.relpath(
1819
os.path.join(os.path.dirname(__file__), "../../Python/bytecodes.c")
@@ -73,6 +74,34 @@ def block(self, head: str):
7374
yield
7475
self.emit("}")
7576

77+
def stack_adjust(self, diff: int):
78+
if diff > 0:
79+
self.emit(f"STACK_GROW({diff});")
80+
elif diff < 0:
81+
self.emit(f"STACK_SHRINK({-diff});")
82+
83+
def declare(self, dst: StackEffect, src: StackEffect | None):
84+
if dst.name == UNUSED:
85+
return
86+
type = f"{dst.type} " if dst.type else "PyObject *"
87+
init = ""
88+
if src:
89+
cast = self.cast(dst, src)
90+
init = f" = {cast}{src.name}"
91+
self.emit(f"{type}{dst.name}{init};")
92+
93+
def assign(self, dst: StackEffect, src: StackEffect):
94+
if src.name == UNUSED:
95+
return
96+
cast = self.cast(dst, src)
97+
if m := re.match(r"^PEEK\((\d+)\)$", dst.name):
98+
self.emit(f"POKE({m.group(1)}, {cast}{src.name});")
99+
else:
100+
self.emit(f"{dst.name} = {cast}{src.name};")
101+
102+
def cast(self, dst: StackEffect, src: StackEffect) -> str:
103+
return f"({dst.type or 'PyObject *'})" if src.type != dst.type else ""
104+
76105

77106
@dataclasses.dataclass
78107
class Instruction:
@@ -88,8 +117,8 @@ class Instruction:
88117
always_exits: bool
89118
cache_offset: int
90119
cache_effects: list[parser.CacheEffect]
91-
input_effects: list[parser.StackEffect]
92-
output_effects: list[parser.StackEffect]
120+
input_effects: list[StackEffect]
121+
output_effects: list[StackEffect]
93122

94123
# Set later
95124
family: parser.Family | None = None
@@ -106,7 +135,7 @@ def __init__(self, inst: parser.InstDef):
106135
]
107136
self.cache_offset = sum(c.size for c in self.cache_effects)
108137
self.input_effects = [
109-
effect for effect in inst.inputs if isinstance(effect, parser.StackEffect)
138+
effect for effect in inst.inputs if isinstance(effect, StackEffect)
110139
]
111140
self.output_effects = inst.outputs # For consistency/completeness
112141

@@ -122,16 +151,15 @@ def write(self, out: Formatter) -> None:
122151
)
123152

124153
# Write input stack effect variable declarations and initializations
125-
for i, seffect in enumerate(reversed(self.input_effects), 1):
126-
if seffect.name != UNUSED:
127-
out.emit(f"PyObject *{seffect.name} = PEEK({i});")
154+
for i, ieffect in enumerate(reversed(self.input_effects), 1):
155+
src = StackEffect(f"PEEK({i})", "")
156+
out.declare(ieffect, src)
128157

129158
# Write output stack effect variable declarations
130-
input_names = {seffect.name for seffect in self.input_effects}
131-
input_names.add(UNUSED)
132-
for seffect in self.output_effects:
133-
if seffect.name not in input_names:
134-
out.emit(f"PyObject *{seffect.name};")
159+
input_names = {ieffect.name for ieffect in self.input_effects}
160+
for oeffect in self.output_effects:
161+
if oeffect.name not in input_names:
162+
out.declare(oeffect, None)
135163

136164
self.write_body(out, 0)
137165

@@ -141,19 +169,17 @@ def write(self, out: Formatter) -> None:
141169

142170
# Write net stack growth/shrinkage
143171
diff = len(self.output_effects) - len(self.input_effects)
144-
if diff > 0:
145-
out.emit(f"STACK_GROW({diff});")
146-
elif diff < 0:
147-
out.emit(f"STACK_SHRINK({-diff});")
172+
out.stack_adjust(diff)
148173

149174
# Write output stack effect assignments
150-
unmoved_names = {UNUSED}
175+
unmoved_names = set()
151176
for ieffect, oeffect in zip(self.input_effects, self.output_effects):
152177
if ieffect.name == oeffect.name:
153178
unmoved_names.add(ieffect.name)
154-
for i, seffect in enumerate(reversed(self.output_effects)):
155-
if seffect.name not in unmoved_names:
156-
out.emit(f"POKE({i+1}, {seffect.name});")
179+
for i, oeffect in enumerate(reversed(self.output_effects), 1):
180+
if oeffect.name not in unmoved_names:
181+
dst = StackEffect(f"PEEK({i})", "")
182+
out.assign(dst, oeffect)
157183

158184
# Write cache effect
159185
if self.cache_offset:
@@ -223,23 +249,26 @@ def write_body(self, out: Formatter, dedent: int, cache_adjust: int = 0) -> None
223249

224250

225251
InstructionOrCacheEffect = Instruction | parser.CacheEffect
252+
StackEffectMapping = list[tuple[StackEffect, StackEffect]]
226253

227254

228255
@dataclasses.dataclass
229256
class Component:
230257
instr: Instruction
231-
input_mapping: dict[str, parser.StackEffect]
232-
output_mapping: dict[str, parser.StackEffect]
258+
input_mapping: StackEffectMapping
259+
output_mapping: StackEffectMapping
233260

234261
def write_body(self, out: Formatter, cache_adjust: int) -> None:
235262
with out.block(""):
236-
for var, ieffect in self.input_mapping.items():
237-
out.emit(f"PyObject *{ieffect.name} = {var};")
238-
for oeffect in self.output_mapping.values():
239-
out.emit(f"PyObject *{oeffect.name};")
263+
for var, ieffect in self.input_mapping:
264+
out.declare(ieffect, var)
265+
for _, oeffect in self.output_mapping:
266+
out.declare(oeffect, None)
267+
240268
self.instr.write_body(out, dedent=-4, cache_adjust=cache_adjust)
241-
for var, oeffect in self.output_mapping.items():
242-
out.emit(f"{var} = {oeffect.name};")
269+
270+
for var, oeffect in self.output_mapping:
271+
out.assign(var, oeffect)
243272

244273

245274
# TODO: Use a common base class for {Super,Macro}Instruction
@@ -250,7 +279,7 @@ class SuperOrMacroInstruction:
250279
"""Common fields for super- and macro instructions."""
251280

252281
name: str
253-
stack: list[str]
282+
stack: list[StackEffect]
254283
initial_sp: int
255284
final_sp: int
256285

@@ -445,15 +474,13 @@ def analyze_super(self, super: parser.Super) -> SuperInstruction:
445474
case parser.CacheEffect() as ceffect:
446475
parts.append(ceffect)
447476
case Instruction() as instr:
448-
input_mapping = {}
477+
input_mapping: StackEffectMapping = []
449478
for ieffect in reversed(instr.input_effects):
450479
sp -= 1
451-
if ieffect.name != UNUSED:
452-
input_mapping[stack[sp]] = ieffect
453-
output_mapping = {}
480+
input_mapping.append((stack[sp], ieffect))
481+
output_mapping: StackEffectMapping = []
454482
for oeffect in instr.output_effects:
455-
if oeffect.name != UNUSED:
456-
output_mapping[stack[sp]] = oeffect
483+
output_mapping.append((stack[sp], oeffect))
457484
sp += 1
458485
parts.append(Component(instr, input_mapping, output_mapping))
459486
case _:
@@ -471,15 +498,13 @@ def analyze_macro(self, macro: parser.Macro) -> MacroInstruction:
471498
case parser.CacheEffect() as ceffect:
472499
parts.append(ceffect)
473500
case Instruction() as instr:
474-
input_mapping = {}
501+
input_mapping: StackEffectMapping = []
475502
for ieffect in reversed(instr.input_effects):
476503
sp -= 1
477-
if ieffect.name != UNUSED:
478-
input_mapping[stack[sp]] = ieffect
479-
output_mapping = {}
504+
input_mapping.append((stack[sp], ieffect))
505+
output_mapping: StackEffectMapping = []
480506
for oeffect in instr.output_effects:
481-
if oeffect.name != UNUSED:
482-
output_mapping[stack[sp]] = oeffect
507+
output_mapping.append((stack[sp], oeffect))
483508
sp += 1
484509
parts.append(Component(instr, input_mapping, output_mapping))
485510
case _:
@@ -514,7 +539,7 @@ def check_macro_components(
514539

515540
def stack_analysis(
516541
self, components: typing.Iterable[InstructionOrCacheEffect]
517-
) -> tuple[list[str], int]:
542+
) -> tuple[list[StackEffect], int]:
518543
"""Analyze a super-instruction or macro.
519544
520545
Print an error if there's a cache effect (which we don't support yet).
@@ -536,7 +561,8 @@ def stack_analysis(
536561
# At this point, 'current' is the net stack effect,
537562
# and 'lowest' and 'highest' are the extremes.
538563
# Note that 'lowest' may be negative.
539-
stack = [f"_tmp_{i+1}" for i in range(highest - lowest)]
564+
# TODO: Reverse the numbering.
565+
stack = [StackEffect(f"_tmp_{i+1}", "") for i in range(highest - lowest)]
540566
return stack, -lowest
541567

542568
def write_instructions(self) -> None:
@@ -616,19 +642,17 @@ def wrap_super_or_macro(self, up: SuperOrMacroInstruction):
616642
self.out.emit("")
617643
with self.out.block(f"TARGET({up.name})"):
618644
for i, var in enumerate(up.stack):
645+
src = None
619646
if i < up.initial_sp:
620-
self.out.emit(f"PyObject *{var} = PEEK({up.initial_sp - i});")
621-
else:
622-
self.out.emit(f"PyObject *{var};")
647+
src = StackEffect(f"PEEK({up.initial_sp - i})", "")
648+
self.out.declare(var, src)
623649

624650
yield
625651

626-
if up.final_sp > up.initial_sp:
627-
self.out.emit(f"STACK_GROW({up.final_sp - up.initial_sp});")
628-
elif up.final_sp < up.initial_sp:
629-
self.out.emit(f"STACK_SHRINK({up.initial_sp - up.final_sp});")
652+
self.out.stack_adjust(up.final_sp - up.initial_sp)
630653
for i, var in enumerate(reversed(up.stack[: up.final_sp]), 1):
631-
self.out.emit(f"POKE({i}, {var});")
654+
dst = StackEffect(f"PEEK({i})", "")
655+
self.out.assign(dst, var)
632656

633657
self.out.emit(f"DISPATCH();")
634658

0 commit comments

Comments
 (0)