Skip to content

Commit 481db72

Browse files
authored
Make the handling of ControlOp targets more generic (#10920)
Add a `targets` property and a `set_target` method to `ControlOp` to enable generic processing. Update CFG extraction, CFG cleanup, and (most importantly) refcount insertion to operate generically. This will make it easier to add new control ops, such as switch statements.
1 parent d991d19 commit 481db72

File tree

3 files changed

+59
-53
lines changed

3 files changed

+59
-53
lines changed

mypyc/analysis/dataflow.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,8 @@ def get_cfg(blocks: List[BasicBlock]) -> CFG:
5353
assert not any(isinstance(op, ControlOp) for op in block.ops[:-1]), (
5454
"Control-flow ops must be at the end of blocks")
5555

56-
last = block.ops[-1]
57-
if isinstance(last, Branch):
58-
succ = [last.true, last.false]
59-
elif isinstance(last, Goto):
60-
succ = [last.label]
61-
else:
62-
succ = []
56+
succ = list(block.terminator.targets())
57+
if not succ:
6358
exits.add(block)
6459

6560
# Errors can occur anywhere inside a block, which means that
@@ -104,12 +99,8 @@ def cleanup_cfg(blocks: List[BasicBlock]) -> None:
10499
while changed:
105100
# First collapse any jumps to basic block that only contain a goto
106101
for block in blocks:
107-
term = block.ops[-1]
108-
if isinstance(term, Goto):
109-
term.label = get_real_target(term.label)
110-
elif isinstance(term, Branch):
111-
term.true = get_real_target(term.true)
112-
term.false = get_real_target(term.false)
102+
for i, tgt in enumerate(block.terminator.targets()):
103+
block.terminator.set_target(i, get_real_target(tgt))
113104

114105
# Then delete any blocks that have no predecessors
115106
changed = False

mypyc/ir/ops.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ def terminated(self) -> bool:
7575
"""
7676
return bool(self.ops) and isinstance(self.ops[-1], ControlOp)
7777

78+
@property
79+
def terminator(self) -> 'ControlOp':
80+
"""The terminator operation of the block."""
81+
assert bool(self.ops) and isinstance(self.ops[-1], ControlOp)
82+
return self.ops[-1]
83+
7884

7985
# Never generates an exception
8086
ERR_NEVER: Final = 0
@@ -260,12 +266,15 @@ def accept(self, visitor: 'OpVisitor[T]') -> T:
260266

261267

262268
class ControlOp(Op):
263-
"""Control flow operation.
269+
"""Control flow operation."""
264270

265-
This is Basically just for class hierarchy organization.
271+
def targets(self) -> Sequence[BasicBlock]:
272+
"""Get all basic block targets of the control operation."""
273+
return ()
266274

267-
We could plausibly have a targets() method if we wanted.
268-
"""
275+
def set_target(self, i: int, new: BasicBlock) -> None:
276+
"""Update a basic block target."""
277+
raise AssertionError("Invalid set_target({}, {})".format(self, i))
269278

270279

271280
class Goto(ControlOp):
@@ -277,6 +286,13 @@ def __init__(self, label: BasicBlock, line: int = -1) -> None:
277286
super().__init__(line)
278287
self.label = label
279288

289+
def targets(self) -> Sequence[BasicBlock]:
290+
return (self.label,)
291+
292+
def set_target(self, i: int, new: BasicBlock) -> None:
293+
assert i == 0
294+
self.label = new
295+
280296
def __repr__(self) -> str:
281297
return '<Goto %s>' % self.label.label
282298

@@ -327,6 +343,16 @@ def __init__(self,
327343
# If True, the condition is expected to be usually False (for optimization purposes)
328344
self.rare = rare
329345

346+
def targets(self) -> Sequence[BasicBlock]:
347+
return (self.true, self.false)
348+
349+
def set_target(self, i: int, new: BasicBlock) -> None:
350+
assert i == 0 or i == 1
351+
if i == 0:
352+
self.true = new
353+
elif i == 1:
354+
self.false = new
355+
330356
def sources(self) -> List[Value]:
331357
return [self.value]
332358

mypyc/transform/refcount.py

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@
3333
from mypyc.ir.func_ir import FuncIR, all_values
3434

3535

36-
DecIncs = Tuple[Tuple[Tuple[Value, bool], ...], Tuple[Value, ...]]
36+
Decs = Tuple[Tuple[Value, bool], ...]
37+
Incs = Tuple[Value, ...]
3738

38-
# A of basic blocks that decrement and increment specific values and
39-
# then jump to some target block. This lets us cut down on how much
40-
# code we generate in some circumstances.
41-
BlockCache = Dict[Tuple[BasicBlock, DecIncs], BasicBlock]
39+
# A cache of basic blocks that decrement and increment specific values
40+
# and then jump to some target block. This lets us cut down on how
41+
# much code we generate in some circumstances.
42+
BlockCache = Dict[Tuple[BasicBlock, Decs, Incs], BasicBlock]
4243

4344

4445
def insert_ref_count_opcodes(ir: FuncIR) -> None:
@@ -161,36 +162,25 @@ def f(a: int) -> None
161162
source_live_regs = pre_live[prev_key]
162163
source_borrowed = post_borrow[prev_key]
163164
source_defined = post_must_defined[prev_key]
164-
if isinstance(block.ops[-1], Branch):
165-
branch = block.ops[-1]
165+
166+
term = block.terminator
167+
for i, target in enumerate(term.targets()):
166168
# HAX: After we've checked against an error value the value we must not touch the
167169
# refcount since it will be a null pointer. The correct way to do this would be
168170
# to perform data flow analysis on whether a value can be null (or is always
169171
# null).
170-
if branch.op == Branch.IS_ERROR:
171-
omitted = {branch.value}
172+
omitted: Iterable[Value]
173+
if isinstance(term, Branch) and term.op == Branch.IS_ERROR and i == 0:
174+
omitted = (term.value,)
172175
else:
173-
omitted = set()
174-
true_decincs = (
175-
after_branch_decrefs(
176-
branch.true, pre_live, source_defined,
177-
source_borrowed, source_live_regs, ordering, omitted),
178-
after_branch_increfs(
179-
branch.true, pre_live, pre_borrow, source_borrowed, ordering))
180-
branch.true = add_block(true_decincs, cache, blocks, branch.true)
181-
182-
false_decincs = (
183-
after_branch_decrefs(
184-
branch.false, pre_live, source_defined, source_borrowed, source_live_regs,
185-
ordering),
186-
after_branch_increfs(
187-
branch.false, pre_live, pre_borrow, source_borrowed, ordering))
188-
branch.false = add_block(false_decincs, cache, blocks, branch.false)
189-
elif isinstance(block.ops[-1], Goto):
190-
goto = block.ops[-1]
191-
new_decincs = ((), after_branch_increfs(
192-
goto.label, pre_live, pre_borrow, source_borrowed, ordering))
193-
goto.label = add_block(new_decincs, cache, blocks, goto.label)
176+
omitted = ()
177+
178+
decs = after_branch_decrefs(
179+
target, pre_live, source_defined,
180+
source_borrowed, source_live_regs, ordering, omitted)
181+
incs = after_branch_increfs(
182+
target, pre_live, pre_borrow, source_borrowed, ordering)
183+
term.set_target(i, add_block(decs, incs, cache, blocks, target))
194184

195185

196186
def after_branch_decrefs(label: BasicBlock,
@@ -199,7 +189,7 @@ def after_branch_decrefs(label: BasicBlock,
199189
source_borrowed: Set[Value],
200190
source_live_regs: Set[Value],
201191
ordering: Dict[Value, int],
202-
omitted: Iterable[Value] = ()) -> Tuple[Tuple[Value, bool], ...]:
192+
omitted: Iterable[Value]) -> Tuple[Tuple[Value, bool], ...]:
203193
target_pre_live = pre_live[label, 0]
204194
decref = source_live_regs - target_pre_live - source_borrowed
205195
if decref:
@@ -224,22 +214,21 @@ def after_branch_increfs(label: BasicBlock,
224214
return ()
225215

226216

227-
def add_block(decincs: DecIncs, cache: BlockCache,
217+
def add_block(decs: Decs, incs: Incs, cache: BlockCache,
228218
blocks: List[BasicBlock], label: BasicBlock) -> BasicBlock:
229-
decs, incs = decincs
230219
if not decs and not incs:
231220
return label
232221

233222
# TODO: be able to share *partial* results
234-
if (label, decincs) in cache:
235-
return cache[label, decincs]
223+
if (label, decs, incs) in cache:
224+
return cache[label, decs, incs]
236225

237226
block = BasicBlock()
238227
blocks.append(block)
239228
block.ops.extend(DecRef(reg, is_xdec=xdec) for reg, xdec in decs)
240229
block.ops.extend(IncRef(reg) for reg in incs)
241230
block.ops.append(Goto(label))
242-
cache[label, decincs] = block
231+
cache[label, decs, incs] = block
243232
return block
244233

245234

0 commit comments

Comments
 (0)