Trac #9878: Have StaticPointers support dynamic loading.
authorAlexander Vershilov <alexander.vershilov@gmail.com>
Mon, 12 Jan 2015 11:29:18 +0000 (05:29 -0600)
committerAustin Seipp <austin@well-typed.com>
Tue, 13 Jan 2015 16:10:38 +0000 (10:10 -0600)
Summary:
A mutex is used to protect the SPT.

unsafeLookupStaticPtr and staticPtrKeys in GHC.StaticPtr are made
monadic.

SPT entries are removed in a destructor function of modules.

Authored-by: Facundo Domínguez <facundo.dominguez@tweag.io>
Authored-by: Alexander Vershilov <alexander.vershilov@tweag.io>
Test Plan: ./validate

Reviewers: austin, simonpj, hvr

Subscribers: carter, thomie, qnikst, mboes

Differential Revision: https://phabricator.haskell.org/D587

GHC Trac Issues: #9878

compiler/deSugar/StaticPtrTable.hs
includes/rts/StaticPtrTable.h
libraries/base/GHC/StaticPtr.hs
rts/Linker.c
rts/StaticPtrTable.c
testsuite/tests/codeGen/should_run/CgStaticPointers.hs
testsuite/tests/rts/GcStaticPointers.hs
testsuite/tests/rts/ListStaticPointers.hs

index 858a0e8..d1e8e05 100644 (file)
 --
 -- where the constants are fingerprints produced from the static forms.
 --
+-- There is also a finalization function for the time when the module is
+-- unloaded.
+--
+-- > static void hs_hpc_fini_Main(void) __attribute__((destructor));
+-- > static void hs_hpc_fini_Main(void) {
+-- >
+-- >   static StgWord64 k0[2] = {16252233372134256ULL,7370534374096082ULL};
+-- >   hs_spt_remove(k0);
+-- >
+-- >   static StgWord64 k1[2] = {12545634534567898ULL,5409674567544151ULL};
+-- >   hs_spt_remove(k1);
+-- >
+-- > }
+--
 module StaticPtrTable (sptInitCode) where
 
 import CoreSyn
@@ -62,6 +76,15 @@ sptInitCode this_mod entries = vcat
         <> semi
         |  (i, (fp, (n, _))) <- zip [0..] entries
         ]
+    , text "static void hs_spt_fini_" <> ppr this_mod
+           <> text "(void) __attribute__((destructor));"
+    , text "static void hs_spt_fini_" <> ppr this_mod <> text "(void)"
+    , braces $ vcat $
+        [  text "StgWord64 k" <> int i <> text "[2] = "
+           <> pprFingerprint fp <> semi
+        $$ text "hs_spt_remove" <> parens (char 'k' <> int i) <> semi
+        | (i, (fp, _)) <- zip [0..] entries
+        ]
     ]
 
   where
index 87a905c..d863160 100644 (file)
  * */
 void hs_spt_insert (StgWord64 key[2],void* spe_closure);
 
+/** Removes an entry from the Static Pointer Table.
+ *
+ * This function is called from the code generated by
+ * compiler/deSugar/StaticPtrTable.sptInitCode
+ *
+ * */
+void hs_spt_remove (StgWord64 key[2]);
+
 #endif /* RTS_STATICPTRTABLE_H */
index b58564e..efaabf2 100644 (file)
@@ -24,9 +24,9 @@
 --
 -- To solve such concern, the references provided by this module offer a key
 -- that can be used to locate the values on each process. Each process maintains
--- a global and immutable table of references which can be looked up with a
--- given key. This table is known as the Static Pointer Table. The reference can
--- then be dereferenced to obtain the value.
+-- a global table of references which can be looked up with a given key. This
+-- table is known as the Static Pointer Table. The reference can then be
+-- dereferenced to obtain the value.
 --
 -----------------------------------------------------------------------------
 
@@ -48,7 +48,6 @@ import Foreign.Ptr         (castPtr)
 import GHC.Exts            (addrToAny#)
 import GHC.Ptr             (Ptr(..), nullPtr)
 import GHC.Fingerprint     (Fingerprint(..))
-import System.IO.Unsafe    (unsafePerformIO)
 
 
 -- | A reference to a value of type 'a'.
@@ -74,8 +73,15 @@ staticKey (StaticPtr k _ _) = k
 -- This function is unsafe because the program behavior is undefined if the type
 -- of the returned 'StaticPtr' does not match the expected one.
 --
-unsafeLookupStaticPtr :: StaticKey -> Maybe (StaticPtr a)
-unsafeLookupStaticPtr k = unsafePerformIO $ sptLookup k
+unsafeLookupStaticPtr :: StaticKey -> IO (Maybe (StaticPtr a))
+unsafeLookupStaticPtr (Fingerprint w1 w2) = do
+    ptr@(Ptr addr) <- withArray [w1,w2] (hs_spt_lookup . castPtr)
+    if (ptr == nullPtr)
+    then return Nothing
+    else case addrToAny# addr of
+           (# spe #) -> return (Just spe)
+
+foreign import ccall unsafe hs_spt_lookup :: Ptr () -> IO (Ptr a)
 
 -- | Miscelaneous information available for debugging purposes.
 data StaticPtrInfo = StaticPtrInfo
@@ -96,20 +102,9 @@ data StaticPtrInfo = StaticPtrInfo
 staticPtrInfo :: StaticPtr a -> StaticPtrInfo
 staticPtrInfo (StaticPtr _ n _) = n
 
--- | Like 'unsafeLookupStaticPtr' but evaluates in 'IO'.
-sptLookup :: StaticKey -> IO (Maybe (StaticPtr a))
-sptLookup (Fingerprint w1 w2) = do
-    ptr@(Ptr addr) <- withArray [w1,w2] (hs_spt_lookup . castPtr)
-    if (ptr == nullPtr)
-    then return Nothing
-    else case addrToAny# addr of
-           (# spe #) -> return (Just spe)
-
-foreign import ccall unsafe hs_spt_lookup :: Ptr () -> IO (Ptr a)
-
 -- | A list of all known keys.
-staticPtrKeys :: [StaticKey]
-staticPtrKeys = unsafePerformIO $ do
+staticPtrKeys :: IO [StaticKey]
+staticPtrKeys = do
     keyCount <- hs_spt_key_count
     allocaArray (fromIntegral keyCount) $ \p -> do
       count <- hs_spt_keys p keyCount
index 4a0e5ea..6bf06ed 100644 (file)
@@ -1420,6 +1420,7 @@ typedef struct _RtsSymbolVal {
       SymI_HasProto(atomic_dec)                                         \
       SymI_HasProto(hs_spt_lookup)                                      \
       SymI_HasProto(hs_spt_insert)                                      \
+      SymI_HasProto(hs_spt_remove)                                      \
       SymI_HasProto(hs_spt_keys)                                        \
       SymI_HasProto(hs_spt_key_count)                                   \
       RTS_USER_SIGNALS_SYMBOLS                                          \
index bd45080..f7fe066 100644 (file)
@@ -8,12 +8,18 @@
  *
  */
 
-#include "Rts.h"
 #include "StaticPtrTable.h"
+#include "Rts.h"
+#include "RtsUtils.h"
 #include "Hash.h"
+#include "Stable.h"
 
 static HashTable * spt = NULL;
 
+#ifdef THREADED_RTS
+static Mutex spt_lock;
+#endif
+
 /// Hash function for the SPT.
 static int hashFingerprint(HashTable *table, StgWord64 key[2]) {
   // Take half of the key to compute the hash.
@@ -28,21 +34,59 @@ static int compareFingerprint(StgWord64 ptra[2], StgWord64 ptrb[2]) {
 void hs_spt_insert(StgWord64 key[2],void *spe_closure) {
   // hs_spt_insert is called from constructor functions, so
   // the SPT needs to be initialized here.
-  if (spt == NULL)
+  if (spt == NULL) {
     spt = allocHashTable_( (HashFunction *)hashFingerprint
                          , (CompareFunction *)compareFingerprint
                          );
+#ifdef THREADED_RTS
+    initMutex(&spt_lock);
+#endif
+  }
+
+  StgStablePtr * entry = stgMallocBytes( sizeof(StgStablePtr)
+                                       , "hs_spt_insert: entry"
+                                       );
+  *entry = getStablePtr(spe_closure);
+  ACQUIRE_LOCK(&spt_lock);
+  insertHashTable(spt, (StgWord)key, entry);
+  RELEASE_LOCK(&spt_lock);
+}
 
-  getStablePtr(spe_closure);
-  insertHashTable(spt, (StgWord)key, spe_closure);
+static void freeSptEntry(void* entry) {
+  freeStablePtr(*(StgStablePtr*)entry);
+  stgFree(entry);
+}
+
+void hs_spt_remove(StgWord64 key[2]) {
+   if (spt) {
+     ACQUIRE_LOCK(&spt_lock);
+     StgStablePtr* entry = removeHashTable(spt, (StgWord)key, NULL);
+     RELEASE_LOCK(&spt_lock);
+
+     if (entry)
+       freeSptEntry(entry);
+   }
 }
 
 StgPtr hs_spt_lookup(StgWord64 key[2]) {
-  return spt ? lookupHashTable(spt, (StgWord)key) : NULL;
+  if (spt) {
+    ACQUIRE_LOCK(&spt_lock);
+    const StgStablePtr * entry = lookupHashTable(spt, (StgWord)key);
+    RELEASE_LOCK(&spt_lock);
+    const StgPtr ret = entry ? deRefStablePtr(*entry) : NULL;
+    return ret;
+  } else
+    return NULL;
 }
 
 int hs_spt_keys(StgPtr keys[], int szKeys) {
-  return spt ? keysHashTable(spt, (StgWord*)keys, szKeys) : 0;
+  if (spt) {
+    ACQUIRE_LOCK(&spt_lock);
+    const int ret = keysHashTable(spt, (StgWord*)keys, szKeys);
+    RELEASE_LOCK(&spt_lock);
+    return ret;
+  } else
+    return 0;
 }
 
 int hs_spt_key_count() {
@@ -51,7 +95,10 @@ int hs_spt_key_count() {
 
 void exitStaticPtrTable() {
   if (spt) {
-    freeHashTable(spt, NULL);
+    freeHashTable(spt, freeSptEntry);
     spt = NULL;
+#ifdef THREADED_RTS
+    closeMutex(&spt_lock);
+#endif
   }
 }
index 5576f43..f7776b0 100644 (file)
@@ -1,4 +1,5 @@
 {-# LANGUAGE DeriveDataTypeable #-}
+{-# LANGUAGE LambdaCase         #-}
 {-# LANGUAGE StaticPointers     #-}
 
 -- | A test to use symbols produced by the static form.
@@ -9,15 +10,15 @@ import GHC.StaticPtr
 
 main :: IO ()
 main = do
-  print $ lookupKey (static (id . id)) (1 :: Int)
-  print $ lookupKey (static method :: StaticPtr (Char -> Int)) 'a'
+  lookupKey (static (id . id)) >>= \f -> print $ f (1 :: Int)
+  lookupKey (static method :: StaticPtr (Char -> Int)) >>= \f -> print $ f 'a'
   print $ deRefStaticPtr (static g)
   print $ deRefStaticPtr p0 'a'
   print $ deRefStaticPtr (static t_field) $ T 'b'
 
-lookupKey :: StaticPtr a -> a
-lookupKey p = case unsafeLookupStaticPtr (staticKey p) of
-  Just p -> deRefStaticPtr p
+lookupKey :: StaticPtr a -> IO a
+lookupKey p = unsafeLookupStaticPtr (staticKey p) >>= \case
+  Just p -> return $ deRefStaticPtr p
   Nothing -> error $ "couldn't find " ++ show (staticPtrInfo p)
 
 g :: String
index c498af5..3bf02d9 100644 (file)
@@ -26,7 +26,7 @@ main = do
   print z
   performGC
   threadDelay 1000000
-  let Just p = unsafeLookupStaticPtr nats_key
+  Just p <- unsafeLookupStaticPtr nats_key
   print (deRefStaticPtr (unsafeCoerce p) !! 800 :: Integer)
   -- Uncommenting the next line keeps 'nats' alive and would prevent a segfault
   -- if 'nats' were garbage collected.
index 5ddb636..01c747d 100644 (file)
@@ -7,10 +7,12 @@ import Data.List ((\\))
 import GHC.StaticPtr
 import System.Exit
 
-main = when (not $ eqBags staticPtrKeys expected) $ do
-    print ("expected", expected)
-    print ("found", staticPtrKeys)
-    exitFailure
+main = do
+    found <- staticPtrKeys
+    when (not $ eqBags found expected) $ do
+      print ("expected", expected)
+      print ("found", found)
+      exitFailure
   where
 
     expected =