Skip to content

Retry harder when stopping rewriting due to RuleConditionUnclear #4000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
31 changes: 21 additions & 10 deletions booster/library/Booster/Pattern/ApplyEquations.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ module Booster.Pattern.ApplyEquations (
handleSimplificationEquation,
simplifyConstraint,
simplifyConstraints,
SimplifierCache,
SimplifierCache (..),
evaluateConstraints,
) where

Expand Down Expand Up @@ -421,7 +421,14 @@ evaluatePattern ::
SimplifierCache ->
Pattern ->
io (Either EquationFailure Pattern, SimplifierCache)
evaluatePattern def mLlvmLibrary smtSolver cache pat =
evaluatePattern def mLlvmLibrary smtSolver cache pat = do
logMessage
( "evaluatePattern assuming "
<> show (length pat.constraints)
<> "constraints and with cache of size "
<> show (Map.size cache.llvm, Map.size cache.equations)
)

runEquationT def mLlvmLibrary smtSolver cache pat.constraints . evaluatePattern' $ pat

-- version for internal nested evaluation
Expand Down Expand Up @@ -556,13 +563,12 @@ cached cacheTag cb t@(Term attributes _)
toCache cacheTag t simplified
pure simplified
Just cachedTerm -> do
when (t /= cachedTerm) $ do
setChanged
withTermContext t $
withContext CtxSuccess $
withContextFor cacheTag $
withTermContext cachedTerm $
pure ()
when (t /= cachedTerm) setChanged
withTermContext t $
withContext CtxSuccess $
withContextFor cacheTag $
withTermContext cachedTerm $
pure ()
pure cachedTerm

elseApply :: (Monad m, Eq b) => (b -> m b) -> (b -> m b) -> b -> m b
Expand Down Expand Up @@ -696,7 +702,10 @@ applyEquations theory handler term = do
concatMap snd . Map.toAscList . Map.unionsWith (<>) $
map equationsFor indexes

processEquations equations
withTermContext term $ do
logMessage $
"Trying equations: " <> show (map (.attributes.uniqueId) equations)
processEquations equations
where
-- process one equation at a time, until something has happened
processEquations ::
Expand Down Expand Up @@ -1010,6 +1019,8 @@ simplifyConstraint ::
Predicate ->
io (Either EquationFailure Predicate, SimplifierCache)
simplifyConstraint def mbApi mbSMT cache knownPredicates (Predicate p) = do
logMessage
("simplifyConstraint with cache of size " <> show (Map.size cache.llvm, Map.size cache.equations))
runEquationT def mbApi mbSMT cache knownPredicates $ (coerce <$>) . simplifyConstraint' True $ p

simplifyConstraints ::
Expand Down
63 changes: 58 additions & 5 deletions booster/library/Booster/Pattern/Rewrite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.Trans.Class
import Control.Monad.Trans.Except
import Control.Monad.Trans.Reader (ReaderT (..), ask, asks, withReaderT)
import Control.Monad.Trans.State.Strict (StateT (runStateT), get, modify)
import Control.Monad.Trans.State.Strict (StateT (runStateT), get, modify, put)
import Data.Aeson (object, (.=))
import Data.Bifunctor (bimap)
import Data.Coerce (coerce)
Expand All @@ -36,10 +36,10 @@ import Data.List (intersperse, partition)
import Data.List.NonEmpty (NonEmpty (..), toList)
import Data.List.NonEmpty qualified as NE
import Data.Map qualified as Map
import Data.Maybe (catMaybes, fromMaybe)
import Data.Maybe (catMaybes, fromJust, fromMaybe)
import Data.Sequence (Seq, (|>))
import Data.Set qualified as Set
import Data.Text as Text (Text, pack)
import Data.Text as Text (Text, intercalate, pack)
import Numeric.Natural
import Prettyprinter

Expand All @@ -49,7 +49,7 @@ import Booster.LLVM as LLVM (API)
import Booster.Log
import Booster.Pattern.ApplyEquations (
EquationFailure (..),
SimplifierCache,
SimplifierCache (..),
evaluatePattern,
simplifyConstraint,
)
Expand All @@ -67,6 +67,8 @@ import Booster.Pattern.Pretty
import Booster.Pattern.Util
import Booster.Prettyprinter
import Booster.SMT.Interface qualified as SMT
import Booster.SMT.Runner (SMTContext (..))
import Booster.SMT.Runner qualified as SMT (evalSMT)
import Booster.Syntax.Json.Externalise (externaliseTerm)
import Booster.Util (Flag (..))

Expand Down Expand Up @@ -729,6 +731,16 @@ performRewrite doTracing def mLlvmLibrary mSolver mbMaxDepth cutLabels terminalL
pure (counter, traces, rr)
where
logDepth = withContext CtxDepth . logMessage
logDebugRewriter RewriteStepsState{counter, simplifierCache, smtSolver} =
withContext CtxDebugRewriter $ do
logMessage $
Text.intercalate
","
[ ("steps: " <> (pack . show $ counter))
, ("LLVM cache: " <> (pack . show . Map.size $ simplifierCache.llvm))
, ("Equations cache: " <> (pack . show . Map.size $ simplifierCache.equations))
, ("Solver: " <> (pack . show $ options <$> smtSolver))
]

depthReached n = maybe False (n >=) mbMaxDepth

Expand All @@ -744,6 +756,11 @@ performRewrite doTracing def mLlvmLibrary mSolver mbMaxDepth cutLabels terminalL

updateCache simplifierCache = modify $ \rss -> rss{simplifierCache}

purgeCache = modify $ \rss -> rss{simplifierCache = mempty}

updateSolver :: SMT.SMTContext -> RewriteStepsState -> RewriteStepsState
updateSolver solver s = s{smtSolver = Just solver}

simplifyP :: Pattern -> StateT RewriteStepsState io (Maybe Pattern)
simplifyP p = withContext CtxSimplify $ do
st <- get
Expand Down Expand Up @@ -805,8 +822,9 @@ performRewrite doTracing def mLlvmLibrary mSolver mbMaxDepth cutLabels terminalL
doSteps ::
Bool -> Pattern -> StateT RewriteStepsState io (RewriteResult Pattern)
doSteps wasSimplified pat' = do
RewriteStepsState{counter, simplifierCache} <- get
rewriteStepsState@RewriteStepsState{counter, simplifierCache, smtSolver} <- get
logDepth $ showCounter counter
logDebugRewriter rewriteStepsState
if depthReached counter
then do
logDepth $ "Reached maximum depth of " <> maybe "?" showCounter mbMaxDepth
Expand Down Expand Up @@ -884,6 +902,13 @@ performRewrite doTracing def mLlvmLibrary mSolver mbMaxDepth cutLabels terminalL
emitRewriteTrace $ RewriteStepFailed failure
-- simplify remainders, substitute and rerun.
-- If failed, do the pattern-wide simplfication and rerun again
-- start new solver because it was already stopped permanently...
solver <- SMT.initSolver def SMT.defaultSMTOptions
let act = do
rss <- get
let rss' = updateSolver solver rss
put rss'
act
withSimplified pat' "Retrying with simplified pattern" (doSteps True)
| otherwise -> do
-- was already simplified, emit an abort log entry
Expand All @@ -902,6 +927,34 @@ performRewrite doTracing def mLlvmLibrary mSolver mbMaxDepth cutLabels terminalL
emitRewriteTrace $ RewriteStepFailed failure
logMessage $ "Aborted after " <> showCounter counter
pure (RewriteAborted failure pat')
-- if a rule condition was unclear and the pattern was
-- unsimplified, simplify and retry rewriting once
Left failure@(RuleConditionUnclear rule unclearCondition)
| not wasSimplified -> do
-- start new solver because it was already stopped permanently...
solver <- SMT.initSolver def SMT.defaultSMTOptions
let act = do
rss <- get
let rss' = updateSolver solver rss
put rss'
act
-- purge simplification cache
purgeCache
-- TODO: perform some sort of cache sanitation and log the difference.
-- For example: simplify all keys in the cache under the current conditions
-- and see if any produce different results
withSimplified pat' "Retrying with simplified pattern" (doSteps True)
| otherwise -> do
-- was already simplified, emit an abort log entry
getPrettyModifiers >>= \case
ModifiersRep (_ :: FromModifiersT mods => Proxy mods) ->
withRuleContext rule . withContext CtxAbort . logMessage $
WithJsonMessage (object ["conditions" .= (externaliseTerm . coerce $ unclearCondition)]) $
renderOneLineText $
"Uncertain about condition(s) in a rule:"
<+> (pretty' @mods unclearCondition)
emitRewriteTrace $ RewriteStepFailed failure
pure (RewriteAborted failure pat')
Left failure -> do
emitRewriteTrace $ RewriteStepFailed failure
let msg = "Aborted after " <> showCounter counter
Expand Down
1 change: 1 addition & 0 deletions kore-rpc-types/src/Kore/JsonRpc/Types/ContextLog.hs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ data SimpleContext
| CtxSubstitution
| CtxRemainder
| CtxDepth
| CtxDebugRewriter
| CtxTiming
| -- standard log levels
CtxError
Expand Down
Loading