Skip to content

Commit dbabf23

Browse files
committed
Experiment: optional booster simplification after N rewrite steps (default: disabled)
1 parent 61258dd commit dbabf23

File tree

6 files changed

+93
-58
lines changed

6 files changed

+93
-58
lines changed

booster/library/Booster/CLOptions.hs

Lines changed: 51 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ module Booster.CLOptions (
77
CLOptions (..),
88
EquationOptions (..),
99
LogFormat (..),
10+
RewriteOptions (..),
1011
TimestampFormat (..),
1112
clOptionsParser,
1213
adjustLogLevels,
@@ -25,6 +26,7 @@ import Data.Maybe (fromMaybe)
2526
import Data.Text (Text, pack)
2627
import Data.Text.Encoding (decodeASCII)
2728
import Data.Version (Version (..), showVersion)
29+
import Numeric.Natural (Natural)
2830
import Options.Applicative
2931

3032
import Booster.GlobalState (EquationOptions (..))
@@ -49,11 +51,18 @@ data CLOptions = CLOptions
4951
, logFile :: Maybe FilePath
5052
, smtOptions :: Maybe SMTOptions
5153
, equationOptions :: EquationOptions
52-
, indexCells :: [Text]
54+
, rewriteOptions :: RewriteOptions
5355
, prettyPrintOptions :: [ModifierT]
5456
}
5557
deriving (Show)
5658

59+
60+
data RewriteOptions = RewriteOptions
61+
{ indexCells :: [Text]
62+
, interimSimplification :: Maybe Natural
63+
}
64+
deriving stock (Show, Eq)
65+
5766
data LogFormat
5867
= Standard
5968
| OneLine
@@ -156,13 +165,7 @@ clOptionsParser =
156165
)
157166
<*> parseSMTOptions
158167
<*> parseEquationOptions
159-
<*> option
160-
(eitherReader $ mapM (readCellName . trim) . splitOn ",")
161-
( metavar "CELL-NAME[,CELL-NAME]"
162-
<> long "index-cells"
163-
<> help "Names of configuration cells to index rewrite rules with (default: 'k')"
164-
<> value []
165-
)
168+
<*> parseRewriteOptions
166169
<*> option
167170
(eitherReader $ mapM (readModifierT . trim) . splitOn ",")
168171
( metavar "PRETTY_PRINT"
@@ -202,18 +205,6 @@ clOptionsParser =
202205
"nanoseconds" -> Right Nanoseconds
203206
other -> Left $ other <> ": Unsupported timestamp format"
204207

205-
readCellName :: String -> Either String Text
206-
readCellName input
207-
| null input =
208-
Left "Empty cell name"
209-
| all isAscii input
210-
, all isPrint input =
211-
Right $ "Lbl'-LT-'" <> enquote input <> "'-GT-'"
212-
| otherwise =
213-
Left $ "Illegal non-ascii characters in `" <> input <> "'"
214-
215-
enquote = decodeASCII . encodeLabel . BS.pack
216-
217208
-- custom log levels that can be selected
218209
allowedLogLevels :: [(String, String)]
219210
allowedLogLevels =
@@ -364,13 +355,6 @@ parseSMTOptions =
364355
where
365356
smtDefaults = defaultSMTOptions
366357

367-
nonnegativeInt :: ReadM Int
368-
nonnegativeInt =
369-
auto >>= \case
370-
i
371-
| i < 0 -> readerError "must be a non-negative integer."
372-
| otherwise -> pure i
373-
374358
readTactic =
375359
either (readerError . ("Invalid s-expression. " <>)) pure . SMT.parseSExpr . BS.pack =<< str
376360

@@ -397,12 +381,46 @@ parseEquationOptions =
397381
defaultMaxIterations = 100
398382
defaultMaxRecursion = 5
399383

400-
nonnegativeInt :: ReadM Int
401-
nonnegativeInt =
402-
auto >>= \case
403-
i
404-
| i < 0 -> readerError "must be a non-negative integer."
405-
| otherwise -> pure i
384+
nonnegativeInt :: Integral i => ReadM i
385+
nonnegativeInt =
386+
auto @Integer >>= \case
387+
i
388+
| i < 0 -> readerError "must be a non-negative integer."
389+
| otherwise -> pure (fromIntegral i)
390+
391+
parseRewriteOptions :: Parser RewriteOptions
392+
parseRewriteOptions =
393+
RewriteOptions
394+
<$> option
395+
(eitherReader $ mapM (readCellName . trim) . splitOn ",")
396+
( metavar "CELL-NAME[,CELL-NAME]"
397+
<> long "index-cells"
398+
<> help "Names of configuration cells to index rewrite rules with (default: 'k')"
399+
<> value []
400+
)
401+
<*> optional
402+
( option
403+
nonnegativeInt
404+
( metavar "DEPTH"
405+
<> long "booster-interim-simplification"
406+
<> help "If given: Simplify the term each time the given rewrite depth is reached"
407+
)
408+
)
409+
where
410+
readCellName :: String -> Either String Text
411+
readCellName input
412+
| null input =
413+
Left "Empty cell name"
414+
| all isAscii input
415+
, all isPrint input =
416+
Right $ "Lbl'-LT-'" <> enquote input <> "'-GT-'"
417+
| otherwise =
418+
Left $ "Illegal non-ascii characters in `" <> input <> "'"
419+
420+
enquote = decodeASCII . encodeLabel . BS.pack
421+
422+
423+
406424

407425
versionInfoParser :: Parser (a -> a)
408426
versionInfoParser =

booster/library/Booster/JsonRpc.hs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ import Numeric.Natural
4242
import Prettyprinter (comma, hsep, punctuate, (<+>))
4343
import System.Clock (Clock (Monotonic), diffTimeSpec, getTime, toNanoSecs)
4444

45+
import Booster.CLOptions (RewriteOptions (..))
4546
import Booster.Definition.Attributes.Base (UniqueId, getUniqueId, uniqueId)
4647
import Booster.Definition.Base (KoreDefinition (..))
4748
import Booster.Definition.Base qualified as Definition (RewriteRule (..))
@@ -108,7 +109,7 @@ respond stateVar request =
108109
| isJust req.stepTimeout -> pure $ Left $ RpcError.unsupportedOption ("step-timeout" :: String)
109110
| isJust req.movingAverageStepTimeout ->
110111
pure $ Left $ RpcError.unsupportedOption ("moving-average-step-timeout" :: String)
111-
RpcTypes.Execute req -> withModule req._module $ \(def, mLlvmLibrary, mSMTOptions) -> Booster.Log.withContext CtxExecute $ do
112+
RpcTypes.Execute req -> withModule req._module $ \(def, mLlvmLibrary, mSMTOptions, rewriteOpts) -> Booster.Log.withContext CtxExecute $ do
112113
start <- liftIO $ getTime Monotonic
113114
-- internalise given constrained term
114115
let internalised = runExcept $ internalisePattern DisallowAlias CheckSubsorts Nothing def req.state.term
@@ -147,7 +148,7 @@ respond stateVar request =
147148

148149
solver <- traverse (SMT.initSolver def) mSMTOptions
149150
result <-
150-
performRewrite doTracing def mLlvmLibrary solver mbDepth cutPoints terminals substPat
151+
performRewrite doTracing def mLlvmLibrary solver mbDepth cutPoints terminals rewriteOpts.interimSimplification substPat
151152
whenJust solver SMT.finaliseSolver
152153
stop <- liftIO $ getTime Monotonic
153154
let duration =
@@ -218,7 +219,7 @@ respond stateVar request =
218219
Booster.Log.logMessage $
219220
"Added a new module. Now in scope: " <> Text.intercalate ", " (Map.keys newDefinitions)
220221
pure $ RpcTypes.AddModule $ RpcTypes.AddModuleResult moduleHash
221-
RpcTypes.Simplify req -> withModule req._module $ \(def, mLlvmLibrary, mSMTOptions) -> Booster.Log.withContext CtxSimplify $ do
222+
RpcTypes.Simplify req -> withModule req._module $ \(def, mLlvmLibrary, mSMTOptions, _) -> Booster.Log.withContext CtxSimplify $ do
222223
start <- liftIO $ getTime Monotonic
223224
let internalised =
224225
runExcept $ internaliseTermOrPredicate DisallowAlias CheckSubsorts Nothing def req.state.term
@@ -309,11 +310,11 @@ respond stateVar request =
309310
RpcTypes.SimplifyResult{state, logs = mkTraces duration}
310311
pure $ second mkSimplifyResponse result
311312
RpcTypes.GetModel req -> withModule req._module $ \case
312-
(_, _, Nothing) -> do
313+
(_, _, Nothing, _) -> do
313314
withContext CtxGetModel $
314315
logMessage' ("get-model request, not supported without SMT solver" :: Text)
315316
pure $ Left RpcError.notImplemented
316-
(def, _, Just smtOptions) -> do
317+
(def, _, Just smtOptions, _) -> do
317318
let internalised =
318319
runExcept $
319320
internaliseTermOrPredicate DisallowAlias CheckSubsorts Nothing def req.state.term
@@ -418,7 +419,7 @@ respond stateVar request =
418419
{ satisfiable = RpcTypes.Sat
419420
, substitution
420421
}
421-
RpcTypes.Implies req -> withModule req._module $ \(def, mLlvmLibrary, mSMTOptions) -> Booster.Log.withContext CtxImplies $ do
422+
RpcTypes.Implies req -> withModule req._module $ \(def, mLlvmLibrary, mSMTOptions, _) -> Booster.Log.withContext CtxImplies $ do
422423
-- internalise given constrained term
423424
let internalised =
424425
runExcept . internalisePattern DisallowAlias CheckSubsorts Nothing def . fst . extractExistentials
@@ -503,7 +504,7 @@ respond stateVar request =
503504
where
504505
withModule ::
505506
Maybe Text ->
506-
( (KoreDefinition, Maybe LLVM.API, Maybe SMT.SMTOptions) ->
507+
( (KoreDefinition, Maybe LLVM.API, Maybe SMT.SMTOptions, RewriteOptions) ->
507508
m (Either ErrorObj (RpcTypes.API 'RpcTypes.Res))
508509
) ->
509510
m (Either ErrorObj (RpcTypes.API 'RpcTypes.Res))
@@ -512,7 +513,7 @@ respond stateVar request =
512513
let mainName = fromMaybe state.defaultMain mbMainModule
513514
case Map.lookup mainName state.definitions of
514515
Nothing -> pure $ Left $ RpcError.backendError $ RpcError.CouldNotFindModule mainName
515-
Just d -> action (d, state.mLlvmLibrary, state.mSMTOptions)
516+
Just d -> action (d, state.mLlvmLibrary, state.mSMTOptions, state.rewriteOptions)
516517

517518
doesNotImply s l r =
518519
pure $
@@ -567,9 +568,11 @@ data ServerState = ServerState
567568
, defaultMain :: Text
568569
-- ^ default main module (initially from command line, could be changed later)
569570
, mLlvmLibrary :: Maybe LLVM.API
570-
-- ^ optional LLVM simplification library
571+
-- ^ Read-only: optional LLVM simplification library
571572
, mSMTOptions :: Maybe SMT.SMTOptions
572-
-- ^ (optional) SMT solver options
573+
-- ^ Read-only: (optional) SMT solver options
574+
, rewriteOptions :: RewriteOptions
575+
-- ^ Read-only: configuration related to booster rewriting
573576
, addedModules :: Map Text Text
574577
-- ^ map of raw modules added via add-module
575578
}

booster/library/Booster/Pattern/Rewrite.hs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -721,9 +721,11 @@ performRewrite ::
721721
[Text] ->
722722
-- | terminal rule labels
723723
[Text] ->
724+
-- | interim-simplification frequency
725+
(Maybe Natural) ->
724726
Pattern ->
725727
io (Natural, Seq (RewriteTrace ()), RewriteResult Pattern)
726-
performRewrite doTracing def mLlvmLibrary mSolver mbMaxDepth cutLabels terminalLabels pat = do
728+
performRewrite doTracing def mLlvmLibrary mSolver mbMaxDepth cutLabels terminalLabels mbSimplify pat = do
727729
(rr, RewriteStepsState{counter, traces}) <-
728730
flip runStateT rewriteStart $ doSteps False pat
729731
pure (counter, traces, rr)
@@ -807,11 +809,19 @@ performRewrite doTracing def mLlvmLibrary mSolver mbMaxDepth cutLabels terminalL
807809
doSteps wasSimplified pat' = do
808810
RewriteStepsState{counter, simplifierCache} <- get
809811
logDepth $ showCounter counter
810-
if depthReached counter
811-
then do
812-
logDepth $ "Reached maximum depth of " <> maybe "?" showCounter mbMaxDepth
813-
(if wasSimplified then pure else simplifyResult pat') $ RewriteFinished Nothing Nothing pat'
814-
else
812+
813+
case counter of
814+
c | depthReached c -> do
815+
logDepth $ "Reached maximum depth of " <> maybe "?" showCounter mbMaxDepth
816+
(if wasSimplified then pure else simplifyResult pat') $ RewriteFinished Nothing Nothing pat'
817+
| counter > 0
818+
, not wasSimplified
819+
, maybe False ((== 0) . (counter `mod`)) mbSimplify -> do
820+
logDepth $ "Interim simplification after " <> maybe "??" showCounter mbSimplify
821+
simplifyP pat' >>= \case
822+
Nothing -> pure $ RewriteTrivial pat'
823+
Just newPat -> doSteps True newPat
824+
| otherwise ->
815825
runRewriteT
816826
doTracing
817827
def

booster/tools/booster/Server.hs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ main = do
134134
, logFile
135135
, smtOptions
136136
, equationOptions
137-
, indexCells
137+
, rewriteOptions
138138
, prettyPrintOptions
139139
}
140140
, proxyOptions =
@@ -233,7 +233,7 @@ main = do
233233
mLlvmLibrary <- maybe (pure Nothing) (fmap Just . mkAPI) mdl
234234
definitionsWithCeilSummaries <-
235235
liftIO $
236-
loadDefinition indexCells definitionFile
236+
loadDefinition rewriteOptions.indexCells definitionFile
237237
>>= mapM (mapM (runNoLoggingT . computeCeilsDefinition mLlvmLibrary))
238238
>>= evaluate . force . either (error . show) id
239239
unless (isJust $ Map.lookup mainModuleName definitionsWithCeilSummaries) $ do
@@ -301,6 +301,7 @@ main = do
301301
, defaultMain = mainModuleName
302302
, mLlvmLibrary
303303
, mSMTOptions = if boosterSMT then smtOptions else Nothing
304+
, rewriteOptions
304305
, addedModules = mempty
305306
}
306307
statsVar <- Stats.newStats

booster/unit-tests/Test/Booster/Pattern/Rewrite.hs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ failsWith t err =
272272
runRewrite :: Term -> IO (Natural, RewriteResult Term)
273273
runRewrite t = do
274274
(counter, _, res) <-
275-
runNoLoggingT $ performRewrite NoCollectRewriteTraces def Nothing Nothing Nothing [] [] $ Pattern_ t
275+
runNoLoggingT $ performRewrite NoCollectRewriteTraces def Nothing Nothing Nothing [] [] Nothing $ Pattern_ t
276276
pure (counter, fmap (.term) res)
277277

278278
aborts :: RewriteFailed "Rewrite" -> Term -> IO ()
@@ -415,7 +415,7 @@ supportsDepthControl =
415415
rewritesToDepth (MaxDepth depth) (Steps n) t t' f = do
416416
(counter, _, res) <-
417417
runNoLoggingT $
418-
performRewrite NoCollectRewriteTraces def Nothing Nothing (Just depth) [] [] $
418+
performRewrite NoCollectRewriteTraces def Nothing Nothing (Just depth) [] [] Nothing $
419419
Pattern_ t
420420
(counter, fmap (.term) res) @?= (n, f t')
421421

@@ -469,7 +469,7 @@ supportsCutPoints =
469469
rewritesToCutPoint lbl (Steps n) t t' f = do
470470
(counter, _, res) <-
471471
runNoLoggingT $
472-
performRewrite NoCollectRewriteTraces def Nothing Nothing Nothing [lbl] [] $
472+
performRewrite NoCollectRewriteTraces def Nothing Nothing Nothing [lbl] [] Nothing $
473473
Pattern_ t
474474
(counter, fmap (.term) res) @?= (n, f t')
475475

@@ -501,5 +501,5 @@ supportsTerminalRules =
501501
rewritesToTerminal lbl (Steps n) t t' f = do
502502
(counter, _, res) <-
503503
runNoLoggingT $ do
504-
performRewrite NoCollectRewriteTraces def Nothing Nothing Nothing [] [lbl] $ Pattern_ t
504+
performRewrite NoCollectRewriteTraces def Nothing Nothing Nothing [] [lbl] Nothing $ Pattern_ t
505505
(counter, fmap (.term) res) @?= (n, f t')

dev-tools/booster-dev/Server.hs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ main = do
5959
, llvmLibraryFile
6060
, smtOptions
6161
, equationOptions
62-
, indexCells
62+
, rewriteOptions
6363
, prettyPrintOptions
6464
, logFile
6565
} = options
@@ -72,7 +72,7 @@ main = do
7272

7373
withLlvmLib llvmLibraryFile $ \mLlvmLibrary -> do
7474
definitionMap <-
75-
loadDefinition indexCells definitionFile
75+
loadDefinition rewriteOptions.indexCells definitionFile
7676
>>= mapM (mapM ((fst <$>) . runNoLoggingT . computeCeilsDefinition mLlvmLibrary))
7777
>>= evaluate . force . either (error . show) id
7878
-- ensure the (default) main module is present in the definition
@@ -88,6 +88,7 @@ main = do
8888
definitionMap
8989
mainModuleName
9090
mLlvmLibrary
91+
rewriteOptions
9192
logFile
9293
smtOptions
9394
(adjustLogLevels logLevels)
@@ -119,6 +120,7 @@ runServer ::
119120
Map Text KoreDefinition ->
120121
Text ->
121122
Maybe LLVM.API ->
123+
RewriteOptions ->
122124
Maybe FilePath ->
123125
Maybe SMT.SMTOptions ->
124126
(LogLevel, [LogLevel]) ->
@@ -128,7 +130,7 @@ runServer ::
128130
LogFormat ->
129131
[ModifierT] ->
130132
IO ()
131-
runServer port definitions defaultMain mLlvmLibrary logFile mSMTOptions (_logLevel, customLevels) logContexts logTimeStamps timeStampsFormat logFormat prettyPrintOptions =
133+
runServer port definitions defaultMain mLlvmLibrary rewriteOpts logFile mSMTOptions (_logLevel, customLevels) logContexts logTimeStamps timeStampsFormat logFormat prettyPrintOptions =
132134
do
133135
let timestampFlag = case timeStampsFormat of
134136
Pretty -> PrettyTimestamps
@@ -153,6 +155,7 @@ runServer port definitions defaultMain mLlvmLibrary logFile mSMTOptions (_logLev
153155
, defaultMain
154156
, mLlvmLibrary
155157
, mSMTOptions
158+
, rewriteOptions = rewriteOpts
156159
, addedModules = mempty
157160
}
158161
jsonRpcServer

0 commit comments

Comments
 (0)