Skip to content

Commit d9f1318

Browse files
committed
Refactor getModel to be more like checkPredicates
1 parent 35538d6 commit d9f1318

File tree

1 file changed

+83
-77
lines changed

1 file changed

+83
-77
lines changed

booster/library/Booster/SMT/Interface.hs

Lines changed: 83 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ Returns either 'Unsat' or 'Unknown' otherwise, depending on whether
134134
the solver could determine 'Unsat'.
135135
-}
136136
getModelFor ::
137+
forall io.
137138
Log.LoggerMIO io =>
139+
MonadLoggerIO io =>
138140
SMT.SMTContext ->
139141
[Predicate] ->
140142
Map Variable Term -> -- supplied substitution
@@ -143,23 +145,15 @@ getModelFor ctxt ps subst
143145
| null ps && Map.null subst = Log.withContext "smt" $ do
144146
Log.logMessage ("No constraints or substitutions to check, returning Sat" :: Text)
145147
pure $ Right Map.empty
146-
| otherwise = runSMT ctxt $ do
148+
| Left errMsg <- translated = Log.withContext "smt" $ do
149+
Log.logErrorNS "booster" $ "SMT translation error: " <> errMsg
150+
Log.logMessage $ "SMT translation error: " <> errMsg
151+
smtTranslateError errMsg
152+
| Right (smtAsserts, transState) <- translated = Log.withContext "smt" $ runSMT ctxt $ do
147153
Log.logMessage $ "Checking, constraint count " <> pack (show $ Map.size subst + length ps)
148-
let translated =
149-
SMT.runTranslator $ do
150-
let mkSMTEquation v t =
151-
SMT.eq <$> SMT.translateTerm (Var v) <*> SMT.translateTerm t
152-
smtSubst <-
153-
mapM (\(v, t) -> Assert "Substitution" <$> mkSMTEquation v t) $ Map.assocs subst
154-
smtPs <-
155-
mapM (\(Predicate p) -> Assert (mkComment p) <$> SMT.translateTerm p) ps
156-
pure $ smtSubst <> smtPs
157-
freeVars =
154+
let freeVars =
158155
Set.unions $
159156
Map.keysSet subst : map ((.variables) . getAttributes . coerce) ps
160-
when (isLeft translated) $
161-
smtTranslateError (fromLeft' translated)
162-
let (smtAsserts, transState) = fromRight' translated
163157

164158
runCmd_ SMT.Push -- assuming the prelude has been run already,
165159

@@ -176,69 +170,81 @@ getModelFor ctxt ps subst
176170

177171
satResponse <- runCmd CheckSat
178172

179-
case satResponse of
180-
Error msg -> do
181-
runCmd_ SMT.Pop
182-
throwSMT' $ BS.unpack msg
183-
Unsat -> do
184-
runCmd_ SMT.Pop
185-
pure $ Left Unsat
186-
Unknown{} -> do
187-
res <- runCmd SMT.GetReasonUnknown
188-
runCmd_ SMT.Pop
189-
pure $ Left res
190-
r@ReasonUnknown{} ->
191-
pure $ Left r
192-
Values{} -> do
193-
runCmd_ SMT.Pop
194-
throwSMT' $ "Unexpected SMT response to CheckSat: " <> show satResponse
195-
Success -> do
196-
runCmd_ SMT.Pop
197-
throwSMT' $ "Unexpected SMT response to CheckSat: " <> show satResponse
198-
Sat -> do
199-
let freeVarsMap =
200-
Map.map Atom . Map.mapKeys getVar $
201-
Map.filterWithKey
202-
(const . (`Set.member` Set.map Var freeVars))
203-
transState.mappings
204-
getVar (Var v) = v
205-
getVar other =
206-
smtTranslateError . pack $
207-
"Solver returned non-var in translation state: " <> show other
208-
sortsToTranslate = Set.fromList [SortInt, SortBool]
209-
210-
(freeVarsToSExprs, untranslatableVars) =
211-
Map.partitionWithKey
212-
(const . ((`Set.member` sortsToTranslate) . (.variableSort)))
213-
freeVarsMap
214-
unless (Map.null untranslatableVars) $
215-
let vars = Pretty.renderText . hsep . map pretty $ Map.keys untranslatableVars
216-
in Log.logMessage ("Untranslatable variables in model: " <> vars)
217-
218-
response <-
219-
if Map.null freeVarsMap
220-
then pure $ Values []
221-
else runCmd $ GetValue (Map.elems freeVarsMap)
222-
runCmd_ SMT.Pop
223-
case response of
224-
Error msg ->
225-
throwSMT' $ BS.unpack msg
226-
Values pairs ->
227-
let (errors, values) =
228-
Map.partition isLeft
229-
. Map.map (valueToTerm transState)
230-
$ Map.compose (Map.fromList pairs) freeVarsToSExprs
231-
untranslated =
232-
Map.mapWithKey (const . Var) untranslatableVars
233-
in if null errors
234-
then pure $ Right $ Map.map fromRight' values <> untranslated
235-
else
236-
throwSMT . Text.unlines $
237-
( "SMT errors while converting results: "
238-
: map fromLeft' (Map.elems errors)
239-
)
240-
other ->
241-
throwSMT' $ "Unexpected SMT response to GetValue: " <> show other
173+
processSMTResult transState freeVars satResponse
174+
where
175+
translated =
176+
SMT.runTranslator $ do
177+
let mkSMTEquation v t =
178+
SMT.eq <$> SMT.translateTerm (Var v) <*> SMT.translateTerm t
179+
smtSubst <-
180+
mapM (\(v, t) -> Assert "Substitution" <$> mkSMTEquation v t) $ Map.assocs subst
181+
smtPs <-
182+
mapM (\(Predicate p) -> Assert (mkComment p) <$> SMT.translateTerm p) ps
183+
pure $ smtSubst <> smtPs
184+
185+
processSMTResult transState freeVars satResponse = case satResponse of
186+
Error msg -> do
187+
runCmd_ SMT.Pop
188+
throwSMT' $ BS.unpack msg
189+
Unsat -> do
190+
runCmd_ SMT.Pop
191+
pure $ Left Unsat
192+
Unknown{} -> do
193+
res <- runCmd SMT.GetReasonUnknown
194+
runCmd_ SMT.Pop
195+
pure $ Left res
196+
r@ReasonUnknown{} ->
197+
pure $ Left r
198+
Values{} -> do
199+
runCmd_ SMT.Pop
200+
throwSMT' $ "Unexpected SMT response to CheckSat: " <> show satResponse
201+
Success -> do
202+
runCmd_ SMT.Pop
203+
throwSMT' $ "Unexpected SMT response to CheckSat: " <> show satResponse
204+
Sat -> do
205+
let freeVarsMap =
206+
Map.map Atom . Map.mapKeys getVar $
207+
Map.filterWithKey
208+
(const . (`Set.member` Set.map Var freeVars))
209+
transState.mappings
210+
getVar (Var v) = v
211+
getVar other =
212+
smtTranslateError . pack $
213+
"Solver returned non-var in translation state: " <> show other
214+
sortsToTranslate = Set.fromList [SortInt, SortBool]
215+
216+
(freeVarsToSExprs, untranslatableVars) =
217+
Map.partitionWithKey
218+
(const . ((`Set.member` sortsToTranslate) . (.variableSort)))
219+
freeVarsMap
220+
unless (Map.null untranslatableVars) $
221+
let vars = Pretty.renderText . hsep . map pretty $ Map.keys untranslatableVars
222+
in Log.logMessage ("Untranslatable variables in model: " <> vars)
223+
224+
response <-
225+
if Map.null freeVarsMap
226+
then pure $ Values []
227+
else runCmd $ GetValue (Map.elems freeVarsMap)
228+
runCmd_ SMT.Pop
229+
case response of
230+
Error msg ->
231+
throwSMT' $ BS.unpack msg
232+
Values pairs ->
233+
let (errors, values) =
234+
Map.partition isLeft
235+
. Map.map (valueToTerm transState)
236+
$ Map.compose (Map.fromList pairs) freeVarsToSExprs
237+
untranslated =
238+
Map.mapWithKey (const . Var) untranslatableVars
239+
in if null errors
240+
then pure $ Right $ Map.map fromRight' values <> untranslated
241+
else
242+
throwSMT . Text.unlines $
243+
( "SMT errors while converting results: "
244+
: map fromLeft' (Map.elems errors)
245+
)
246+
other ->
247+
throwSMT' $ "Unexpected SMT response to GetValue: " <> show other
242248

243249
mkComment :: Pretty a => a -> BS.ByteString
244250
mkComment = BS.pack . Pretty.renderDefault . pretty

0 commit comments

Comments
 (0)