@@ -13,6 +13,7 @@ module Booster.Pattern.Rewrite (
13
13
RewriteConfig (.. ),
14
14
RewriteFailed (.. ),
15
15
RewriteResult (.. ),
16
+ RewriteBranchNextState (.. ),
16
17
RewriteTrace (.. ),
17
18
pattern CollectRewriteTraces ,
18
19
pattern NoCollectRewriteTraces ,
@@ -247,7 +248,14 @@ rewriteStep cutLabels terminalLabels pat = do
247
248
RewriteBranch base $
248
249
NE. fromList $
249
250
map
250
- ( \ (rule, RewriteRuleAppliedData {rewritten, rulePredicate, ruleSubstitution}) -> (ruleLabelOrLocT rule, uniqueId rule, rewritten, rulePredicate, ruleSubstitution)
251
+ ( \ (rule, RewriteRuleAppliedData {rewritten, rulePredicate, ruleSubstitution}) ->
252
+ RewriteBranchNextState
253
+ { ruleLabel = ruleLabelOrLocT rule
254
+ , ruleUniqueId = uniqueId rule
255
+ , rewrittenPat = rewritten
256
+ , mRulePredicate = rulePredicate
257
+ , ruleSubstitution
258
+ }
251
259
)
252
260
leafs
253
261
@@ -789,10 +797,20 @@ ruleLabelOrLoc rule =
789
797
fromMaybe " unknown rule" $
790
798
fmap pretty rule. attributes. ruleLabel <|> fmap pretty rule. attributes. location
791
799
800
+ data RewriteBranchNextState pat = RewriteBranchNextState
801
+ { ruleLabel :: Text
802
+ , ruleUniqueId :: UniqueId
803
+ , rewrittenPat :: pat
804
+ , mRulePredicate :: Maybe Predicate
805
+ , ruleSubstitution :: Substitution
806
+ }
807
+ deriving stock (Eq , Show )
808
+ deriving (Functor , Foldable , Traversable )
809
+
792
810
-- | Different rewrite results (returned from RPC execute endpoint)
793
811
data RewriteResult pat
794
812
= -- | branch point
795
- RewriteBranch pat (NonEmpty (Text , UniqueId , pat , Maybe Predicate , Substitution ))
813
+ RewriteBranch pat (NonEmpty (RewriteBranchNextState pat ))
796
814
| -- | no rules could be applied, config is stuck
797
815
RewriteStuck pat
798
816
| -- | cut point rule, return current (lhs) and single next state
@@ -1021,14 +1039,18 @@ performRewrite rewriteConfig pat = do
1021
1039
Nothing -> pure $ RewriteTrivial orig
1022
1040
Just p' -> do
1023
1041
-- simplify the 3rd component, i.e. the pattern
1024
- let simplifyP3rd (a, b, c, e, f) =
1025
- fmap (a,b,,e,f) <$> simplifyP c
1026
- nexts' <- catMaybes <$> mapM simplifyP3rd (toList nexts)
1042
+ let simplifyRewritten pattr@ RewriteBranchNextState {rewrittenPat} = do
1043
+ ( fmap @ Maybe
1044
+ ( \ rewrittenSimplified -> (pattr{rewrittenPat = rewrittenSimplified})
1045
+ )
1046
+ )
1047
+ <$> simplifyP rewrittenPat
1048
+ nexts' <- catMaybes <$> mapM simplifyRewritten (toList nexts)
1027
1049
pure $ case nexts' of
1028
1050
-- The `[]` case should be `Stuck` not `Trivial`, because `RewriteTrivial p'`
1029
1051
-- means the pattern `p'` is bottom, but we know that is not the case here.
1030
1052
[] -> RewriteStuck p'
1031
- [(lbl, uId, n, _rp, _rs) ] -> RewriteFinished (Just lbl ) (Just uId) n
1053
+ [RewriteBranchNextState {ruleLabel, ruleUniqueId, rewrittenPat} ] -> RewriteFinished (Just ruleLabel ) (Just ruleUniqueId) rewrittenPat
1032
1054
ns -> RewriteBranch p' $ NE. fromList ns
1033
1055
r@ RewriteStuck {} -> pure r
1034
1056
r@ RewriteTrivial {} -> pure r
@@ -1098,7 +1120,9 @@ performRewrite rewriteConfig pat = do
1098
1120
incrementCounter
1099
1121
doSteps False single
1100
1122
RewriteBranch pat'' branches -> withPatternContext pat' $ do
1101
- emitRewriteTrace $ RewriteBranchingStep pat'' $ fmap (\ (lbl, uid, _, _, _) -> (lbl, uid)) branches
1123
+ emitRewriteTrace $
1124
+ RewriteBranchingStep pat'' $
1125
+ fmap (\ RewriteBranchNextState {ruleLabel, ruleUniqueId} -> (ruleLabel, ruleUniqueId)) branches
1102
1126
pure simplified
1103
1127
_other -> withPatternContext pat' $ error " simplifyResult: Unexpected return value"
1104
1128
Right (cutPoint@ (RewriteCutPoint lbl _ _ _), _) -> withPatternContext pat' $ do
0 commit comments