Bag: Add Foldable instance
[ghc.git] / compiler / utils / Bag.hs
1 {-
2 (c) The University of Glasgow 2006
3 (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
4
5
6 Bag: an unordered collection with duplicates
7 -}
8
9 {-# LANGUAGE DeriveDataTypeable, ScopedTypeVariables #-}
10
11 module Bag (
12 Bag, -- abstract type
13
14 emptyBag, unitBag, unionBags, unionManyBags,
15 mapBag,
16 elemBag, lengthBag,
17 filterBag, partitionBag, partitionBagWith,
18 concatBag, catBagMaybes, foldBag, foldrBag, foldlBag,
19 isEmptyBag, isSingletonBag, consBag, snocBag, anyBag,
20 listToBag, bagToList,
21 foldrBagM, foldlBagM, mapBagM, mapBagM_,
22 flatMapBagM, flatMapBagPairM,
23 mapAndUnzipBagM, mapAccumBagLM
24 ) where
25
26 import Outputable
27 import Util
28
29 import MonadUtils
30 import Data.Data
31 import Data.List ( partition )
32 import qualified Data.Foldable as Foldable
33
34 infixr 3 `consBag`
35 infixl 3 `snocBag`
36
37 data Bag a
38 = EmptyBag
39 | UnitBag a
40 | TwoBags (Bag a) (Bag a) -- INVARIANT: neither branch is empty
41 | ListBag [a] -- INVARIANT: the list is non-empty
42 deriving Typeable
43
44 emptyBag :: Bag a
45 emptyBag = EmptyBag
46
47 unitBag :: a -> Bag a
48 unitBag = UnitBag
49
50 lengthBag :: Bag a -> Int
51 lengthBag EmptyBag = 0
52 lengthBag (UnitBag {}) = 1
53 lengthBag (TwoBags b1 b2) = lengthBag b1 + lengthBag b2
54 lengthBag (ListBag xs) = length xs
55
56 elemBag :: Eq a => a -> Bag a -> Bool
57 elemBag _ EmptyBag = False
58 elemBag x (UnitBag y) = x == y
59 elemBag x (TwoBags b1 b2) = x `elemBag` b1 || x `elemBag` b2
60 elemBag x (ListBag ys) = any (x ==) ys
61
62 unionManyBags :: [Bag a] -> Bag a
63 unionManyBags xs = foldr unionBags EmptyBag xs
64
65 -- This one is a bit stricter! The bag will get completely evaluated.
66
67 unionBags :: Bag a -> Bag a -> Bag a
68 unionBags EmptyBag b = b
69 unionBags b EmptyBag = b
70 unionBags b1 b2 = TwoBags b1 b2
71
72 consBag :: a -> Bag a -> Bag a
73 snocBag :: Bag a -> a -> Bag a
74
75 consBag elt bag = (unitBag elt) `unionBags` bag
76 snocBag bag elt = bag `unionBags` (unitBag elt)
77
78 isEmptyBag :: Bag a -> Bool
79 isEmptyBag EmptyBag = True
80 isEmptyBag _ = False -- NB invariants
81
82 isSingletonBag :: Bag a -> Bool
83 isSingletonBag EmptyBag = False
84 isSingletonBag (UnitBag _) = True
85 isSingletonBag (TwoBags _ _) = False -- Neither is empty
86 isSingletonBag (ListBag xs) = isSingleton xs
87
88 filterBag :: (a -> Bool) -> Bag a -> Bag a
89 filterBag _ EmptyBag = EmptyBag
90 filterBag pred b@(UnitBag val) = if pred val then b else EmptyBag
91 filterBag pred (TwoBags b1 b2) = sat1 `unionBags` sat2
92 where sat1 = filterBag pred b1
93 sat2 = filterBag pred b2
94 filterBag pred (ListBag vs) = listToBag (filter pred vs)
95
96 anyBag :: (a -> Bool) -> Bag a -> Bool
97 anyBag _ EmptyBag = False
98 anyBag p (UnitBag v) = p v
99 anyBag p (TwoBags b1 b2) = anyBag p b1 || anyBag p b2
100 anyBag p (ListBag xs) = any p xs
101
102 concatBag :: Bag (Bag a) -> Bag a
103 concatBag bss = foldrBag add emptyBag bss
104 where
105 add bs rs = bs `unionBags` rs
106
107 catBagMaybes :: Bag (Maybe a) -> Bag a
108 catBagMaybes bs = foldrBag add emptyBag bs
109 where
110 add Nothing rs = rs
111 add (Just x) rs = x `consBag` rs
112
113 partitionBag :: (a -> Bool) -> Bag a -> (Bag a {- Satisfy predictate -},
114 Bag a {- Don't -})
115 partitionBag _ EmptyBag = (EmptyBag, EmptyBag)
116 partitionBag pred b@(UnitBag val)
117 = if pred val then (b, EmptyBag) else (EmptyBag, b)
118 partitionBag pred (TwoBags b1 b2)
119 = (sat1 `unionBags` sat2, fail1 `unionBags` fail2)
120 where (sat1, fail1) = partitionBag pred b1
121 (sat2, fail2) = partitionBag pred b2
122 partitionBag pred (ListBag vs) = (listToBag sats, listToBag fails)
123 where (sats, fails) = partition pred vs
124
125
126 partitionBagWith :: (a -> Either b c) -> Bag a
127 -> (Bag b {- Left -},
128 Bag c {- Right -})
129 partitionBagWith _ EmptyBag = (EmptyBag, EmptyBag)
130 partitionBagWith pred (UnitBag val)
131 = case pred val of
132 Left a -> (UnitBag a, EmptyBag)
133 Right b -> (EmptyBag, UnitBag b)
134 partitionBagWith pred (TwoBags b1 b2)
135 = (sat1 `unionBags` sat2, fail1 `unionBags` fail2)
136 where (sat1, fail1) = partitionBagWith pred b1
137 (sat2, fail2) = partitionBagWith pred b2
138 partitionBagWith pred (ListBag vs) = (listToBag sats, listToBag fails)
139 where (sats, fails) = partitionWith pred vs
140
141 foldBag :: (r -> r -> r) -- Replace TwoBags with this; should be associative
142 -> (a -> r) -- Replace UnitBag with this
143 -> r -- Replace EmptyBag with this
144 -> Bag a
145 -> r
146
147 {- Standard definition
148 foldBag t u e EmptyBag = e
149 foldBag t u e (UnitBag x) = u x
150 foldBag t u e (TwoBags b1 b2) = (foldBag t u e b1) `t` (foldBag t u e b2)
151 foldBag t u e (ListBag xs) = foldr (t.u) e xs
152 -}
153
154 -- More tail-recursive definition, exploiting associativity of "t"
155 foldBag _ _ e EmptyBag = e
156 foldBag t u e (UnitBag x) = u x `t` e
157 foldBag t u e (TwoBags b1 b2) = foldBag t u (foldBag t u e b2) b1
158 foldBag t u e (ListBag xs) = foldr (t.u) e xs
159
160 foldrBag :: (a -> r -> r) -> r
161 -> Bag a
162 -> r
163
164 foldrBag _ z EmptyBag = z
165 foldrBag k z (UnitBag x) = k x z
166 foldrBag k z (TwoBags b1 b2) = foldrBag k (foldrBag k z b2) b1
167 foldrBag k z (ListBag xs) = foldr k z xs
168
169 foldlBag :: (r -> a -> r) -> r
170 -> Bag a
171 -> r
172
173 foldlBag _ z EmptyBag = z
174 foldlBag k z (UnitBag x) = k z x
175 foldlBag k z (TwoBags b1 b2) = foldlBag k (foldlBag k z b1) b2
176 foldlBag k z (ListBag xs) = foldl k z xs
177
178 foldrBagM :: (Monad m) => (a -> b -> m b) -> b -> Bag a -> m b
179 foldrBagM _ z EmptyBag = return z
180 foldrBagM k z (UnitBag x) = k x z
181 foldrBagM k z (TwoBags b1 b2) = do { z' <- foldrBagM k z b2; foldrBagM k z' b1 }
182 foldrBagM k z (ListBag xs) = foldrM k z xs
183
184 foldlBagM :: (Monad m) => (b -> a -> m b) -> b -> Bag a -> m b
185 foldlBagM _ z EmptyBag = return z
186 foldlBagM k z (UnitBag x) = k z x
187 foldlBagM k z (TwoBags b1 b2) = do { z' <- foldlBagM k z b1; foldlBagM k z' b2 }
188 foldlBagM k z (ListBag xs) = foldlM k z xs
189
190 mapBag :: (a -> b) -> Bag a -> Bag b
191 mapBag _ EmptyBag = EmptyBag
192 mapBag f (UnitBag x) = UnitBag (f x)
193 mapBag f (TwoBags b1 b2) = TwoBags (mapBag f b1) (mapBag f b2)
194 mapBag f (ListBag xs) = ListBag (map f xs)
195
196 mapBagM :: Monad m => (a -> m b) -> Bag a -> m (Bag b)
197 mapBagM _ EmptyBag = return EmptyBag
198 mapBagM f (UnitBag x) = do r <- f x
199 return (UnitBag r)
200 mapBagM f (TwoBags b1 b2) = do r1 <- mapBagM f b1
201 r2 <- mapBagM f b2
202 return (TwoBags r1 r2)
203 mapBagM f (ListBag xs) = do rs <- mapM f xs
204 return (ListBag rs)
205
206 mapBagM_ :: Monad m => (a -> m b) -> Bag a -> m ()
207 mapBagM_ _ EmptyBag = return ()
208 mapBagM_ f (UnitBag x) = f x >> return ()
209 mapBagM_ f (TwoBags b1 b2) = mapBagM_ f b1 >> mapBagM_ f b2
210 mapBagM_ f (ListBag xs) = mapM_ f xs
211
212 flatMapBagM :: Monad m => (a -> m (Bag b)) -> Bag a -> m (Bag b)
213 flatMapBagM _ EmptyBag = return EmptyBag
214 flatMapBagM f (UnitBag x) = f x
215 flatMapBagM f (TwoBags b1 b2) = do r1 <- flatMapBagM f b1
216 r2 <- flatMapBagM f b2
217 return (r1 `unionBags` r2)
218 flatMapBagM f (ListBag xs) = foldrM k EmptyBag xs
219 where
220 k x b2 = do { b1 <- f x; return (b1 `unionBags` b2) }
221
222 flatMapBagPairM :: Monad m => (a -> m (Bag b, Bag c)) -> Bag a -> m (Bag b, Bag c)
223 flatMapBagPairM _ EmptyBag = return (EmptyBag, EmptyBag)
224 flatMapBagPairM f (UnitBag x) = f x
225 flatMapBagPairM f (TwoBags b1 b2) = do (r1,s1) <- flatMapBagPairM f b1
226 (r2,s2) <- flatMapBagPairM f b2
227 return (r1 `unionBags` r2, s1 `unionBags` s2)
228 flatMapBagPairM f (ListBag xs) = foldrM k (EmptyBag, EmptyBag) xs
229 where
230 k x (r2,s2) = do { (r1,s1) <- f x
231 ; return (r1 `unionBags` r2, s1 `unionBags` s2) }
232
233 mapAndUnzipBagM :: Monad m => (a -> m (b,c)) -> Bag a -> m (Bag b, Bag c)
234 mapAndUnzipBagM _ EmptyBag = return (EmptyBag, EmptyBag)
235 mapAndUnzipBagM f (UnitBag x) = do (r,s) <- f x
236 return (UnitBag r, UnitBag s)
237 mapAndUnzipBagM f (TwoBags b1 b2) = do (r1,s1) <- mapAndUnzipBagM f b1
238 (r2,s2) <- mapAndUnzipBagM f b2
239 return (TwoBags r1 r2, TwoBags s1 s2)
240 mapAndUnzipBagM f (ListBag xs) = do ts <- mapM f xs
241 let (rs,ss) = unzip ts
242 return (ListBag rs, ListBag ss)
243
244 mapAccumBagLM :: Monad m
245 => (acc -> x -> m (acc, y)) -- ^ combining funcction
246 -> acc -- ^ initial state
247 -> Bag x -- ^ inputs
248 -> m (acc, Bag y) -- ^ final state, outputs
249 mapAccumBagLM _ s EmptyBag = return (s, EmptyBag)
250 mapAccumBagLM f s (UnitBag x) = do { (s1, x1) <- f s x; return (s1, UnitBag x1) }
251 mapAccumBagLM f s (TwoBags b1 b2) = do { (s1, b1') <- mapAccumBagLM f s b1
252 ; (s2, b2') <- mapAccumBagLM f s1 b2
253 ; return (s2, TwoBags b1' b2') }
254 mapAccumBagLM f s (ListBag xs) = do { (s', xs') <- mapAccumLM f s xs
255 ; return (s', ListBag xs') }
256
257 listToBag :: [a] -> Bag a
258 listToBag [] = EmptyBag
259 listToBag vs = ListBag vs
260
261 bagToList :: Bag a -> [a]
262 bagToList b = foldrBag (:) [] b
263
264 instance (Outputable a) => Outputable (Bag a) where
265 ppr bag = braces (pprWithCommas ppr (bagToList bag))
266
267 instance Data a => Data (Bag a) where
268 gfoldl k z b = z listToBag `k` bagToList b -- traverse abstract type abstractly
269 toConstr _ = abstractConstr $ "Bag("++show (typeOf (undefined::a))++")"
270 gunfold _ _ = error "gunfold"
271 dataTypeOf _ = mkNoRepType "Bag"
272 dataCast1 x = gcast1 x
273
274 instance Foldable.Foldable Bag where
275 foldr = foldrBag