expand definitions of Applicative and Alternative methods (fixes #4)
[packages/transformers.git] / Control / Monad / Trans / State / Strict.hs
1 {-# LANGUAGE CPP #-}
2 #if __GLASGOW_HASKELL__ >= 702
3 {-# LANGUAGE Safe #-}
4 #endif
5 #if __GLASGOW_HASKELL__ >= 710
6 {-# LANGUAGE AutoDeriveTypeable #-}
7 #endif
8 -----------------------------------------------------------------------------
9 -- |
10 -- Module : Control.Monad.Trans.State.Strict
11 -- Copyright : (c) Andy Gill 2001,
12 -- (c) Oregon Graduate Institute of Science and Technology, 2001
13 -- License : BSD-style (see the file LICENSE)
14 --
15 -- Maintainer : R.Paterson@city.ac.uk
16 -- Stability : experimental
17 -- Portability : portable
18 --
19 -- Strict state monads, passing an updatable state through a computation.
20 -- See below for examples.
21 --
22 -- Some computations may not require the full power of state transformers:
23 --
24 -- * For a read-only state, see "Control.Monad.Trans.Reader".
25 --
26 -- * To accumulate a value without using it on the way, see
27 -- "Control.Monad.Trans.Writer".
28 --
29 -- In this version, sequencing of computations is strict (but computations
30 -- are not strict in the state unless you force it with 'seq' or the like).
31 -- For a lazy version with the same interface, see
32 -- "Control.Monad.Trans.State.Lazy".
33 -----------------------------------------------------------------------------
34
35 module Control.Monad.Trans.State.Strict (
36 -- * The State monad
37 State,
38 state,
39 runState,
40 evalState,
41 execState,
42 mapState,
43 withState,
44 -- * The StateT monad transformer
45 StateT(..),
46 evalStateT,
47 execStateT,
48 mapStateT,
49 withStateT,
50 -- * State operations
51 get,
52 put,
53 modify,
54 modify',
55 gets,
56 -- * Lifting other operations
57 liftCallCC,
58 liftCallCC',
59 liftCatch,
60 liftListen,
61 liftPass,
62 -- * Examples
63 -- ** State monads
64 -- $examples
65
66 -- ** Counting
67 -- $counting
68
69 -- ** Labelling trees
70 -- $labelling
71 ) where
72
73 import Control.Monad.IO.Class
74 import Control.Monad.Signatures
75 import Control.Monad.Trans.Class
76 import Data.Functor.Identity
77
78 import Control.Applicative
79 import Control.Monad
80 import Control.Monad.Fix
81
82 -- ---------------------------------------------------------------------------
83 -- | A state monad parameterized by the type @s@ of the state to carry.
84 --
85 -- The 'return' function leaves the state unchanged, while @>>=@ uses
86 -- the final state of the first computation as the initial state of
87 -- the second.
88 type State s = StateT s Identity
89
90 -- | Construct a state monad computation from a function.
91 -- (The inverse of 'runState'.)
92 state :: (Monad m)
93 => (s -> (a, s)) -- ^pure state transformer
94 -> StateT s m a -- ^equivalent state-passing computation
95 state f = StateT (return . f)
96
97 -- | Unwrap a state monad computation as a function.
98 -- (The inverse of 'state'.)
99 runState :: State s a -- ^state-passing computation to execute
100 -> s -- ^initial state
101 -> (a, s) -- ^return value and final state
102 runState m = runIdentity . runStateT m
103
104 -- | Evaluate a state computation with the given initial state
105 -- and return the final value, discarding the final state.
106 --
107 -- * @'evalState' m s = 'fst' ('runState' m s)@
108 evalState :: State s a -- ^state-passing computation to execute
109 -> s -- ^initial value
110 -> a -- ^return value of the state computation
111 evalState m s = fst (runState m s)
112
113 -- | Evaluate a state computation with the given initial state
114 -- and return the final state, discarding the final value.
115 --
116 -- * @'execState' m s = 'snd' ('runState' m s)@
117 execState :: State s a -- ^state-passing computation to execute
118 -> s -- ^initial value
119 -> s -- ^final state
120 execState m s = snd (runState m s)
121
122 -- | Map both the return value and final state of a computation using
123 -- the given function.
124 --
125 -- * @'runState' ('mapState' f m) = f . 'runState' m@
126 mapState :: ((a, s) -> (b, s)) -> State s a -> State s b
127 mapState f = mapStateT (Identity . f . runIdentity)
128
129 -- | @'withState' f m@ executes action @m@ on a state modified by
130 -- applying @f@.
131 --
132 -- * @'withState' f m = 'modify' f >> m@
133 withState :: (s -> s) -> State s a -> State s a
134 withState = withStateT
135
136 -- ---------------------------------------------------------------------------
137 -- | A state transformer monad parameterized by:
138 --
139 -- * @s@ - The state.
140 --
141 -- * @m@ - The inner monad.
142 --
143 -- The 'return' function leaves the state unchanged, while @>>=@ uses
144 -- the final state of the first computation as the initial state of
145 -- the second.
146 newtype StateT s m a = StateT { runStateT :: s -> m (a,s) }
147
148 -- | Evaluate a state computation with the given initial state
149 -- and return the final value, discarding the final state.
150 --
151 -- * @'evalStateT' m s = 'liftM' 'fst' ('runStateT' m s)@
152 evalStateT :: (Monad m) => StateT s m a -> s -> m a
153 evalStateT m s = do
154 (a, _) <- runStateT m s
155 return a
156
157 -- | Evaluate a state computation with the given initial state
158 -- and return the final state, discarding the final value.
159 --
160 -- * @'execStateT' m s = 'liftM' 'snd' ('runStateT' m s)@
161 execStateT :: (Monad m) => StateT s m a -> s -> m s
162 execStateT m s = do
163 (_, s') <- runStateT m s
164 return s'
165
166 -- | Map both the return value and final state of a computation using
167 -- the given function.
168 --
169 -- * @'runStateT' ('mapStateT' f m) = f . 'runStateT' m@
170 mapStateT :: (m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
171 mapStateT f m = StateT $ f . runStateT m
172
173 -- | @'withStateT' f m@ executes action @m@ on a state modified by
174 -- applying @f@.
175 --
176 -- * @'withStateT' f m = 'modify' f >> m@
177 withStateT :: (s -> s) -> StateT s m a -> StateT s m a
178 withStateT f m = StateT $ runStateT m . f
179
180 instance (Functor m) => Functor (StateT s m) where
181 fmap f m = StateT $ \ s ->
182 fmap (\ (a, s') -> (f a, s')) $ runStateT m s
183
184 instance (Functor m, Monad m) => Applicative (StateT s m) where
185 pure a = StateT $ \ s -> return (a, s)
186 StateT mf <*> StateT mx = StateT $ \ s -> do
187 (f, s') <- mf s
188 (x, s'') <- mx s'
189 return (f x, s'')
190 {-# INLINE (<*>) #-}
191
192 instance (Functor m, MonadPlus m) => Alternative (StateT s m) where
193 empty = StateT $ \ _ -> mzero
194 StateT m <|> StateT n = StateT $ \ s -> m s `mplus` n s
195
196 instance (Monad m) => Monad (StateT s m) where
197 return a = StateT $ \ s -> return (a, s)
198 m >>= k = StateT $ \ s -> do
199 (a, s') <- runStateT m s
200 runStateT (k a) s'
201 fail str = StateT $ \ _ -> fail str
202
203 instance (MonadPlus m) => MonadPlus (StateT s m) where
204 mzero = StateT $ \ _ -> mzero
205 m `mplus` n = StateT $ \ s -> runStateT m s `mplus` runStateT n s
206
207 instance (MonadFix m) => MonadFix (StateT s m) where
208 mfix f = StateT $ \ s -> mfix $ \ ~(a, _) -> runStateT (f a) s
209
210 instance MonadTrans (StateT s) where
211 lift m = StateT $ \ s -> do
212 a <- m
213 return (a, s)
214
215 instance (MonadIO m) => MonadIO (StateT s m) where
216 liftIO = lift . liftIO
217
218 -- | Fetch the current value of the state within the monad.
219 get :: (Monad m) => StateT s m s
220 get = state $ \ s -> (s, s)
221
222 -- | @'put' s@ sets the state within the monad to @s@.
223 put :: (Monad m) => s -> StateT s m ()
224 put s = state $ \ _ -> ((), s)
225
226 -- | @'modify' f@ is an action that updates the state to the result of
227 -- applying @f@ to the current state.
228 --
229 -- * @'modify' f = 'get' >>= ('put' . f)@
230 modify :: (Monad m) => (s -> s) -> StateT s m ()
231 modify f = state $ \ s -> ((), f s)
232
233 -- | A variant of 'modify' in which the computation is strict in the
234 -- new state.
235 --
236 -- * @'modify'' f = 'get' >>= (('$!') 'put' . f)@
237 modify' :: (Monad m) => (s -> s) -> StateT s m ()
238 modify' f = do
239 s <- get
240 put $! f s
241
242 -- | Get a specific component of the state, using a projection function
243 -- supplied.
244 --
245 -- * @'gets' f = 'liftM' f 'get'@
246 gets :: (Monad m) => (s -> a) -> StateT s m a
247 gets f = state $ \ s -> (f s, s)
248
249 -- | Uniform lifting of a @callCC@ operation to the new monad.
250 -- This version rolls back to the original state on entering the
251 -- continuation.
252 liftCallCC :: CallCC m (a,s) (b,s) -> CallCC (StateT s m) a b
253 liftCallCC callCC f = StateT $ \ s ->
254 callCC $ \ c ->
255 runStateT (f (\ a -> StateT $ \ _ -> c (a, s))) s
256
257 -- | In-situ lifting of a @callCC@ operation to the new monad.
258 -- This version uses the current state on entering the continuation.
259 -- It does not satisfy the uniformity property (see "Control.Monad.Signatures").
260 liftCallCC' :: CallCC m (a,s) (b,s) -> CallCC (StateT s m) a b
261 liftCallCC' callCC f = StateT $ \ s ->
262 callCC $ \ c ->
263 runStateT (f (\ a -> StateT $ \ s' -> c (a, s'))) s
264
265 -- | Lift a @catchE@ operation to the new monad.
266 liftCatch :: Catch e m (a,s) -> Catch e (StateT s m) a
267 liftCatch catchE m h =
268 StateT $ \ s -> runStateT m s `catchE` \ e -> runStateT (h e) s
269
270 -- | Lift a @listen@ operation to the new monad.
271 liftListen :: (Monad m) => Listen w m (a,s) -> Listen w (StateT s m) a
272 liftListen listen m = StateT $ \ s -> do
273 ((a, s'), w) <- listen (runStateT m s)
274 return ((a, w), s')
275
276 -- | Lift a @pass@ operation to the new monad.
277 liftPass :: (Monad m) => Pass w m (a,s) -> Pass w (StateT s m) a
278 liftPass pass m = StateT $ \ s -> pass $ do
279 ((a, f), s') <- runStateT m s
280 return ((a, s'), f)
281
282 {- $examples
283
284 Parser from ParseLib with Hugs:
285
286 > type Parser a = StateT String [] a
287 > ==> StateT (String -> [(a,String)])
288
289 For example, item can be written as:
290
291 > item = do (x:xs) <- get
292 > put xs
293 > return x
294 >
295 > type BoringState s a = StateT s Identity a
296 > ==> StateT (s -> Identity (a,s))
297 >
298 > type StateWithIO s a = StateT s IO a
299 > ==> StateT (s -> IO (a,s))
300 >
301 > type StateWithErr s a = StateT s Maybe a
302 > ==> StateT (s -> Maybe (a,s))
303
304 -}
305
306 {- $counting
307
308 A function to increment a counter.
309 Taken from the paper \"Generalising Monads to Arrows\",
310 John Hughes (<http://www.cse.chalmers.se/~rjmh/>), November 1998:
311
312 > tick :: State Int Int
313 > tick = do n <- get
314 > put (n+1)
315 > return n
316
317 Add one to the given number using the state monad:
318
319 > plusOne :: Int -> Int
320 > plusOne n = execState tick n
321
322 A contrived addition example. Works only with positive numbers:
323
324 > plus :: Int -> Int -> Int
325 > plus n x = execState (sequence $ replicate n tick) x
326
327 -}
328
329 {- $labelling
330
331 An example from /The Craft of Functional Programming/, Simon
332 Thompson (<http://www.cs.kent.ac.uk/people/staff/sjt/>),
333 Addison-Wesley 1999: \"Given an arbitrary tree, transform it to a
334 tree of integers in which the original elements are replaced by
335 natural numbers, starting from 0. The same element has to be
336 replaced by the same number at every occurrence, and when we meet
337 an as-yet-unvisited element we have to find a \'new\' number to match
338 it with:\"
339
340 > data Tree a = Nil | Node a (Tree a) (Tree a) deriving (Show, Eq)
341 > type Table a = [a]
342
343 > numberTree :: Eq a => Tree a -> State (Table a) (Tree Int)
344 > numberTree Nil = return Nil
345 > numberTree (Node x t1 t2) = do
346 > num <- numberNode x
347 > nt1 <- numberTree t1
348 > nt2 <- numberTree t2
349 > return (Node num nt1 nt2)
350 > where
351 > numberNode :: Eq a => a -> State (Table a) Int
352 > numberNode x = do
353 > table <- get
354 > case elemIndex x table of
355 > Nothing -> do
356 > put (table ++ [x])
357 > return (length table)
358 > Just i -> return i
359
360 numTree applies numberTree with an initial state:
361
362 > numTree :: (Eq a) => Tree a -> Tree Int
363 > numTree t = evalState (numberTree t) []
364
365 > testTree = Node "Zero" (Node "One" (Node "Two" Nil Nil) (Node "One" (Node "Zero" Nil Nil) Nil)) Nil
366 > numTree testTree => Node 0 (Node 1 (Node 2 Nil Nil) (Node 1 (Node 0 Nil Nil) Nil)) Nil
367
368 -}