Typo
[packages/base.git] / tests / Memo2.lhs
1 % $Id: Memo.lhs,v 1.1 2005/12/16 10:46:05 simonmar Exp $
2 %
3 % (c) The GHC Team, 1999
4 %
5 % Hashing memo tables.
6
7 \begin{code}
8 {-# LANGUAGE CPP #-}
9
10 module Memo2
11 #ifndef __PARALLEL_HASKELL__
12         ( memo          -- :: (a -> b) -> a -> b
13         , memoSized     -- :: Int -> (a -> b) -> a -> b
14         ) 
15 #endif
16         where
17
18 #ifndef __PARALLEL_HASKELL__
19
20 import System.Mem.StableName    ( StableName, makeStableName, hashStableName )
21 import System.Mem.Weak          ( Weak, mkWeakPtr, mkWeak, deRefWeak, finalize )
22 import Data.Array.IO            ( IOArray, newArray, readArray, writeArray )
23 import System.IO.Unsafe         ( unsafePerformIO )
24 import Control.Concurrent.MVar  ( MVar, newMVar, putMVar, takeMVar )
25 \end{code}
26
27 -----------------------------------------------------------------------------
28 Memo table representation.
29
30 The representation is this: a fixed-size hash table where each bucket
31 is a list of table entries, of the form (key,value).
32
33 The key in this case is (StableName key), and we use hashStableName to
34 hash it.
35
36 It's important that we can garbage collect old entries in the table
37 when the key is no longer reachable in the heap.  Hence the value part
38 of each table entry is (Weak val), where the weak pointer "key" is the
39 key for our memo table, and 'val' is the value of this memo table
40 entry.  When the key becomes unreachable, a finalizer will fire and
41 remove this entry from the hash bucket, and further attempts to
42 dereference the weak pointer will return Nothing.  References from
43 'val' to the key are ignored (see the semantics of weak pointers in
44 the documentation).
45
46 \begin{code}
47 type MemoTable key val
48         = MVar (
49             Int,        -- current table size
50             IOArray Int [MemoEntry key val]   -- hash table
51            )
52
53 -- a memo table entry: compile with -funbox-strict-fields to eliminate
54 -- the boxes around the StableName and Weak fields.
55 data MemoEntry key val = MemoEntry !(StableName key) !(Weak val)
56 \end{code}
57
58 We use an MVar to the hash table, so that several threads may safely
59 access it concurrently.  This includes the finalization threads that
60 remove entries from the table.
61
62 ToDo: Can efficiency be improved at all?
63
64 \begin{code}
65 memo :: (a -> b) -> a -> b
66 memo f = memoSized default_table_size f
67
68 default_table_size = 1001
69
70 -- Our memo functions are *strict*.  Lazy memo functions tend to be
71 -- less useful because it is less likely you'll get a memo table hit
72 -- for a thunk.  This change was made to match Hugs's Memo
73 -- implementation, and as the result of feedback from Conal Elliot
74 -- <conal@microsoft.com>.
75
76 memoSized :: Int -> (a -> b) -> a -> b
77 memoSized size f = strict (lazyMemoSized size f)
78
79 strict = ($!)
80
81 lazyMemoSized :: Int -> (a -> b) -> a -> b
82 lazyMemoSized size f =
83    let (table,weak) = unsafePerformIO (
84                 do { tbl <- newArray (0,size) []
85                    ; mvar <- newMVar (size,tbl)
86                    ; weak <- mkWeakPtr mvar (Just (table_finalizer tbl size))
87                    ; return (mvar,weak)
88                    })
89    in  memo' f table weak
90
91 table_finalizer :: IOArray Int [MemoEntry key val] -> Int -> IO ()
92 table_finalizer table size = 
93    sequence_ [ finalizeBucket i | i <- [0..size] ]
94  where
95    finalizeBucket i = do
96       bucket <- readArray table i 
97       sequence_ [ finalize w | MemoEntry _ w <- bucket ]
98
99 memo' :: (a -> b) -> MemoTable a b -> Weak (MemoTable a b) -> a -> b
100 memo' f ref weak_ref = \k -> unsafePerformIO $ do
101    stable_key <- makeStableName k
102    (size, table) <- takeMVar ref
103    let hash_key = hashStableName stable_key `mod` size
104    bucket <- readArray table hash_key
105    lkp <- lookupSN stable_key bucket
106
107    case lkp of
108      Just result -> do
109         putMVar ref (size,table)
110         return result
111      Nothing -> do
112         let result = f k
113         weak <- mkWeak k result (Just (finalizer hash_key stable_key weak_ref))
114         writeArray table hash_key (MemoEntry stable_key weak : bucket)
115         putMVar ref (size,table)
116         return result
117
118 finalizer :: Int -> StableName a -> Weak (MemoTable a b) -> IO ()
119 finalizer hash_key stable_key weak_ref = 
120   do r <- deRefWeak weak_ref 
121      case r of
122         Nothing -> return ()
123         Just mvar -> do
124                 (size,table) <- takeMVar mvar
125                 bucket <- readArray table hash_key
126                 let new_bucket = [ e | e@(MemoEntry sn weak) <- bucket, 
127                                        sn /= stable_key ]
128                 writeArray table hash_key new_bucket
129                 putMVar mvar (size,table)
130
131 lookupSN :: StableName key -> [MemoEntry key val] -> IO (Maybe val)
132 lookupSN sn [] = sn `seq` return Nothing -- make it strict in sn
133 lookupSN sn (MemoEntry sn' weak : xs)
134    | sn == sn'  = do maybe_item <- deRefWeak weak
135                      case maybe_item of
136                         Nothing -> error ("dead weak pair: " ++ 
137                                                 show (hashStableName sn))
138                         Just v  -> return (Just v)
139    | otherwise  = lookupSN sn xs
140 #endif
141 \end{code}