Skip to content

Commit eec6d2b

Browse files
committed
Simplify RewriteRuleAppT monad transformer
1 parent 45d89cf commit eec6d2b

File tree

1 file changed

+21
-59
lines changed

1 file changed

+21
-59
lines changed

booster/library/Booster/Pattern/Rewrite.hs

Lines changed: 21 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -222,59 +222,27 @@ rewriteStep cutLabels terminalLabels pat = do
222222
NE.fromList $
223223
map (\(r, p) -> (ruleLabelOrLocT r, uniqueId r, p)) rxs
224224

225+
-- | Rewrite rule application transformer: may throw exceptions on non-applicable or trivial rule applications
226+
type RewriteRuleAppT m a = ExceptT RewriteRuleAppException m a
227+
228+
data RewriteRuleAppException = RewriteRuleNotApplied | RewriteRuleTrivial deriving (Show, Eq)
229+
230+
runRewriteRuleAppT :: Monad m => RewriteRuleAppT m a -> m (RewriteRuleAppResult a)
231+
runRewriteRuleAppT action =
232+
runExceptT action >>= \case
233+
Left RewriteRuleNotApplied -> pure NotApplied
234+
Left RewriteRuleTrivial -> pure Trivial
235+
Right result -> pure (Applied result)
236+
225237
data RewriteRuleAppResult a
226238
= Applied a
227239
| NotApplied
228240
| Trivial
229241
deriving (Show, Eq, Functor)
230242

231-
newtype RewriteRuleAppT m a = RewriteRuleAppT {runRewriteRuleAppT :: m (RewriteRuleAppResult a)}
232-
deriving (Functor)
233-
234-
instance Monad m => Applicative (RewriteRuleAppT m) where
235-
pure = RewriteRuleAppT . return . Applied
236-
{-# INLINE pure #-}
237-
mf <*> mx = RewriteRuleAppT $ do
238-
mb_f <- runRewriteRuleAppT mf
239-
case mb_f of
240-
NotApplied -> return NotApplied
241-
Trivial -> return Trivial
242-
Applied f -> do
243-
mb_x <- runRewriteRuleAppT mx
244-
case mb_x of
245-
NotApplied -> return NotApplied
246-
Trivial -> return Trivial
247-
Applied x -> return (Applied (f x))
248-
{-# INLINE (<*>) #-}
249-
m *> k = m >> k
250-
{-# INLINE (*>) #-}
251-
252-
instance Monad m => Monad (RewriteRuleAppT m) where
253-
return = pure
254-
{-# INLINE return #-}
255-
x >>= f = RewriteRuleAppT $ do
256-
v <- runRewriteRuleAppT x
257-
case v of
258-
Applied y -> runRewriteRuleAppT (f y)
259-
NotApplied -> return NotApplied
260-
Trivial -> return Trivial
261-
{-# INLINE (>>=) #-}
262-
263-
instance MonadTrans RewriteRuleAppT where
264-
lift :: Monad m => m a -> RewriteRuleAppT m a
265-
lift = RewriteRuleAppT . fmap Applied
266-
{-# INLINE lift #-}
267-
268-
instance Monad m => MonadFail (RewriteRuleAppT m) where
269-
fail _ = RewriteRuleAppT (return NotApplied)
270-
{-# INLINE fail #-}
271-
272-
instance MonadIO m => MonadIO (RewriteRuleAppT m) where
273-
liftIO = lift . liftIO
274-
{-# INLINE liftIO #-}
275-
276-
instance LoggerMIO m => LoggerMIO (RewriteRuleAppT m) where
277-
withLogger l (RewriteRuleAppT m) = RewriteRuleAppT $ withLogger l m
243+
returnTrivial, returnNotApplied :: Monad m => RewriteRuleAppT m a
244+
returnTrivial = throwE RewriteRuleTrivial
245+
returnNotApplied = throwE RewriteRuleNotApplied
278246

279247
{- | Tries to apply one rewrite rule:
280248
@@ -310,7 +278,7 @@ applyRule pat@Pattern{ceilConditions} rule =
310278
failRewrite $ InternalMatchError $ renderText $ pretty' @mods err
311279
MatchFailed reason -> do
312280
withContext CtxFailure $ logPretty' @mods reason
313-
fail "Rule matching failed"
281+
returnNotApplied
314282
MatchIndeterminate remainder -> do
315283
withContext CtxIndeterminate $
316284
logMessage $
@@ -417,12 +385,6 @@ applyRule pat@Pattern{ceilConditions} rule =
417385
failRewrite :: RewriteFailed "Rewrite" -> RewriteRuleAppT (RewriteT io) a
418386
failRewrite = lift . (throw)
419387

420-
notAppliedIfBottom :: RewriteRuleAppT (RewriteT io) a
421-
notAppliedIfBottom = RewriteRuleAppT $ pure NotApplied
422-
423-
trivialIfBottom :: RewriteRuleAppT (RewriteT io) a
424-
trivialIfBottom = RewriteRuleAppT $ pure Trivial
425-
426388
checkConstraint ::
427389
(Predicate -> a) ->
428390
RewriteRuleAppT (RewriteT io) (Maybe a) ->
@@ -457,7 +419,7 @@ applyRule pat@Pattern{ceilConditions} rule =
457419

458420
-- simplify the constraints (one by one in isolation). Stop if false, abort rewrite if indeterminate.
459421
unclearRequires <-
460-
catMaybes <$> mapM (checkConstraint id notAppliedIfBottom pat.constraints) toCheck
422+
catMaybes <$> mapM (checkConstraint id returnNotApplied pat.constraints) toCheck
461423

462424
-- unclear conditions may have been simplified and
463425
-- could now be syntactically present in the path constraints, filter again
@@ -473,7 +435,7 @@ applyRule pat@Pattern{ceilConditions} rule =
473435
SMT.IsInvalid -> do
474436
-- requires is actually false given the prior
475437
withContext CtxFailure $ logMessage ("Required clauses evaluated to #Bottom." :: Text)
476-
RewriteRuleAppT $ pure NotApplied
438+
returnNotApplied
477439
SMT.IsValid ->
478440
pure [] -- can proceed
479441
checkEnsures ::
@@ -483,7 +445,7 @@ applyRule pat@Pattern{ceilConditions} rule =
483445
let ruleEnsures =
484446
concatMap (splitBoolPredicates . coerce . substituteInTerm matchingSubst . coerce) rule.ensures
485447
newConstraints <-
486-
catMaybes <$> mapM (checkConstraint id trivialIfBottom pat.constraints) ruleEnsures
448+
catMaybes <$> mapM (checkConstraint id returnTrivial pat.constraints) ruleEnsures
487449

488450
-- check all new constraints together with the known side constraints
489451
solver <- lift $ RewriteT $ (.smtSolver) <$> ask
@@ -492,10 +454,10 @@ applyRule pat@Pattern{ceilConditions} rule =
492454
(lift $ SMT.checkPredicates solver pat.constraints mempty (Set.fromList newConstraints)) >>= \case
493455
SMT.IsInvalid -> do
494456
withContext CtxSuccess $ logMessage ("New constraints evaluated to #Bottom." :: Text)
495-
RewriteRuleAppT $ pure Trivial
457+
returnTrivial
496458
SMT.IsUnknown SMT.InconsistentGroundTruth -> do
497459
withContext CtxSuccess $ logMessage ("Ground truth is #Bottom." :: Text)
498-
RewriteRuleAppT $ pure Trivial
460+
returnTrivial
499461
SMT.IsUnknown SMT.ImplicationIndeterminate -> do
500462
-- the new constraint is satisfiable, continue
501463
pure ()

0 commit comments

Comments
 (0)