GHCi support for levity-polymorphic join points
authorRichard Eisenberg <rae@richarde.dev>
Tue, 4 Jun 2019 18:31:08 +0000 (14:31 -0400)
committerBen Gamari <ben@smart-cactus.org>
Tue, 25 Jun 2019 18:37:38 +0000 (14:37 -0400)
Fixes #16509.

See Note [Levity-polymorphic join points] in ByteCodeGen,
which tells the full story.

This commit also adds some comments and cleans some code
in the byte-code generator, as I was exploring around trying
to understand it.

test case: ghci/scripts/T16509

(cherry picked from commit 392210bf8a27b3604f8642d76c39e391c2d4b5e0)

compiler/ghci/ByteCodeAsm.hs
compiler/ghci/ByteCodeGen.hs
compiler/ghci/ByteCodeInstr.hs
compiler/simplStg/RepType.hs
testsuite/tests/ghci/scripts/T16509.hs [new file with mode: 0644]
testsuite/tests/ghci/scripts/T16509.script [new file with mode: 0644]
testsuite/tests/ghci/scripts/all.T

index 0776e40..e3c18b9 100644 (file)
@@ -156,7 +156,11 @@ assembleOneBCO hsc_env pbco = do
   return ubco'
 
 assembleBCO :: DynFlags -> ProtoBCO Name -> IO UnlinkedBCO
-assembleBCO dflags (ProtoBCO nm instrs bitmap bsize arity _origin _malloced) = do
+assembleBCO dflags (ProtoBCO { protoBCOName       = nm
+                             , protoBCOInstrs     = instrs
+                             , protoBCOBitmap     = bitmap
+                             , protoBCOBitmapSize = bsize
+                             , protoBCOArity      = arity }) = do
   -- pass 1: collect up the offsets of the local labels.
   let asm = mapM_ (assembleI dflags) instrs
 
index 1136907..0f5d649 100644 (file)
@@ -26,6 +26,7 @@ import Platform
 import Name
 import MkId
 import Id
+import Var             ( updateVarType )
 import ForeignCall
 import HscTypes
 import CoreUtils
@@ -61,7 +62,6 @@ import Data.Char
 
 import UniqSupply
 import Module
-import Control.Arrow ( second )
 
 import Control.Exception
 import Data.Array
@@ -90,7 +90,7 @@ byteCodeGen hsc_env this_mod binds tycs mb_modBreaks
                 (const ()) $ do
         -- Split top-level binds into strings and others.
         -- See Note [generating code for top-level string literal bindings].
-        let (strings, flatBinds) = partitionEithers $ do
+        let (strings, flatBinds) = partitionEithers $ do  -- list monad
                 (bndr, rhs) <- flattenBinds binds
                 return $ case exprIsTickedString_maybe rhs of
                     Just str -> Left (bndr, str)
@@ -181,29 +181,13 @@ coreExprToBCOs hsc_env this_mod expr
   where dflags = hsc_dflags hsc_env
 
 -- The regular freeVars function gives more information than is useful to
--- us here. simpleFreeVars does the impedance matching.
+-- us here. We need only the free variables, not everything in an FVAnn.
+-- Historical note: At one point FVAnn was more sophisticated than just
+-- a set. Now it isn't. So this function is much simpler. Keeping it around
+-- so that if someone changes FVAnn, they will get a nice type error right
+-- here.
 simpleFreeVars :: CoreExpr -> AnnExpr Id DVarSet
-simpleFreeVars = go . freeVars
-  where
-    go :: AnnExpr Id FVAnn -> AnnExpr Id DVarSet
-    go (ann, e) = (freeVarsOfAnn ann, go' e)
-
-    go' :: AnnExpr' Id FVAnn -> AnnExpr' Id DVarSet
-    go' (AnnVar id)                  = AnnVar id
-    go' (AnnLit lit)                 = AnnLit lit
-    go' (AnnLam bndr body)           = AnnLam bndr (go body)
-    go' (AnnApp fun arg)             = AnnApp (go fun) (go arg)
-    go' (AnnCase scrut bndr ty alts) = AnnCase (go scrut) bndr ty (map go_alt alts)
-    go' (AnnLet bind body)           = AnnLet (go_bind bind) (go body)
-    go' (AnnCast expr (ann, co))     = AnnCast (go expr) (freeVarsOfAnn ann, co)
-    go' (AnnTick tick body)          = AnnTick tick (go body)
-    go' (AnnType ty)                 = AnnType ty
-    go' (AnnCoercion co)             = AnnCoercion co
-
-    go_alt (con, args, expr) = (con, args, go expr)
-
-    go_bind (AnnNonRec bndr rhs) = AnnNonRec bndr (go rhs)
-    go_bind (AnnRec pairs)       = AnnRec (map (second go) pairs)
+simpleFreeVars = freeVars
 
 -- -----------------------------------------------------------------------------
 -- Compilation schema for the bytecode generator
@@ -256,6 +240,7 @@ mkProtoBCO
    -> name
    -> BCInstrList
    -> Either  [AnnAlt Id DVarSet] (AnnExpr Id DVarSet)
+        -- ^ original expression; for debugging only
    -> Int
    -> Word16
    -> [StgWord]
@@ -368,6 +353,9 @@ schemeR fvs (nm, rhs)
 -}
    = schemeR_wrk fvs nm rhs (collect rhs)
 
+-- If an expression is a lambda (after apply bcView), return the
+-- list of arguments to the lambda (in R-to-L order) and the
+-- underlying expression
 collect :: AnnExpr Id DVarSet -> ([Var], AnnExpr' Id DVarSet)
 collect (_, e) = go [] e
   where
@@ -382,8 +370,8 @@ collect (_, e) = go [] e
 schemeR_wrk
     :: [Id]
     -> Id
-    -> AnnExpr Id DVarSet
-    -> ([Var], AnnExpr' Var DVarSet)
+    -> AnnExpr Id DVarSet             -- expression e, for debugging only
+    -> ([Var], AnnExpr' Var DVarSet)  -- result of collect on e
     -> BcM (ProtoBCO Name)
 schemeR_wrk fvs nm original_body (args, body)
    = do
@@ -508,8 +496,16 @@ schemeE d s p e@(AnnLit lit)     = returnUnboxedAtom d s p e (typeArgRep (litera
 schemeE d s p e@(AnnCoercion {}) = returnUnboxedAtom d s p e V
 
 schemeE d s p e@(AnnVar v)
+      -- See Note [Levity-polymorphic join points], step 3.
+    | isLPJoinPoint v           = schemeT d s p $
+                                  AnnApp (bogus_fvs, AnnVar (protectLPJoinPointId v))
+                                         (bogus_fvs, AnnVar voidPrimId)
+                         -- schemeT will call splitApp, dropping the fvs.
+
     | isUnliftedType (idType v) = returnUnboxedAtom d s p e (bcIdArgRep v)
     | otherwise                 = schemeT d s p e
+    where
+      bogus_fvs = pprPanic "schemeE bogus_fvs" (ppr v)
 
 schemeE d s p (AnnLet (AnnNonRec x (_,rhs)) (_,body))
    | (AnnVar v, args_r_to_l) <- splitApp rhs,
@@ -534,19 +530,22 @@ schemeE d s p (AnnLet binds (_,body)) = do
 
          fvss  = map (fvsToEnv p' . fst) rhss
 
+           -- See Note [Levity-polymorphic join points], step 2.
+         (xs',rhss') = zipWithAndUnzip protectLPJoinPointBind xs rhss
+
          -- Sizes of free vars
          size_w = trunc16W . idSizeW dflags
          sizes = map (\rhs_fvs -> sum (map size_w rhs_fvs)) fvss
 
          -- the arity of each rhs
-         arities = map (genericLength . fst . collect) rhss
+         arities = map (genericLength . fst . collect) rhss'
 
          -- This p', d' defn is safe because all the items being pushed
          -- are ptrs, so all have size 1 word.  d' and p' reflect the stack
          -- after the closures have been allocated in the heap (but not
          -- filled in), and pointers to them parked on the stack.
          offsets = mkStackOffsets d (genericReplicate n_binds (wordSize dflags))
-         p' = Map.insertList (zipE xs offsets) p
+         p' = Map.insertList (zipE xs' offsets) p
          d' = d + wordsToBytes dflags n_binds
          zipE = zipEqual "schemeE"
 
@@ -587,7 +586,7 @@ schemeE d s p (AnnLet binds (_,body)) = do
          compile_binds =
             [ compile_bind d' fvs x rhs size arity (trunc16W n)
             | (fvs, x, rhs, size, arity, n) <-
-                zip6 fvss xs rhss sizes arities [n_binds, n_binds-1 .. 1]
+                zip6 fvss xs' rhss' sizes arities [n_binds, n_binds-1 .. 1]
             ]
      body_code <- schemeE d' s p' body
      thunk_codes <- sequence compile_binds
@@ -681,6 +680,30 @@ schemeE _ _ _ expr
    = pprPanic "ByteCodeGen.schemeE: unhandled case"
                (pprCoreExpr (deAnnotate' expr))
 
+-- Is this Id a levity-polymorphic join point?
+-- See Note [Levity-polymorphic join points], step 1
+isLPJoinPoint :: Id -> Bool
+isLPJoinPoint x = isJoinId x &&
+                  isNothing (isLiftedType_maybe (idType x))
+
+-- If necessary, modify this Id and body to protect levity-polymorphic join points.
+-- See Note [Levity-polymorphic join points], step 2.
+protectLPJoinPointBind :: Id -> AnnExpr Id DVarSet -> (Id, AnnExpr Id DVarSet)
+protectLPJoinPointBind x rhs@(fvs, _)
+  | isLPJoinPoint x
+  = (protectLPJoinPointId x, (fvs, AnnLam voidArgId rhs))
+
+  | otherwise
+  = (x, rhs)
+
+-- Update an Id's type to take a Void# argument.
+-- Precondition: the Id is a levity-polymorphic join point.
+-- See Note [Levity-polymorphic join points]
+protectLPJoinPointId :: Id -> Id
+protectLPJoinPointId x
+  = ASSERT( isLPJoinPoint x )
+    updateVarType (voidPrimTy `mkFunTy`) x
+
 {-
    Ticked Expressions
    ------------------
@@ -689,6 +712,41 @@ schemeE _ _ _ expr
   the code. When we find such a thing, we pull out the useful information,
   and then compile the code as if it was just the expression E.
 
+Note [Levity-polymorphic join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+A join point variable is essentially a goto-label: it is, for example,
+never used as an argument to another function, and it is called only
+in tail position. See Note [Join points] and Note [Invariants on join points],
+both in CoreSyn. Because join points do not compile to true, red-blooded
+variables (with, e.g., registers allocated to them), they are allowed
+to be levity-polymorphic. (See invariant #6 in Note [Invariants on join points]
+in CoreSyn.)
+
+However, in this byte-code generator, join points *are* treated just as
+ordinary variables. There is no check whether a binding is for a join point
+or not; they are all treated uniformly. (Perhaps there is a missed optimization
+opportunity here, but that is beyond the scope of my (Richard E's) Thursday.)
+
+We thus must have *some* strategy for dealing with levity-polymorphic join
+points (LPJPs), because we cannot have a levity-polymorphic variable.
+(Not having such a strategy led to #16509, which panicked in the isUnliftedType
+check in the AnnVar case of schemeE.) Here is the strategy:
+
+1. Detect LPJPs. This is done in isLPJoinPoint.
+
+2. When binding an LPJP, add a `\ (_ :: Void#) ->` to its RHS, and modify the
+   type to tack on a `Void# ->`. (Void# is written voidPrimTy within GHC.)
+   Note that functions are never levity-polymorphic, so this transformation
+   changes an LPJP to a non-levity-polymorphic join point. This is done
+   in protectLPJoinPointBind, called from the AnnLet case of schemeE.
+
+3. At an occurrence of an LPJP, add an application to void# (called voidPrimId),
+   being careful to note the new type of the LPJP. This is done in the AnnVar
+   case of schemeE, with help from protectLPJoinPointId.
+
+It's a bit hacky, but it works well in practice and is local. I suspect the
+Right Fix is to take advantage of join points as goto-labels.
+
 -}
 
 -- Compile code to do a tail call.  Specifically, push the fn,
index 07dcd22..d405e1a 100644 (file)
@@ -45,7 +45,7 @@ data ProtoBCO a
         protoBCOBitmap     :: [StgWord],
         protoBCOBitmapSize :: Word16,
         protoBCOArity      :: Int,
-        -- what the BCO came from
+        -- what the BCO came from, for debugging only
         protoBCOExpr       :: Either  [AnnAlt Id DVarSet] (AnnExpr Id DVarSet),
         -- malloc'd pointers
         protoBCOFFIs       :: [FFIInfo]
@@ -179,7 +179,13 @@ data BCInstr
 -- Printing bytecode instructions
 
 instance Outputable a => Outputable (ProtoBCO a) where
-   ppr (ProtoBCO name instrs bitmap bsize arity origin ffis)
+   ppr (ProtoBCO { protoBCOName       = name
+                 , protoBCOInstrs     = instrs
+                 , protoBCOBitmap     = bitmap
+                 , protoBCOBitmapSize = bsize
+                 , protoBCOArity      = arity
+                 , protoBCOExpr       = origin
+                 , protoBCOFFIs       = ffis })
       = (text "ProtoBCO" <+> ppr name <> char '#' <> int arity
                 <+> text (show ffis) <> colon)
         $$ nest 3 (case origin of
index 4d437d3..522eeb1 100644 (file)
@@ -64,7 +64,7 @@ isNvUnaryType ty
   = False
 
 -- INVARIANT: the result list is never empty.
-typePrimRepArgs :: Type -> [PrimRep]
+typePrimRepArgs :: HasDebugCallStack => Type -> [PrimRep]
 typePrimRepArgs ty
   | [] <- reps
   = [VoidRep]
diff --git a/testsuite/tests/ghci/scripts/T16509.hs b/testsuite/tests/ghci/scripts/T16509.hs
new file mode 100644 (file)
index 0000000..6f35e3c
--- /dev/null
@@ -0,0 +1,11 @@
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE ViewPatterns #-}
+
+module PatternPanic where
+
+pattern TestPat :: (Int, Int)
+pattern TestPat <- (isSameRef -> True, 0)
+
+isSameRef :: Int -> Bool
+isSameRef e | 0 <- e = True
+isSameRef _ = False
diff --git a/testsuite/tests/ghci/scripts/T16509.script b/testsuite/tests/ghci/scripts/T16509.script
new file mode 100644 (file)
index 0000000..3e40de0
--- /dev/null
@@ -0,0 +1 @@
+:l T16509
index 5162a3c..b6772d4 100755 (executable)
@@ -295,5 +295,6 @@ test('T15941', normal, ghci_script, ['T15941.script'])
 test('T16030', normal, ghci_script, ['T16030.script'])
 test('T11606', normal, ghci_script, ['T11606.script'])
 test('T16089', normal, ghci_script, ['T16089.script'])
+test('T16509', normal, ghci_script, ['T16509.script'])
 test('T16527', normal, ghci_script, ['T16527.script'])
 test('T16767', normal, ghci_script, ['T16767.script'])