Functions and types can now be post-hoc vectorised; i.e., in modules where they are...
authorManuel M T Chakravarty <chak@cse.unsw.edu.au>
Mon, 22 Aug 2011 13:53:04 +0000 (23:53 +1000)
committerManuel M T Chakravarty <chak@cse.unsw.edu.au>
Wed, 24 Aug 2011 12:44:09 +0000 (22:44 +1000)
- Types already gained this functionality already in a previous commit
- This commit adds the capability for functions

This is a crucial step towards being able to use the standard Prelude, instead of a special vectorised one.

compiler/hsSyn/HsDecls.lhs
compiler/rename/RnSource.lhs
compiler/typecheck/TcBinds.lhs
compiler/vectorise/Vectorise.hs
compiler/vectorise/Vectorise/Env.hs
compiler/vectorise/Vectorise/Monad/Global.hs

index 5015838..5461500 100644 (file)
@@ -1021,18 +1021,6 @@ A vectorisation pragma, one of
   {-# VECTORISE type T = ty #-}
   {-# VECTORISE SCALAR type T #-}
   
-Note [Typechecked vectorisation pragmas]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-In case of the first variant of vectorisation pragmas (with an explicit expression),
-we need to infer the type of that expression during type checking and then keep that type
-around until vectorisation, so that it can be checked against the *vectorised* type of 'f'.
-(We cannot determine vectorised types during type checking due to internal information of
-the vectoriser being needed.)
-
-To this end, we annotate the 'Id' of 'f' (the variable mentioned in the PRAGMA) with the
-inferred type of the expression.  This is slightly dodgy, as this is really the type of
-'$v_f' (the name of the vectorised function).
-
 \begin{code}
 type LVectDecl name = Located (VectDecl name)
 
index a5a9a54..dc076cf 100644 (file)
@@ -636,13 +636,13 @@ badRuleLhsErr name lhs bad_e
 \begin{code}
 rnHsVectDecl :: VectDecl RdrName -> RnM (VectDecl Name, FreeVars)
 rnHsVectDecl (HsVect var Nothing)
-  = do { var' <- lookupLocatedTopBndrRn var
+  = do { var' <- lookupLocatedOccRn var
        ; return (HsVect var' Nothing, unitFV (unLoc var'))
        }
 -- FIXME: For the moment, the right-hand side is restricted to be a variable as we cannot properly
 --        typecheck a complex right-hand side without invoking 'vectType' from the vectoriser.
 rnHsVectDecl (HsVect var (Just rhs@(L _ (HsVar _))))
-  = do { var' <- lookupLocatedTopBndrRn var
+  = do { var' <- lookupLocatedOccRn var
        ; (rhs', fv_rhs) <- rnLExpr rhs
        ; return (HsVect var' (Just rhs'), fv_rhs `addOneFV` unLoc var')
        }
@@ -652,7 +652,7 @@ rnHsVectDecl (HsVect _var (Just _rhs))
                , ptext (sLit "must be an identifier")
                ]
 rnHsVectDecl (HsNoVect var)
-  = do { var' <- lookupLocatedTopBndrRn var
+  = do { var' <- lookupLocatedTopBndrRn var           -- only applies to local (not imported) names
        ; return (HsNoVect var', unitFV (unLoc var'))
        }
 rnHsVectDecl (HsVectTypeIn tycon Nothing)
index 9f5fd4d..6787bbd 100644 (file)
@@ -644,20 +644,19 @@ tcVect :: VectDecl Name -> TcM (VectDecl TcId)
 -- FIXME: We can't typecheck the expression of a vectorisation declaration against the vectorised
 --   type of the original definition as this requires internals of the vectoriser not available
 --   during type checking.  Instead, constrain the rhs of a vectorisation declaration to be a single
---   identifier (this is checked in 'rnHsVectDecl').
+--   identifier (this is checked in 'rnHsVectDecl').  Fix this by enabling the use of 'vectType'
+--   from the vectoriser here.
 tcVect (HsVect name Nothing)
   = addErrCtxt (vectCtxt name) $
-    do { id <- wrapLocM tcLookupId name
-       ; return $ HsVect id Nothing
+    do { var <- wrapLocM tcLookupId name
+       ; return $ HsVect var Nothing
        }
-tcVect (HsVect lname@(L loc name) (Just rhs))
-  = addErrCtxt (vectCtxt lname) $
-    do { id <- tcLookupId name
-
+tcVect (HsVect name (Just rhs))
+  = addErrCtxt (vectCtxt name) $
+    do { var <- wrapLocM tcLookupId name
        ; let L rhs_loc (HsVar rhs_var_name) = rhs
        ; rhs_id <- tcLookupId rhs_var_name
-       ; let typedId = setIdType id (idType rhs_id)
-       ; return $ HsVect (L loc typedId) (Just $ L rhs_loc (HsVar rhs_id))
+       ; return $ HsVect var (Just $ L rhs_loc (HsVar rhs_id))
        }
 
 {- OLD CODE:
@@ -688,8 +687,8 @@ tcVect (HsVect lname@(L loc name) (Just rhs))
  -}
 tcVect (HsNoVect name)
   = addErrCtxt (vectCtxt name) $
-    do { id <- wrapLocM tcLookupId name
-       ; return $ HsNoVect id
+    do { var <- wrapLocM tcLookupId name
+       ; return $ HsNoVect var
        }
 tcVect (HsVectTypeIn lname@(L _ name) ty)
   = addErrCtxt (vectCtxt lname) $
index 1d54b38..2f9035e 100644 (file)
@@ -33,9 +33,10 @@ import Util                 ( zipLazy )
 import MonadUtils
 
 import Control.Monad
+import Data.Maybe
 
 
--- | Vectorise a single module.
+-- |Vectorise a single module.
 --
 vectorise :: ModGuts -> CoreM ModGuts
 vectorise guts
@@ -43,7 +44,7 @@ vectorise guts
       ; liftIO $ vectoriseIO hsc_env guts
       }
 
--- Vectorise a single monad, given the dynamic compiler flags and HscEnv.
+-- Vectorise a single monad, given the dynamic compiler flags and HscEnv.
 --
 vectoriseIO :: HscEnv -> ModGuts -> IO ModGuts
 vectoriseIO hsc_env guts
@@ -58,7 +59,7 @@ vectoriseIO hsc_env guts
       ; return (guts' { mg_vect_info = info' })
       }
 
--- Vectorise a single module, in the VM monad.
+-- Vectorise a single module, in the VM monad.
 --
 vectModule :: ModGuts -> VM ModGuts
 vectModule guts@(ModGuts { mg_types      = types
@@ -73,21 +74,23 @@ vectModule guts@(ModGuts { mg_types      = types
           -- representaions, and the conrresponding data constructors.  Moreover, we produce
           -- bindings for dfuns and family instances of the classes and type families used in the
           -- DPH library to represent array types.
-      ; (types', new_fam_insts, tc_binds) <- vectTypeEnv types [vd | vd@(VectType _ _) <- vect_decls]
+      ; (types', new_fam_insts, tc_binds) <- vectTypeEnv types [vd 
+                                                               | vd@(VectType _ _) <- vect_decls]
 
       ; (_, fam_inst_env) <- readGEnv global_fam_inst_env
 
-          -- Vectorise all the top level bindings.
-      ; binds'  <- mapM vectTopBind binds
+          -- Vectorise all the top level bindings and VECTORISE declarations on imported identifiers
+      ; binds_top <- mapM vectTopBind binds
+      ; binds_imp <- mapM vectImpBind [imp_id | Vect imp_id _ <- vect_decls, isGlobalId imp_id]
 
       ; return $ guts { mg_types        = types'
-                      , mg_binds        = Rec tc_binds : binds'
+                      , mg_binds        = Rec tc_binds : (binds_top ++ binds_imp)
                       , mg_fam_inst_env = fam_inst_env
                       , mg_fam_insts    = fam_insts ++ new_fam_insts
                       }
       }
 
--- |Try to vectorise a top-level binding.  If it doesn't vectorise then return it unharmed.
+-- Try to vectorise a top-level binding.  If it doesn't vectorise then return it unharmed.
 --
 -- For example, for the binding 
 --
@@ -198,7 +201,25 @@ vectTopBind b@(Rec bs)
              else vectorise                             -- no binding has a 'NOVECTORISE' decl
            }
     noVectoriseErr = "NOVECTORISE must be used on all or no bindings of a recursive group"
-     
+
+-- Add a vectorised binding to an imported top-level variable that has a VECTORISE [SCALAR] pragma
+-- in this module.
+--
+vectImpBind :: Id -> VM CoreBind
+vectImpBind var
+  = do {   -- Vectorise the right-hand side, create an appropriate top-level binding and add it
+           -- to the vectorisation map.  For the non-lifted version, we refer to the original
+           -- definition — i.e., 'Var var'.
+       ; (inline, isScalar, expr') <- vectTopRhs [] var (Var var)
+       ; var' <- vectTopBinder var inline expr'
+       ; when isScalar $ 
+           addGlobalScalar var
+
+           -- We add any newly created hoisted top-level bindings.
+       ; hs <- takeHoisted
+       ; return . Rec $ (var', expr') : hs
+       }
+
 -- | Make the vectorised version of this top level binder, and add the mapping
 --   between it and the original to the state. For some binder @foo@ the vectorised
 --   version is @$v_foo@
@@ -215,13 +236,13 @@ vectTopBinder var inline expr
       ; vty  <- vectType (idType var)
       
           -- If there is a vectorisation declartion for this binding, make sure that its type
-          --  matches
+          -- matches
       ; vectDecl <- lookupVectDecl var
       ; case vectDecl of
-          Nothing                 -> return ()
+          Nothing             -> return ()
           Just (vdty, _) 
             | eqType vty vdty -> return ()
-            | otherwise           -> 
+            | otherwise       -> 
               cantVectorise ("Type mismatch in vectorisation pragma for " ++ show var) $
                 (text "Expected type" <+> ppr vty)
                 $$
@@ -263,10 +284,11 @@ vectTopRhs :: [Var]           -- ^ Names of all functions in the rec block
                  , CoreExpr)  -- (3) the vectorised right-hand side
 vectTopRhs recFs var expr
   = closedV
-  $ do { traceVt ("vectTopRhs of " ++ show var) $ ppr expr
-  
-       ; globalScalar <- isGlobalScalar var
+  $ do { globalScalar <- isGlobalScalar var
        ; vectDecl     <- lookupVectDecl var
+
+       ; traceVt ("vectTopRhs of " ++ show var ++ info globalScalar vectDecl) $ ppr expr
+
        ; rhs globalScalar vectDecl
        }
   where
@@ -278,10 +300,15 @@ vectTopRhs recFs var expr
            }
     rhs False         Nothing                         -- Case (3)
       = do { let fvs = freeVars expr
-           ; (inline, isScalar, vexpr) <- inBind var $
-                                          vectPolyExpr (isStrongLoopBreaker $ idOccInfo var) recFs fvs
+           ; (inline, isScalar, vexpr) 
+               <- inBind var $
+                    vectPolyExpr (isStrongLoopBreaker $ idOccInfo var) recFs fvs
            ; return (inline, isScalar, vectorised vexpr)
            }
+    
+    info True  _                          = " [VECTORISE SCALAR]"
+    info False vectDecl | isJust vectDecl = " [VECTORISE]"
+                        | otherwise       = " (no pragma)"
 
 -- | Project out the vectorised version of a binding from some closure,
 --   or return the original body if that doesn't work or the binding is scalar. 
index a13c021..5220d5a 100644 (file)
@@ -98,9 +98,6 @@ data GlobalEnv
           -- ^Variables that are not vectorised.  (They may be referenced in the right-hand sides
           -- of vectorisation declarations, though.)
 
-        , global_exported_vars        :: VarEnv (Var, Var)
-          -- ^Exported variables which have a vectorised version.
-
         , global_tycons               :: NameEnv TyCon
           -- ^Mapping from TyCons to their vectorised versions.
           -- TyCons which do not have to be vectorised are mapped to themselves.
@@ -134,7 +131,6 @@ initGlobalEnv info vectDecls instEnvs famInstEnvs
   , global_scalar_vars          = vectInfoScalarVars info   `extendVarSetList` scalar_vars
   , global_scalar_tycons        = vectInfoScalarTyCons info `addListToNameSet` scalar_tycons
   , global_novect_vars          = mkVarSet novects
-  , global_exported_vars        = emptyVarEnv
   , global_tycons               = mapNameEnv snd $ vectInfoTyCon info
   , global_datacons             = mapNameEnv snd $ vectInfoDataCon info
   , global_pa_funs              = mapNameEnv snd $ vectInfoPADFun info
@@ -144,10 +140,14 @@ initGlobalEnv info vectDecls instEnvs famInstEnvs
   , global_bindings             = []
   }
   where
-    vects         = [(var, (varType var, exp)) | Vect     var   (Just exp) <- vectDecls]
-    scalar_vars   = [var                       | Vect     var   Nothing    <- vectDecls]
-    novects       = [var                       | NoVect   var              <- vectDecls]
-    scalar_tycons = [tyConName tycon           | VectType tycon Nothing    <- vectDecls]
+    vects         = [(var, (ty, exp)) | Vect     var   (Just exp@(Var rhs_var)) <- vectDecls
+                                      , let ty = varType rhs_var]
+                                        -- FIXME: we currently only allow RHSes consisting of a
+                                        --   single variable to be able to obtain the type without
+                                        --   inference — see also 'TcBinds.tcVect'
+    scalar_vars   = [var              | Vect     var   Nothing                  <- vectDecls]
+    novects       = [var              | NoVect   var                            <- vectDecls]
+    scalar_tycons = [tyConName tycon  | VectType tycon Nothing                  <- vectDecls]
 
 
 -- Operators on Global Environments -------------------------------------------
@@ -198,13 +198,14 @@ setPRFunsEnv ps genv
 
 -- |Compute vectorisation information that goes into 'ModGuts' (and is stored in interface files).
 -- The incoming 'vectInfo' is that from the 'HscEnv' and 'EPS'.  The outgoing one contains only the
--- definitions for the currently compiled module; this includes variables, type constructors, and
--- data constructors referenced in VECTORISE pragmas.
+-- declarations for the currently compiled module; this includes variables, type constructors, and
+-- data constructors referenced in VECTORISE pragmas, even if they are defined in an imported
+-- module.
 --
 modVectInfo :: GlobalEnv -> TypeEnv -> [CoreVect]-> VectInfo -> VectInfo
 modVectInfo env tyenv vectDecls info
   = info 
-    { vectInfoVar          = global_exported_vars env
+    { vectInfoVar          = mk_env ids      (global_vars     env)
     , vectInfoTyCon        = mk_env tyCons   (global_tycons   env)
     , vectInfoDataCon      = mk_env dataCons (global_datacons env)
     , vectInfoPADFun       = mk_env tyCons   (global_pa_funs  env)
@@ -212,9 +213,12 @@ modVectInfo env tyenv vectDecls info
     , vectInfoScalarTyCons = global_scalar_tycons env `minusNameSet` vectInfoScalarTyCons info
     }
   where
+    vectIds        = [id    | Vect     id    _ <- vectDecls]
     vectTypeTyCons = [tycon | VectType tycon _ <- vectDecls]
+    vectDataCons   = concatMap tyConDataCons vectTypeTyCons
+    ids            = typeEnvIds      tyenv ++ vectIds
     tyCons         = typeEnvTyCons   tyenv ++ vectTypeTyCons
-    dataCons       = typeEnvDataCons tyenv ++ concatMap tyConDataCons vectTypeTyCons
+    dataCons       = typeEnvDataCons tyenv ++ vectDataCons
     
     -- Produce an entry for every declaration that is mentioned in the domain of the 'inspectedEnv'
     mk_env decls inspectedEnv
index 0624e35..5639c23 100644 (file)
@@ -39,9 +39,9 @@ import TyCon
 import DataCon
 import NameEnv
 import NameSet
-import Var
 import VarEnv
 import VarSet
+import Outputable
 
 
 -- Global Environment ---------------------------------------------------------
@@ -67,13 +67,11 @@ updGEnv f = VM $ \_ genv lenv -> return (Yes (f genv) lenv ())
 -- |Add a mapping between a global var and its vectorised version to the state.
 --
 defGlobalVar :: Var -> Var -> VM ()
-defGlobalVar v v' = updGEnv $ \env ->
-  env { global_vars = extendVarEnv (global_vars env) v v'
-      , global_exported_vars = upd (global_exported_vars env)
-      }
-  where
-    upd env | isExportedId v = extendVarEnv env v (v, v')
-            | otherwise      = env
+defGlobalVar v v'
+  = do { traceVt "add global var mapping:" (ppr v <+> text "-->" <+> ppr v') 
+
+       ; updGEnv $ \env -> env { global_vars = extendVarEnv (global_vars env) v v' }
+       }
 
 
 -- Vectorisation declarations -------------------------------------------------