Skip to content

Commit 4ef3106

Browse files
committed
Bundle performRewrite options into RewriteConfig
1 parent d0c19f0 commit 4ef3106

File tree

3 files changed

+72
-60
lines changed

3 files changed

+72
-60
lines changed

booster/library/Booster/JsonRpc.hs

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ import Booster.Pattern.Bool (pattern TrueBool)
5454
import Booster.Pattern.Match (FailReason (..), MatchResult (..), MatchType (..), matchTerms)
5555
import Booster.Pattern.Pretty
5656
import Booster.Pattern.Rewrite (
57+
RewriteConfig (..),
5758
RewriteFailed (..),
5859
RewriteResult (..),
5960
RewriteTrace (..),
@@ -153,17 +154,26 @@ respond stateVar request =
153154
]
154155

155156
solver <- maybe (SMT.noSolver) (SMT.initSolver def) mSMTOptions
157+
158+
logger <- getLogger
159+
prettyModifiers <- getPrettyModifiers
160+
let rewriteConfig =
161+
RewriteConfig
162+
{ definition = def
163+
, llvmApi = mLlvmLibrary
164+
, smtSolver = solver
165+
, varsToAvoid = substVars
166+
, doTracing
167+
, logger
168+
, prettyModifiers
169+
, mbMaxDepth = mbDepth
170+
, mbSimplify = rewriteOpts.interimSimplification
171+
, cutLabels = cutPoints
172+
, terminalLabels = terminals
173+
}
156174
result <-
157175
performRewrite
158-
doTracing
159-
def
160-
mLlvmLibrary
161-
solver
162-
substVars
163-
mbDepth
164-
cutPoints
165-
terminals
166-
rewriteOpts.interimSimplification
176+
rewriteConfig
167177
substPat
168178
SMT.finaliseSolver solver
169179
stop <- liftIO $ getTime Monotonic

booster/library/Booster/Pattern/Rewrite.hs

Lines changed: 22 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ License : BSD-3-Clause
1010
module Booster.Pattern.Rewrite (
1111
performRewrite,
1212
rewriteStep,
13+
RewriteConfig (..),
1314
RewriteFailed (..),
1415
RewriteResult (..),
1516
RewriteTrace (..),
@@ -84,6 +85,9 @@ data RewriteConfig = RewriteConfig
8485
, doTracing :: Flag "CollectRewriteTraces"
8586
, logger :: Logger LogMessage
8687
, prettyModifiers :: ModifiersRep
88+
, -- below: parameters used only in performRewrite
89+
mbMaxDepth, mbSimplify :: Maybe Natural
90+
, cutLabels, terminalLabels :: [Text]
8791
}
8892

8993
instance MonadIO io => LoggerMIO (RewriteT io) where
@@ -99,25 +103,12 @@ pattern NoCollectRewriteTraces :: Flag "CollectRewriteTraces"
99103
pattern NoCollectRewriteTraces = Flag False
100104

101105
runRewriteT ::
102-
LoggerMIO io =>
103-
Flag "CollectRewriteTraces" ->
104-
KoreDefinition ->
105-
Maybe LLVM.API ->
106-
SMT.SMTContext ->
107-
Set.Set Variable ->
106+
RewriteConfig ->
108107
SimplifierCache ->
109108
RewriteT io a ->
110109
io (Either (RewriteFailed "Rewrite") (a, SimplifierCache))
111-
runRewriteT doTracing definition llvmApi smtSolver varsToAvoid cache m = do
112-
logger <- getLogger
113-
prettyModifiers <- getPrettyModifiers
114-
runExceptT
115-
. flip runStateT cache
116-
. flip
117-
runReaderT
118-
RewriteConfig{definition, llvmApi, smtSolver, varsToAvoid, doTracing, logger, prettyModifiers}
119-
. unRewriteT
120-
$ m
110+
runRewriteT rewriteConfig cache =
111+
runExceptT . flip runStateT cache . flip runReaderT rewriteConfig . unRewriteT
121112

122113
throw :: LoggerMIO io => RewriteFailed "Rewrite" -> RewriteT io a
123114
throw = RewriteT . lift . lift . throwE
@@ -704,27 +695,25 @@ mkDiffTerms = \case
704695
performRewrite ::
705696
forall io.
706697
LoggerMIO io =>
707-
Flag "CollectRewriteTraces" ->
708-
KoreDefinition ->
709-
Maybe LLVM.API ->
710-
SMT.SMTContext ->
711-
-- | Variable names to avoid (for new existentials)
712-
Set.Set Variable ->
713-
-- | maximum depth
714-
Maybe Natural ->
715-
-- | cut point rule labels
716-
[Text] ->
717-
-- | terminal rule labels
718-
[Text] ->
719-
-- | interim-simplification frequency
720-
(Maybe Natural) ->
698+
RewriteConfig ->
721699
Pattern ->
722700
io (Natural, Seq (RewriteTrace ()), RewriteResult Pattern)
723-
performRewrite doTracing def mLlvmLibrary smtSolver varsToAvoid mbMaxDepth cutLabels terminalLabels mbSimplify pat = do
701+
performRewrite rewriteConfig pat = do
724702
(rr, RewriteStepsState{counter, traces}) <-
725703
flip runStateT rewriteStart $ doSteps False pat
726704
pure (counter, traces, rr)
727705
where
706+
RewriteConfig
707+
{ definition
708+
, llvmApi
709+
, smtSolver
710+
, doTracing
711+
, mbMaxDepth
712+
, mbSimplify
713+
, cutLabels
714+
, terminalLabels
715+
} = rewriteConfig
716+
728717
logDepth = withContext CtxDepth . logMessage
729718

730719
depthReached n = maybe False (n >=) mbMaxDepth
@@ -745,7 +734,7 @@ performRewrite doTracing def mLlvmLibrary smtSolver varsToAvoid mbMaxDepth cutLa
745734
simplifyP p = withContext CtxSimplify $ do
746735
st <- get
747736
let cache = st.simplifierCache
748-
evaluatePattern def mLlvmLibrary smtSolver cache p >>= \(res, newCache) -> do
737+
evaluatePattern definition llvmApi smtSolver cache p >>= \(res, newCache) -> do
749738
updateCache newCache
750739
case res of
751740
Right newPattern -> do
@@ -818,11 +807,7 @@ performRewrite doTracing def mLlvmLibrary smtSolver varsToAvoid mbMaxDepth cutLa
818807
Just newPat -> doSteps True newPat
819808
| otherwise ->
820809
runRewriteT
821-
doTracing
822-
def
823-
mLlvmLibrary
824-
smtSolver
825-
varsToAvoid
810+
rewriteConfig
826811
simplifierCache
827812
(withPatternContext pat' $ rewriteStep cutLabels terminalLabels pat')
828813
>>= \case

booster/unit-tests/Test/Booster/Pattern/Rewrite.hs

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,18 @@ import Data.Bifunctor (second)
1414
import Data.List.NonEmpty qualified as NE
1515
import Data.Map (Map)
1616
import Data.Map qualified as Map
17+
import Data.Proxy (Proxy (..))
1718
import Data.Text (Text)
1819
import Numeric.Natural
1920
import Test.Tasty
2021
import Test.Tasty.HUnit
2122

2223
import Booster.Definition.Attributes.Base
2324
import Booster.Definition.Base
25+
import Booster.Log (Logger (..))
2426
import Booster.Pattern.Base
2527
import Booster.Pattern.Index (CellIndex (..), TermIndex (..))
28+
import Booster.Pattern.Pretty (ModifiersRep (..))
2629
import Booster.Pattern.Rewrite
2730
import Booster.SMT.Interface (noSolver)
2831
import Booster.Syntax.Json.Internalise (trm)
@@ -171,6 +174,24 @@ mkTheory = Map.map mkPriorityGroups . Map.fromList
171174
d :: Term
172175
d = dv someSort "thing"
173176

177+
testConf :: IO RewriteConfig
178+
testConf = do
179+
smtSolver <- noSolver
180+
pure
181+
RewriteConfig
182+
{ definition = def
183+
, llvmApi = Nothing
184+
, smtSolver
185+
, varsToAvoid = mempty
186+
, doTracing = NoCollectRewriteTraces
187+
, logger = Logger $ const $ pure ()
188+
, prettyModifiers = ModifiersRep @'[] Proxy
189+
, mbMaxDepth = Nothing
190+
, mbSimplify = Nothing
191+
, cutLabels = []
192+
, terminalLabels = []
193+
}
194+
174195
----------------------------------------
175196
errorCases
176197
, rewriteSuccess
@@ -244,9 +265,8 @@ rulePriority =
244265
runWith :: Term -> IO (Either (RewriteFailed "Rewrite") (RewriteResult Pattern))
245266
runWith t =
246267
second fst <$> do
247-
ns <- noSolver
248-
runNoLoggingT $
249-
runRewriteT NoCollectRewriteTraces def Nothing ns mempty mempty (rewriteStep [] [] $ Pattern_ t)
268+
conf <- testConf
269+
runNoLoggingT $ runRewriteT conf mempty (rewriteStep [] [] $ Pattern_ t)
250270

251271
rewritesTo :: Term -> (Text, Term) -> IO ()
252272
t1 `rewritesTo` (lbl, t2) =
@@ -271,10 +291,10 @@ failsWith t err =
271291

272292
runRewrite :: Term -> IO (Natural, RewriteResult Term)
273293
runRewrite t = do
274-
ns <- noSolver
294+
conf <- testConf
275295
(counter, _, res) <-
276296
runNoLoggingT $
277-
performRewrite NoCollectRewriteTraces def Nothing ns mempty Nothing [] [] Nothing $
297+
performRewrite conf $
278298
Pattern_ t
279299
pure (counter, fmap (.term) res)
280300

@@ -416,11 +436,9 @@ supportsDepthControl =
416436
where
417437
rewritesToDepth :: MaxDepth -> Steps -> Term -> t -> (t -> RewriteResult Term) -> IO ()
418438
rewritesToDepth (MaxDepth depth) (Steps n) t t' f = do
419-
ns <- noSolver
439+
conf <- testConf
420440
(counter, _, res) <-
421-
runNoLoggingT $
422-
performRewrite NoCollectRewriteTraces def Nothing ns mempty (Just depth) [] [] Nothing $
423-
Pattern_ t
441+
runNoLoggingT $ performRewrite conf{mbMaxDepth = Just depth} $ Pattern_ t
424442
(counter, fmap (.term) res) @?= (n, f t')
425443

426444
supportsCutPoints :: TestTree
@@ -471,10 +489,10 @@ supportsCutPoints =
471489
where
472490
rewritesToCutPoint :: Text -> Steps -> Term -> t -> (t -> RewriteResult Term) -> IO ()
473491
rewritesToCutPoint lbl (Steps n) t t' f = do
474-
ns <- noSolver
492+
conf <- testConf
475493
(counter, _, res) <-
476494
runNoLoggingT $
477-
performRewrite NoCollectRewriteTraces def Nothing ns mempty Nothing [lbl] [] Nothing $
495+
performRewrite conf{cutLabels = [lbl]} $
478496
Pattern_ t
479497
(counter, fmap (.term) res) @?= (n, f t')
480498

@@ -504,8 +522,7 @@ supportsTerminalRules =
504522
where
505523
rewritesToTerminal :: Text -> Steps -> Term -> t -> (t -> RewriteResult Term) -> IO ()
506524
rewritesToTerminal lbl (Steps n) t t' f = do
525+
conf <- testConf
507526
(counter, _, res) <-
508-
runNoLoggingT $ do
509-
ns <- noSolver
510-
performRewrite NoCollectRewriteTraces def Nothing ns mempty Nothing [] [lbl] Nothing $ Pattern_ t
527+
runNoLoggingT $ performRewrite conf{terminalLabels = [lbl]} $ Pattern_ t
511528
(counter, fmap (.term) res) @?= (n, f t')

0 commit comments

Comments
 (0)