66e52ed9f4761837710ee2415990ae1772ccf584
[ghc.git] / compiler / utils / ListT.hs
1 {-# LANGUAGE CPP #-}
2 {-# LANGUAGE DeriveFunctor #-}
3 {-# LANGUAGE UndecidableInstances #-}
4 {-# LANGUAGE Rank2Types #-}
5 {-# LANGUAGE FlexibleInstances #-}
6 {-# LANGUAGE MultiParamTypeClasses #-}
7
8 -------------------------------------------------------------------------
9 -- |
10 -- Module : Control.Monad.Logic
11 -- Copyright : (c) Dan Doel
12 -- License : BSD3
13 --
14 -- Maintainer : dan.doel@gmail.com
15 -- Stability : experimental
16 -- Portability : non-portable (multi-parameter type classes)
17 --
18 -- A backtracking, logic programming monad.
19 --
20 -- Adapted from the paper
21 -- /Backtracking, Interleaving, and Terminating
22 -- Monad Transformers/, by
23 -- Oleg Kiselyov, Chung-chieh Shan, Daniel P. Friedman, Amr Sabry
24 -- (<http://www.cs.rutgers.edu/~ccshan/logicprog/ListT-icfp2005.pdf>).
25 -------------------------------------------------------------------------
26
27 module ListT (
28 ListT(..),
29 runListT,
30 select,
31 fold
32 ) where
33
34 import GhcPrelude
35
36 import Control.Applicative
37
38 import Control.Monad
39 import Control.Monad.Fail as MonadFail
40
41 -------------------------------------------------------------------------
42 -- | A monad transformer for performing backtracking computations
43 -- layered over another monad 'm'
44 newtype ListT m a =
45 ListT { unListT :: forall r. (a -> m r -> m r) -> m r -> m r }
46 deriving (Functor)
47
48 select :: Monad m => [a] -> ListT m a
49 select xs = foldr (<|>) mzero (map pure xs)
50
51 fold :: ListT m a -> (a -> m r -> m r) -> m r -> m r
52 fold = runListT
53
54 -------------------------------------------------------------------------
55 -- | Runs a ListT computation with the specified initial success and
56 -- failure continuations.
57 runListT :: ListT m a -> (a -> m r -> m r) -> m r -> m r
58 runListT = unListT
59
60 instance Applicative (ListT f) where
61 pure a = ListT $ \sk fk -> sk a fk
62 f <*> a = ListT $ \sk fk -> unListT f (\g fk' -> unListT a (sk . g) fk') fk
63
64 instance Alternative (ListT f) where
65 empty = ListT $ \_ fk -> fk
66 f1 <|> f2 = ListT $ \sk fk -> unListT f1 sk (unListT f2 sk fk)
67
68 instance Monad (ListT m) where
69 m >>= f = ListT $ \sk fk -> unListT m (\a fk' -> unListT (f a) sk fk') fk
70 #if !MIN_VERSION_base(4,13,0)
71 fail = MonadFail.fail
72 #endif
73
74 instance MonadFail.MonadFail (ListT m) where
75 fail _ = ListT $ \_ fk -> fk
76
77 instance MonadPlus (ListT m) where
78 mzero = ListT $ \_ fk -> fk
79 m1 `mplus` m2 = ListT $ \sk fk -> unListT m1 sk (unListT m2 sk fk)