Skip to content

Commit 9a0a41f

Browse files
committed
kevm-pyk/__main__: fall back to manual branch extraction using haskell backend
1 parent bc746fd commit 9a0a41f

File tree

1 file changed

+37
-19
lines changed

1 file changed

+37
-19
lines changed

kevm-pyk/src/kevm_pyk/__main__.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
from pyk.cli_utils import dir_path, file_path
1010
from pyk.cterm import CTerm
1111
from pyk.kast import KApply, KAtt, KDefinition, KFlatModule, KImport, KInner, KRequire, KRewrite, KRule, KToken
12-
from pyk.kastManip import minimize_term, push_down_rewrites
12+
from pyk.kastManip import flatten_label, minimize_term, push_down_rewrites
1313
from pyk.kcfg import KCFG
1414
from pyk.ktool.kit import KIT
1515
from pyk.ktool.krun import _krun
16-
from pyk.prelude.ml import mlTop
16+
from pyk.prelude.ml import mlAnd, mlTop
1717
from pyk.utils import shorten_hashes
1818

1919
from .gst_to_kore import gst_to_kore
@@ -378,7 +378,7 @@ def prove_it(id_and_cfg: Tuple[str, Tuple[KCFG, Path]]) -> bool:
378378
_LOGGER.info(f'Advancing proof from node {cfgid}: {shorten_hashes(curr_node.id)}')
379379
edge = KCFG.Edge(curr_node, target_node, mlTop(), -1)
380380
claim = edge.to_claim()
381-
claim_id = f'gen-{curr_node.id}-to-{target_node.id}'
381+
claim_id = f'gen-block-{curr_node.id}-to-{target_node.id}'
382382
depth, branching, result = foundry.get_claim_basic_block(
383383
claim_id, claim, lemmas=lemma_rules, max_depth=max_depth
384384
)
@@ -393,31 +393,49 @@ def prove_it(id_and_cfg: Tuple[str, Tuple[KCFG, Path]]) -> bool:
393393
next_state = CTerm(sanitize_config(foundry.definition, result))
394394
next_node = cfg.get_or_create_node(next_state)
395395
if next_node != curr_node:
396-
_LOGGER.info(f'Found basic block at depth {depth} for {cfgid}: {shorten_hashes((curr_node.id, next_node.id))}.')
396+
_LOGGER.info(
397+
f'Found basic block at depth {depth} for {cfgid}: {shorten_hashes((curr_node.id, next_node.id))}.'
398+
)
397399
cfg.create_edge(curr_node.id, next_node.id, mlTop(), depth)
398400

399401
if KEVM.is_terminal(next_node.cterm):
400402
cfg.add_expanded(next_node.id)
401403
_LOGGER.info(f'Terminal node {cfgid}: {shorten_hashes((curr_node.id))}.')
402404

403405
elif branching:
406+
cfg.add_expanded(next_node.id)
404407
branches = KEVM.extract_branches(next_state)
405-
if not branches:
406-
raise ValueError(
407-
f'Could not extract branch condition {cfgid}:\n{foundry.pretty_print(minimize_term(result))}'
408+
if len(list(branches)) > 0:
409+
_LOGGER.info(
410+
f'Found {len(list(branches))} branches at depth {depth} for {cfgid}: {[foundry.pretty_print(b) for b in branches]}'
408411
)
409-
cfg.add_expanded(next_node.id)
410-
_LOGGER.info(
411-
f'Found {len(list(branches))} branches at depth {depth} for {cfgid}: {[foundry.pretty_print(b) for b in branches]}'
412-
)
413-
for branch in branches:
414-
branch_cterm = next_state.add_constraint(branch)
415-
branch_node = cfg.get_or_create_node(branch_cterm)
416-
cfg.create_edge(next_node.id, branch_node.id, branch, 0)
417-
_LOGGER.info(f'Made split for {cfgid}: {shorten_hashes((next_node.id, branch_node.id))}')
418-
# TODO: have to store case splits as rewrites because of how frontier is handled for covers
419-
# cfg.create_cover(branch_node.id, next_node.id)
420-
# _LOGGER.info(f'Made cover: {shorten_hashes((branch_node.id, next_node.id))}')
412+
for branch in branches:
413+
branch_cterm = next_state.add_constraint(branch)
414+
branch_node = cfg.get_or_create_node(branch_cterm)
415+
cfg.create_edge(next_node.id, branch_node.id, branch, 0)
416+
_LOGGER.info(f'Made split for {cfgid}: {shorten_hashes((next_node.id, branch_node.id))}')
417+
# TODO: have to store case splits as rewrites because of how frontier is handled for covers
418+
# cfg.create_cover(branch_node.id, next_node.id)
419+
# _LOGGER.info(f'Made cover: {shorten_hashes((branch_node.id, next_node.id))}')
420+
else:
421+
_LOGGER.warning(
422+
f'Falling back to running backend for branch extraction {cfgid}:\n{foundry.pretty_print(minimize_term(result))}'
423+
)
424+
edge = KCFG.Edge(next_node, target_node, mlTop(), -1)
425+
claim = edge.to_claim()
426+
claim_id = f'gen-branch-{curr_node.id}-to-{target_node.id}'
427+
result = foundry.prove_claim(claim, claim_id, lemmas=lemma_rules, args=['--depth', '1'])
428+
branch_cterms = [CTerm(r) for r in flatten_label('#Or', result)]
429+
old_constraints = next_state.constraints
430+
new_constraints = [
431+
[c for c in s.constraints if c not in old_constraints] for s in branch_cterms
432+
]
433+
_LOGGER.info(
434+
f'Found {len(list(branch_cterms))} branches manually ad depth 1 for {cfgid}: {[foundry.pretty_print(mlAnd(nc)) for nc in new_constraints]}'
435+
)
436+
for ns, nc in zip(branch_cterms, new_constraints):
437+
branch_node = cfg.get_or_create_node(ns)
438+
cfg.create_edge(next_node.id, branch_node.id, mlAnd(nc), 1)
421439

422440
_write_cfg(cfg, cfgpath)
423441

0 commit comments

Comments
 (0)