Skip to content

Commit 458bbfc

Browse files
committed
Refactor getModel to be more like checkPredicates
1 parent b492d15 commit 458bbfc

File tree

1 file changed

+79
-77
lines changed

1 file changed

+79
-77
lines changed

booster/library/Booster/SMT/Interface.hs

Lines changed: 79 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -143,23 +143,13 @@ getModelFor ctxt ps subst
143143
| null ps && Map.null subst = Log.withContext "smt" $ do
144144
Log.logMessage ("No constraints or substitutions to check, returning Sat" :: Text)
145145
pure $ Right Map.empty
146-
| otherwise = runSMT ctxt $ do
146+
| Left errMsg <- translated = Log.withContext "smt" $ do
147+
smtTranslateError errMsg
148+
| Right (smtAsserts, transState) <- translated = Log.withContext "smt" $ runSMT ctxt $ do
147149
Log.logMessage $ "Checking, constraint count " <> pack (show $ Map.size subst + length ps)
148-
let translated =
149-
SMT.runTranslator $ do
150-
let mkSMTEquation v t =
151-
SMT.eq <$> SMT.translateTerm (Var v) <*> SMT.translateTerm t
152-
smtSubst <-
153-
mapM (\(v, t) -> Assert "Substitution" <$> mkSMTEquation v t) $ Map.assocs subst
154-
smtPs <-
155-
mapM (\(Predicate p) -> Assert (mkComment p) <$> SMT.translateTerm p) ps
156-
pure $ smtSubst <> smtPs
157-
freeVars =
150+
let freeVars =
158151
Set.unions $
159152
Map.keysSet subst : map ((.variables) . getAttributes . coerce) ps
160-
when (isLeft translated) $
161-
smtTranslateError (fromLeft' translated)
162-
let (smtAsserts, transState) = fromRight' translated
163153

164154
runCmd_ SMT.Push -- assuming the prelude has been run already,
165155

@@ -176,69 +166,81 @@ getModelFor ctxt ps subst
176166

177167
satResponse <- runCmd CheckSat
178168

179-
case satResponse of
180-
Error msg -> do
181-
runCmd_ SMT.Pop
182-
throwSMT' $ BS.unpack msg
183-
Unsat -> do
184-
runCmd_ SMT.Pop
185-
pure $ Left Unsat
186-
Unknown{} -> do
187-
res <- runCmd SMT.GetReasonUnknown
188-
runCmd_ SMT.Pop
189-
pure $ Left res
190-
r@ReasonUnknown{} ->
191-
pure $ Left r
192-
Values{} -> do
193-
runCmd_ SMT.Pop
194-
throwSMT' $ "Unexpected SMT response to CheckSat: " <> show satResponse
195-
Success -> do
196-
runCmd_ SMT.Pop
197-
throwSMT' $ "Unexpected SMT response to CheckSat: " <> show satResponse
198-
Sat -> do
199-
let freeVarsMap =
200-
Map.map Atom . Map.mapKeys getVar $
201-
Map.filterWithKey
202-
(const . (`Set.member` Set.map Var freeVars))
203-
transState.mappings
204-
getVar (Var v) = v
205-
getVar other =
206-
smtTranslateError . pack $
207-
"Solver returned non-var in translation state: " <> show other
208-
sortsToTranslate = Set.fromList [SortInt, SortBool]
209-
210-
(freeVarsToSExprs, untranslatableVars) =
211-
Map.partitionWithKey
212-
(const . ((`Set.member` sortsToTranslate) . (.variableSort)))
213-
freeVarsMap
214-
unless (Map.null untranslatableVars) $
215-
let vars = Pretty.renderText . hsep . map pretty $ Map.keys untranslatableVars
216-
in Log.logMessage ("Untranslatable variables in model: " <> vars)
217-
218-
response <-
219-
if Map.null freeVarsMap
220-
then pure $ Values []
221-
else runCmd $ GetValue (Map.elems freeVarsMap)
222-
runCmd_ SMT.Pop
223-
case response of
224-
Error msg ->
225-
throwSMT' $ BS.unpack msg
226-
Values pairs ->
227-
let (errors, values) =
228-
Map.partition isLeft
229-
. Map.map (valueToTerm transState)
230-
$ Map.compose (Map.fromList pairs) freeVarsToSExprs
231-
untranslated =
232-
Map.mapWithKey (const . Var) untranslatableVars
233-
in if null errors
234-
then pure $ Right $ Map.map fromRight' values <> untranslated
235-
else
236-
throwSMT . Text.unlines $
237-
( "SMT errors while converting results: "
238-
: map fromLeft' (Map.elems errors)
239-
)
240-
other ->
241-
throwSMT' $ "Unexpected SMT response to GetValue: " <> show other
169+
processSMTResult transState freeVars satResponse
170+
where
171+
translated =
172+
SMT.runTranslator $ do
173+
let mkSMTEquation v t =
174+
SMT.eq <$> SMT.translateTerm (Var v) <*> SMT.translateTerm t
175+
smtSubst <-
176+
mapM (\(v, t) -> Assert "Substitution" <$> mkSMTEquation v t) $ Map.assocs subst
177+
smtPs <-
178+
mapM (\(Predicate p) -> Assert (mkComment p) <$> SMT.translateTerm p) ps
179+
pure $ smtSubst <> smtPs
180+
181+
processSMTResult transState freeVars satResponse = case satResponse of
182+
Error msg -> do
183+
runCmd_ SMT.Pop
184+
throwSMT' $ BS.unpack msg
185+
Unsat -> do
186+
runCmd_ SMT.Pop
187+
pure $ Left Unsat
188+
Unknown{} -> do
189+
res <- runCmd SMT.GetReasonUnknown
190+
runCmd_ SMT.Pop
191+
pure $ Left res
192+
r@ReasonUnknown{} ->
193+
pure $ Left r
194+
Values{} -> do
195+
runCmd_ SMT.Pop
196+
throwSMT' $ "Unexpected SMT response to CheckSat: " <> show satResponse
197+
Success -> do
198+
runCmd_ SMT.Pop
199+
throwSMT' $ "Unexpected SMT response to CheckSat: " <> show satResponse
200+
Sat -> do
201+
let freeVarsMap =
202+
Map.map Atom . Map.mapKeys getVar $
203+
Map.filterWithKey
204+
(const . (`Set.member` Set.map Var freeVars))
205+
transState.mappings
206+
getVar (Var v) = v
207+
getVar other =
208+
smtTranslateError . pack $
209+
"Solver returned non-var in translation state: " <> show other
210+
sortsToTranslate = Set.fromList [SortInt, SortBool]
211+
212+
(freeVarsToSExprs, untranslatableVars) =
213+
Map.partitionWithKey
214+
(const . ((`Set.member` sortsToTranslate) . (.variableSort)))
215+
freeVarsMap
216+
unless (Map.null untranslatableVars) $
217+
let vars = Pretty.renderText . hsep . map pretty $ Map.keys untranslatableVars
218+
in Log.logMessage ("Untranslatable variables in model: " <> vars)
219+
220+
response <-
221+
if Map.null freeVarsMap
222+
then pure $ Values []
223+
else runCmd $ GetValue (Map.elems freeVarsMap)
224+
runCmd_ SMT.Pop
225+
case response of
226+
Error msg ->
227+
throwSMT' $ BS.unpack msg
228+
Values pairs ->
229+
let (errors, values) =
230+
Map.partition isLeft
231+
. Map.map (valueToTerm transState)
232+
$ Map.compose (Map.fromList pairs) freeVarsToSExprs
233+
untranslated =
234+
Map.mapWithKey (const . Var) untranslatableVars
235+
in if null errors
236+
then pure $ Right $ Map.map fromRight' values <> untranslated
237+
else
238+
throwSMT . Text.unlines $
239+
( "SMT errors while converting results: "
240+
: map fromLeft' (Map.elems errors)
241+
)
242+
other ->
243+
throwSMT' $ "Unexpected SMT response to GetValue: " <> show other
242244

243245
mkComment :: Pretty a => a -> BS.ByteString
244246
mkComment = BS.pack . Pretty.renderDefault . pretty

0 commit comments

Comments
 (0)