Skip to content

Commit 2e5ecd5

Browse files
committed
Refactor applyRule in preparation for remainders
1 parent eec6d2b commit 2e5ecd5

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

booster/library/Booster/Pattern/Bool.hs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ module Booster.Pattern.Bool (
1010
negateBool,
1111
splitBoolPredicates,
1212
splitAndBools,
13+
collapseAndBools,
1314
-- patterns
1415
pattern TrueBool,
1516
pattern FalseBool,
@@ -206,3 +207,7 @@ splitAndBools :: Predicate -> [Predicate]
206207
splitAndBools p@(Predicate t)
207208
| AndBool l r <- t = concatMap (splitAndBools . Predicate) [l, r]
208209
| otherwise = [p]
210+
211+
-- | Inverse of splitAndBools
212+
collapseAndBools :: [Predicate] -> Predicate
213+
collapseAndBools = Predicate . foldAndBool . map (\(Predicate p) -> p)

booster/library/Booster/Pattern/Rewrite.hs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ rewriteStep cutLabels terminalLabels pat = do
207207
RewriteStuck{} -> pure $ RewriteTrivial pat
208208
other -> pure other
209209
-- all branches but one were either not applied or trivial
210-
[(r, x)]
210+
[(r, (x, _remainder, _subst))]
211211
| labelOf r `elem` cutLabels ->
212212
pure $ RewriteCutPoint (labelOf r) (uniqueId r) pat x
213213
| labelOf r `elem` terminalLabels ->
@@ -220,7 +220,7 @@ rewriteStep cutLabels terminalLabels pat = do
220220
pure $
221221
RewriteBranch pat $
222222
NE.fromList $
223-
map (\(r, p) -> (ruleLabelOrLocT r, uniqueId r, p)) rxs
223+
map (\(r, (p, _remainder, _subst)) -> (ruleLabelOrLocT r, uniqueId r, p)) rxs
224224

225225
-- | Rewrite rule application transformer: may throw exceptions on non-applicable or trivial rule applications
226226
type RewriteRuleAppT m a = ExceptT RewriteRuleAppException m a
@@ -260,7 +260,7 @@ applyRule ::
260260
LoggerMIO io =>
261261
Pattern ->
262262
RewriteRule "Rewrite" ->
263-
RewriteT io (RewriteRuleAppResult Pattern)
263+
RewriteT io (RewriteRuleAppResult (Pattern, Predicate, Substitution))
264264
applyRule pat@Pattern{ceilConditions} rule =
265265
withRuleContext rule $
266266
runRewriteRuleAppT $
@@ -326,12 +326,6 @@ applyRule pat@Pattern{ceilConditions} rule =
326326
-- check required constraints from lhs: Stop if any is false,
327327
-- add as remainders if indeterminate.
328328
unclearRequiresAfterSmt <- checkRequires subst
329-
-- when unclearRequiresAfterSmt is non-empty, we need to add it as a rule remainder.
330-
-- To maintain the old behaviour, we fail hard here
331-
unless (null unclearRequiresAfterSmt) $
332-
failRewrite $
333-
RuleConditionUnclear rule . coerce . foldl1 AndTerm $
334-
map coerce unclearRequiresAfterSmt
335329

336330
-- check ensures constraints (new) from rhs: stop and return `Trivial` if
337331
-- any are false, remove all that are trivially true, return the rest
@@ -366,9 +360,21 @@ applyRule pat@Pattern{ceilConditions} rule =
366360
<> (Set.fromList $ map (coerce . substituteInTerm existentialSubst . coerce) newConstraints)
367361
)
368362
ceilConditions
369-
withContext CtxSuccess $
370-
withPatternContext rewritten $
371-
return rewritten
363+
withContext CtxSuccess $ do
364+
case unclearRequiresAfterSmt of
365+
[] -> withPatternContext rewritten $ pure (rewritten, Predicate FalseBool, subst)
366+
_ -> do
367+
failRewrite $
368+
RuleConditionUnclear rule . coerce . foldl1 AndTerm $
369+
map coerce unclearRequiresAfterSmt
370+
-- TODO the following code is intentionally dead and should be enabled to get rewrite rule remainders
371+
-- when unclearRequiresAfterSmt is non-empty, we need to add it as a rule remainder predicate, which means:
372+
-- - the resulting patten will have it conjoined to its constraints TODO is this right?
373+
-- - its negation, i.e. the remainder predicate, will be returned as the second component of the result
374+
let rewritten' = rewritten{constraints = rewritten.constraints <> Set.fromList unclearRequiresAfterSmt}
375+
in withPatternContext rewritten' $
376+
pure
377+
(rewritten', Predicate $ NotBool $ coerce $ collapseAndBools unclearRequiresAfterSmt, subst)
372378
where
373379
filterOutKnownConstraints :: Set.Set Predicate -> [Predicate] -> RewriteT io [Predicate]
374380
filterOutKnownConstraints priorKnowledge constraitns = do

0 commit comments

Comments
 (0)