Skip to content

Commit b148a1f

Browse files
committed
Introduce RewriteBranchNextState
Fix unit tests
1 parent 920ad62 commit b148a1f

File tree

3 files changed

+80
-47
lines changed

3 files changed

+80
-47
lines changed

booster/library/Booster/JsonRpc.hs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ import Booster.Pattern.Base qualified as Pattern
4949
import Booster.Pattern.Implies (runImplies)
5050
import Booster.Pattern.Pretty
5151
import Booster.Pattern.Rewrite (
52+
RewriteBranchNextState (..),
5253
RewriteConfig (..),
5354
RewriteFailed (..),
5455
RewriteResult (..),
@@ -485,7 +486,7 @@ execResponse req (d, traces, rr) unsupported = case rr of
485486
, nextStates =
486487
Just
487488
$ map
488-
( \(_, muid, p', mrulePred, ruleSubst) -> toExecState p' unsupported (Just muid) mrulePred (Just ruleSubst)
489+
( \(RewriteBranchNextState{ruleUniqueId, rewrittenPat, mRulePredicate, ruleSubstitution}) -> toExecState rewrittenPat unsupported (Just ruleUniqueId) mRulePredicate (Just ruleSubstitution)
489490
)
490491
$ toList nexts
491492
, rule = Nothing

booster/library/Booster/Pattern/Rewrite.hs

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ module Booster.Pattern.Rewrite (
1313
RewriteConfig (..),
1414
RewriteFailed (..),
1515
RewriteResult (..),
16+
RewriteBranchNextState (..),
1617
RewriteTrace (..),
1718
pattern CollectRewriteTraces,
1819
pattern NoCollectRewriteTraces,
@@ -247,7 +248,14 @@ rewriteStep cutLabels terminalLabels pat = do
247248
RewriteBranch base $
248249
NE.fromList $
249250
map
250-
( \(rule, RewriteRuleAppliedData{rewritten, rulePredicate, ruleSubstitution}) -> (ruleLabelOrLocT rule, uniqueId rule, rewritten, rulePredicate, ruleSubstitution)
251+
( \(rule, RewriteRuleAppliedData{rewritten, rulePredicate, ruleSubstitution}) ->
252+
RewriteBranchNextState
253+
{ ruleLabel = ruleLabelOrLocT rule
254+
, ruleUniqueId = uniqueId rule
255+
, rewrittenPat = rewritten
256+
, mRulePredicate = rulePredicate
257+
, ruleSubstitution
258+
}
251259
)
252260
leafs
253261

@@ -789,10 +797,20 @@ ruleLabelOrLoc rule =
789797
fromMaybe "unknown rule" $
790798
fmap pretty rule.attributes.ruleLabel <|> fmap pretty rule.attributes.location
791799

800+
data RewriteBranchNextState pat = RewriteBranchNextState
801+
{ ruleLabel :: Text
802+
, ruleUniqueId :: UniqueId
803+
, rewrittenPat :: pat
804+
, mRulePredicate :: Maybe Predicate
805+
, ruleSubstitution :: Substitution
806+
}
807+
deriving stock (Eq, Show)
808+
deriving (Functor, Foldable, Traversable)
809+
792810
-- | Different rewrite results (returned from RPC execute endpoint)
793811
data RewriteResult pat
794812
= -- | branch point
795-
RewriteBranch pat (NonEmpty (Text, UniqueId, pat, Maybe Predicate, Substitution))
813+
RewriteBranch pat (NonEmpty (RewriteBranchNextState pat))
796814
| -- | no rules could be applied, config is stuck
797815
RewriteStuck pat
798816
| -- | cut point rule, return current (lhs) and single next state
@@ -1020,15 +1038,19 @@ performRewrite rewriteConfig pat = do
10201038
simplifyP p >>= \case
10211039
Nothing -> pure $ RewriteTrivial orig
10221040
Just p' -> do
1023-
-- simplify the 3rd component, i.e. the pattern
1024-
let simplifyP3rd (a, b, c, e, f) =
1025-
fmap (a,b,,e,f) <$> simplifyP c
1026-
nexts' <- catMaybes <$> mapM simplifyP3rd (toList nexts)
1041+
-- simplify the next-state pattern inside a branch payload
1042+
let simplifyRewritten pattr@RewriteBranchNextState{rewrittenPat} = do
1043+
( fmap @Maybe
1044+
( \rewrittenSimplified -> (pattr{rewrittenPat = rewrittenSimplified})
1045+
)
1046+
)
1047+
<$> simplifyP rewrittenPat
1048+
nexts' <- catMaybes <$> mapM simplifyRewritten (toList nexts)
10271049
pure $ case nexts' of
10281050
-- The `[]` case should be `Stuck` not `Trivial`, because `RewriteTrivial p'`
10291051
-- means the pattern `p'` is bottom, but we know that is not the case here.
10301052
[] -> RewriteStuck p'
1031-
[(lbl, uId, n, _rp, _rs)] -> RewriteFinished (Just lbl) (Just uId) n
1053+
[RewriteBranchNextState{ruleLabel, ruleUniqueId, rewrittenPat}] -> RewriteFinished (Just ruleLabel) (Just ruleUniqueId) rewrittenPat
10321054
ns -> RewriteBranch p' $ NE.fromList ns
10331055
r@RewriteStuck{} -> pure r
10341056
r@RewriteTrivial{} -> pure r
@@ -1098,7 +1120,9 @@ performRewrite rewriteConfig pat = do
10981120
incrementCounter
10991121
doSteps False single
11001122
RewriteBranch pat'' branches -> withPatternContext pat' $ do
1101-
emitRewriteTrace $ RewriteBranchingStep pat'' $ fmap (\(lbl, uid, _, _, _) -> (lbl, uid)) branches
1123+
emitRewriteTrace $
1124+
RewriteBranchingStep pat'' $
1125+
fmap (\RewriteBranchNextState{ruleLabel, ruleUniqueId} -> (ruleLabel, ruleUniqueId)) branches
11021126
pure simplified
11031127
_other -> withPatternContext pat' $ error "simplifyResult: Unexpected return value"
11041128
Right (cutPoint@(RewriteCutPoint lbl _ _ _), _) -> withPatternContext pat' $ do

booster/unit-tests/Test/Booster/Pattern/Rewrite.hs

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,13 @@ testConf = do
175175
ignoreRulePredicateAndSubst :: RewriteResult Pattern -> RewriteResult Pattern
176176
ignoreRulePredicateAndSubst =
177177
\case
178-
RewriteBranch pre posts -> RewriteBranch pre (NE.map (\(lbl, uid, p, _, _) -> (lbl, uid, p, Nothing, mempty)) posts)
178+
RewriteBranch pre posts ->
179+
RewriteBranch
180+
pre
181+
( NE.map
182+
(\nextState -> nextState{mRulePredicate = Nothing, ruleSubstitution = mempty})
183+
posts
184+
)
179185
other -> other
180186

181187
----------------------------------------
@@ -268,7 +274,7 @@ t `branchesTo` ts =
268274
@?>>= Right
269275
( RewriteBranch (Pattern_ t) $
270276
NE.fromList $
271-
map (\(lbl, t') -> (lbl, mockUniqueId, Pattern_ t', Nothing, mempty)) ts
277+
map (\(lbl, t') -> RewriteBranchNextState lbl mockUniqueId (Pattern_ t') Nothing mempty) ts
272278
)
273279

274280
failsWith :: Term -> RewriteFailed "Rewrite" -> IO ()
@@ -312,19 +318,19 @@ canRewrite =
312318
RewriteStuck
313319
, testCase "Rewrites con3 twice, branching on con1" $ do
314320
let branch1 =
315-
( "con1-f2"
316-
, mockUniqueId
317-
, [trm| kCell{}( kseq{}( inj{AnotherSort{}, SortKItem{}}( con4{}( \dv{SomeSort{}}("somethingElse"), \dv{SomeSort{}}("somethingElse") ) ), C:SortK{}) ) |]
318-
, Nothing
319-
, mempty
320-
)
321+
RewriteBranchNextState
322+
"con1-f2"
323+
mockUniqueId
324+
[trm| kCell{}( kseq{}( inj{AnotherSort{}, SortKItem{}}( con4{}( \dv{SomeSort{}}("somethingElse"), \dv{SomeSort{}}("somethingElse") ) ), C:SortK{}) ) |]
325+
Nothing
326+
mempty
321327
branch2 =
322-
( "con1-f1'"
323-
, mockUniqueId
324-
, [trm| kCell{}( kseq{}( inj{SomeSort{}, SortKItem{}}( f1{}( \dv{SomeSort{}}("somethingElse") ) ), C:SortK{}) ) |]
325-
, Nothing
326-
, mempty
327-
)
328+
RewriteBranchNextState
329+
"con1-f1'"
330+
mockUniqueId
331+
[trm| kCell{}( kseq{}( inj{SomeSort{}, SortKItem{}}( f1{}( \dv{SomeSort{}}("somethingElse") ) ), C:SortK{}) ) |]
332+
Nothing
333+
mempty
328334

329335
rewrites
330336
(Steps 1)
@@ -409,19 +415,20 @@ supportsDepthControl =
409415
(RewriteFinished Nothing Nothing)
410416
, testCase "prefers reporting branches to stopping at depth" $ do
411417
let branch1 =
412-
( "con1-f2"
413-
, mockUniqueId
414-
, [trm| kCell{}( kseq{}( inj{AnotherSort{}, SortKItem{}}( con4{}( \dv{SomeSort{}}("somethingElse"), \dv{SomeSort{}}("somethingElse") ) ), C:SortK{}) ) |]
415-
, Nothing
416-
, mempty
417-
)
418+
RewriteBranchNextState
419+
"con1-f2"
420+
mockUniqueId
421+
[trm| kCell{}( kseq{}( inj{AnotherSort{}, SortKItem{}}( con4{}( \dv{SomeSort{}}("somethingElse"), \dv{SomeSort{}}("somethingElse") ) ), C:SortK{}) ) |]
422+
Nothing
423+
mempty
424+
418425
branch2 =
419-
( "con1-f1'"
420-
, mockUniqueId
421-
, [trm| kCell{}( kseq{}( inj{SomeSort{}, SortKItem{}}( f1{}( \dv{SomeSort{}}("somethingElse") ) ), C:SortK{}) ) |]
422-
, Nothing
423-
, mempty
424-
)
426+
RewriteBranchNextState
427+
"con1-f1'"
428+
mockUniqueId
429+
[trm| kCell{}( kseq{}( inj{SomeSort{}, SortKItem{}}( f1{}( \dv{SomeSort{}}("somethingElse") ) ), C:SortK{}) ) |]
430+
Nothing
431+
mempty
425432

426433
rewritesToDepth
427434
(MaxDepth 2)
@@ -466,19 +473,20 @@ supportsCutPoints =
466473
RewriteStuck
467474
, testCase "prefers reporting branches to stopping at label in one branch" $ do
468475
let branch1 =
469-
( "con1-f2"
470-
, mockUniqueId
471-
, [trm| kCell{}( kseq{}( inj{AnotherSort{}, SortKItem{}}( con4{}( \dv{SomeSort{}}("somethingElse"), \dv{SomeSort{}}("somethingElse") ) ), C:SortK{}) ) |]
472-
, Nothing
473-
, mempty
474-
)
476+
RewriteBranchNextState
477+
"con1-f2"
478+
mockUniqueId
479+
[trm| kCell{}( kseq{}( inj{AnotherSort{}, SortKItem{}}( con4{}( \dv{SomeSort{}}("somethingElse"), \dv{SomeSort{}}("somethingElse") ) ), C:SortK{}) ) |]
480+
Nothing
481+
mempty
482+
475483
branch2 =
476-
( "con1-f1'"
477-
, mockUniqueId
478-
, [trm| kCell{}( kseq{}( inj{SomeSort{}, SortKItem{}}( f1{}( \dv{SomeSort{}}("somethingElse") ) ), C:SortK{}) ) |]
479-
, Nothing
480-
, mempty
481-
)
484+
RewriteBranchNextState
485+
"con1-f1'"
486+
mockUniqueId
487+
[trm| kCell{}( kseq{}( inj{SomeSort{}, SortKItem{}}( f1{}( \dv{SomeSort{}}("somethingElse") ) ), C:SortK{}) ) |]
488+
Nothing
489+
mempty
482490

483491
rewritesToCutPoint
484492
"con1-f2"

0 commit comments

Comments
 (0)