Skip to content

Commit 5b676c6

Browse files
committed
HOTFIX invalidate cache when modifying known predicates
The change in #3953 introduced additional arguments to evaluation functions which allow callers to supply some "known true" predicates for the simplification and evaluation. However, doing so means that the cache will get populated with associations that might only be true if this known truth does not change. The case in point was a predicate being cached as "true" because of an earlier evaluation, and then _removed_ from the path condition. This change removes returning the cache from any inteface functions that take known truth arguments, and runs internal computations that modify the predicates with a fresh empty cache.
1 parent 2a5be78 commit 5b676c6

File tree

4 files changed

+42
-27
lines changed

4 files changed

+42
-27
lines changed

booster/library/Booster/Pattern/ApplyEquations.hs

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,18 @@ countSteps = length . (.termStack) <$> getState
210210
pushTerm :: Monad io => Term -> EquationT io ()
211211
pushTerm t = eqState . modify $ \s -> s{termStack = t :<| s.termStack}
212212

213+
-- pushing new constraints means the cache is stale
213214
pushConstraints :: Monad io => Set Predicate -> EquationT io ()
214-
pushConstraints ps = eqState . modify $ \s -> s{predicates = s.predicates <> ps}
215+
pushConstraints ps = eqState . modify $ \s -> s{predicates = s.predicates <> ps, cache = mempty}
216+
217+
-- run an evaluation with a local empty cache (discarded afterwards)
218+
withoutCache :: Monad io => EquationT io a -> EquationT io a
219+
withoutCache action = do
220+
prior <- getState
221+
eqState $ put prior{cache = mempty}
222+
a <- action
223+
eqState $ modify $ \s -> s{cache = prior.cache}
224+
pure a
215225

216226
setChanged, resetChanged :: Monad io => EquationT io ()
217227
setChanged = eqState . modify $ \s -> s{changed = True}
@@ -333,6 +343,7 @@ runEquationT definition llvmApi smtSolver sCache known (EquationT m) = do
333343
, logger
334344
, prettyModifiers
335345
}
346+
-- NB the returned cache assumes the known predicates
336347
pure (res, endState.cache)
337348

338349
iterateEquations ::
@@ -432,9 +443,10 @@ evaluateTerm ::
432443
SMT.SMTContext ->
433444
Set Predicate ->
434445
Term ->
435-
io (Either EquationFailure Term, SimplifierCache)
446+
io (Either EquationFailure Term)
436447
evaluateTerm direction def llvmApi smtSolver knownPredicates =
437-
runEquationT def llvmApi smtSolver mempty knownPredicates
448+
fmap fst
449+
. runEquationT def llvmApi smtSolver mempty knownPredicates
438450
. evaluateTerm' direction
439451

440452
-- version for internal nested evaluation
@@ -447,6 +459,9 @@ evaluateTerm' direction = iterateEquations direction PreferFunctions
447459

448460
{- | Simplify a Pattern, processing its constraints independently.
449461
Returns either the first failure or the new pattern if no failure was encountered
462+
463+
The returned cache may only be reused if pat.constraints are known
464+
to remain true in the next usage context.
450465
-}
451466
evaluatePattern ::
452467
LoggerMIO io =>
@@ -476,7 +491,8 @@ evaluatePattern' pat@Pattern{term, ceilConditions} = withPatternContext pat $ do
476491
newTerm <- withTermContext term $ evaluateTerm' BottomUp term `catch_` keepTopLevelResults
477492
-- after evaluating the term, evaluate all (existing and
478493
-- newly-acquired) constraints, once
479-
traverse_ simplifyAssumedPredicate . predicates =<< getState
494+
-- this runs with a local empty cache because it manipulates the known constraints
495+
withoutCache (traverse_ simplifyAssumedPredicate . predicates =<< getState)
480496
-- this may yield additional new constraints, left unevaluated
481497
evaluatedConstraints <- predicates <$> getState
482498
-- break-up introduced symbolic _andBool_, filter-out trivial truth, de-duplicate
@@ -510,11 +526,12 @@ evaluatePattern' pat@Pattern{term, ceilConditions} = withPatternContext pat $ do
510526
err -> throw err
511527

512528
-- evaluate the given predicate assuming all others
529+
-- this manipulates the known predicates so it also resets the simplifier cache
513530
simplifyAssumedPredicate :: LoggerMIO io => Predicate -> EquationT io ()
514531
simplifyAssumedPredicate p = do
515532
allPs <- predicates <$> getState
516533
let otherPs = Set.delete p allPs
517-
eqState $ modify $ \s -> s{predicates = otherPs}
534+
eqState $ modify $ \s -> s{predicates = otherPs, cache = mempty}
518535
newP <- simplifyConstraint' True $ coerce p
519536
pushConstraints $ Set.singleton $ coerce newP
520537

@@ -525,16 +542,16 @@ evaluateConstraints ::
525542
SMT.SMTContext ->
526543
SimplifierCache ->
527544
Set Predicate ->
528-
io (Either EquationFailure (Set Predicate), SimplifierCache)
545+
io (Either EquationFailure (Set Predicate))
529546
evaluateConstraints def mLlvmLibrary smtSolver cache =
530-
runEquationT def mLlvmLibrary smtSolver cache mempty . evaluateConstraints'
547+
fmap fst . runEquationT def mLlvmLibrary smtSolver cache mempty . evaluateConstraints'
531548

532549
evaluateConstraints' ::
533550
LoggerMIO io =>
534551
Set Predicate ->
535552
EquationT io (Set Predicate)
536-
evaluateConstraints' constraints = do
537-
pushConstraints constraints
553+
evaluateConstraints' constraints = withoutCache $ do
554+
pushConstraints constraints -- invalidates the cache
538555
-- evaluate all existing constraints, once
539556
traverse_ simplifyAssumedPredicate . predicates =<< getState
540557
-- this may yield additional new constraints, left unevaluated
@@ -1074,10 +1091,6 @@ applyEquation term rule =
10741091
This is used during rewriting to simplify side conditions of rules
10751092
(to decide whether or not a rule can apply, not to retain the
10761093
ensured conditions).
1077-
1078-
If and as soon as this function is used inside equation
1079-
application, it needs to run within the same 'EquationT' context
1080-
so we can detect simplification loops and avoid monad nesting.
10811094
-}
10821095
simplifyConstraint ::
10831096
LoggerMIO io =>
@@ -1087,9 +1100,13 @@ simplifyConstraint ::
10871100
SimplifierCache ->
10881101
Set Predicate ->
10891102
Predicate ->
1090-
io (Either EquationFailure Predicate, SimplifierCache)
1091-
simplifyConstraint def mbApi smt cache knownPredicates (Predicate p) = do
1092-
runEquationT def mbApi smt cache knownPredicates $ (coerce <$>) . simplifyConstraint' True $ p
1103+
io (Either EquationFailure Predicate)
1104+
simplifyConstraint def mbApi smt cache knownPredicates =
1105+
fmap fst
1106+
. runEquationT def mbApi smt cache knownPredicates
1107+
. (coerce <$>)
1108+
. simplifyConstraint' True
1109+
. coerce
10931110

10941111
simplifyConstraints ::
10951112
LoggerMIO io =>

booster/library/Booster/Pattern/Implies.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ runImplies def mLlvmLibrary mSMTOptions antecedent consequent =
151151
else -- FIXME This is incomplete because patL.constraints are not assumed in the check.
152152

153153
ApplyEquations.evaluateConstraints def mLlvmLibrary solver mempty filteredConsequentPreds >>= \case
154-
(Right newPreds, _) ->
154+
Right newPreds ->
155155
if all (== Predicate TrueBool) newPreds
156156
then
157157
implies
@@ -161,7 +161,7 @@ runImplies def mLlvmLibrary mSMTOptions antecedent consequent =
161161
subst
162162
else -- here we conservatively abort (incomplete)
163163
pure . Left . RpcError.backendError $ RpcError.Aborted "unknown constraints"
164-
(Left other, _) ->
164+
Left other ->
165165
pure . Left . RpcError.backendError $ RpcError.Aborted (Text.pack . constructorName $ other)
166166

167167
case (internalised antecedent.term, internalised consequent.term) of

booster/library/Booster/Pattern/Rewrite.hs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -539,12 +539,10 @@ applyRule pat@Pattern{ceilConditions} rule =
539539
RewriteRuleAppT (RewriteT io) (Maybe a)
540540
checkConstraint onUnclear onBottom knownPredicates p = do
541541
RewriteConfig{definition, llvmApi, smtSolver} <- lift $ RewriteT ask
542-
RewriteState{cache = oldCache} <- lift . RewriteT . lift $ get
543-
(simplified, cache) <-
542+
RewriteState{cache} <- lift . RewriteT . lift $ get
543+
simplified <-
544544
withContext CtxConstraint $
545-
simplifyConstraint definition llvmApi smtSolver oldCache knownPredicates p
546-
-- update cache
547-
lift $ updateRewriterCache cache
545+
simplifyConstraint definition llvmApi smtSolver cache knownPredicates p
548546
case simplified of
549547
Right (Predicate FalseBool) -> onBottom
550548
Right (Predicate TrueBool) -> pure Nothing

booster/unit-tests/Test/Booster/Pattern/ApplyEquations.hs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ test_evaluateFunction =
9797
where
9898
eval direction t = do
9999
ns <- noSolver
100-
runNoLoggingT $ fst <$> evaluateTerm direction funDef Nothing ns mempty t
100+
runNoLoggingT $ evaluateTerm direction funDef Nothing ns mempty t
101101

102102
isTooManyIterations (Left (TooManyIterations _n _ _)) = pure ()
103103
isTooManyIterations (Left err) = assertFailure $ "Unexpected error " <> show err
@@ -126,7 +126,7 @@ test_simplify =
126126
where
127127
simpl direction t = do
128128
ns <- noSolver
129-
runNoLoggingT $ fst <$> evaluateTerm direction simplDef Nothing ns mempty t
129+
runNoLoggingT $ evaluateTerm direction simplDef Nothing ns mempty t
130130
a = var "A" someSort
131131

132132
test_simplifyPattern :: TestTree
@@ -223,7 +223,7 @@ test_simplifyConstraint =
223223
simpl t =
224224
do
225225
ns <- noSolver
226-
runNoLoggingT $ fst <$> simplifyConstraint testDefinition Nothing ns mempty mempty t
226+
runNoLoggingT $ simplifyConstraint testDefinition Nothing ns mempty mempty t
227227

228228
test_errors :: TestTree
229229
test_errors =
@@ -236,7 +236,7 @@ test_errors =
236236
loopTerms =
237237
[f $ app con1 [a], f $ app con2 [a], f $ app con3 [a, a], f $ app con1 [a]]
238238
ns <- noSolver
239-
isLoop loopTerms =<< (runNoLoggingT $ fst <$> evaluateTerm TopDown loopDef Nothing ns mempty subj)
239+
isLoop loopTerms =<< (runNoLoggingT $ evaluateTerm TopDown loopDef Nothing ns mempty subj)
240240
]
241241
where
242242
isLoop ts (Left (EquationLoop ts')) = ts @?= ts'

0 commit comments

Comments
 (0)