Vectoriser: distinguish vectorised from parallel types and functions
authorManuel M T Chakravarty <chak@cse.unsw.edu.au>
Sun, 9 Dec 2012 06:24:05 +0000 (17:24 +1100)
committerManuel M T Chakravarty <chak@cse.unsw.edu.au>
Sun, 9 Dec 2012 06:24:05 +0000 (17:24 +1100)
- We sometimes need to vectorise types and functions because they might be needed in a vectorised context, not because they do directly introduce parallelism.

compiler/vectorise/Vectorise/Exp.hs
compiler/vectorise/Vectorise/Type/Classify.hs
compiler/vectorise/Vectorise/Type/Env.hs

index b300335..5c20507 100644 (file)
@@ -243,20 +243,29 @@ liftSimpleAndCase aexpr@((fvs, _vi), AnnCase expr bndr t alts)
     { vi <- vectAvoidInfoTypeOf expr
     ; if (vi == VISimple)
       then
-        return $ liftSimple aexpr  -- if the scrutinee is scalar, we need no special treatment
+        liftSimple aexpr  -- if the scrutinee is scalar, we need no special treatment
       else do
       { alts' <- mapM (\(ac, bndrs, aexpr) -> (ac, bndrs,) <$> liftSimpleAndCase aexpr) alts
       ; return ((fvs, vi), AnnCase expr bndr t alts')
       }
     }
-liftSimpleAndCase aexpr = return $ liftSimple aexpr
+liftSimpleAndCase aexpr = liftSimple aexpr
 
-liftSimple :: CoreExprWithVectInfo -> CoreExprWithVectInfo
-liftSimple ((fvs, vi), expr) 
-  = ASSERT(vi == VISimple)
-    mkAnnApps (mkAnnLams vars fvs expr) vars
+liftSimple :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
+liftSimple aexpr@((fvs_orig, VISimple), expr) 
+  = do 
+    { let liftedExpr = mkAnnApps (mkAnnLams vars fvs expr) vars
+
+    ; traceVt "encapsulate:" $ ppr (deAnnotate aexpr) $$ text "==>" $$ ppr (deAnnotate liftedExpr)
+
+    ; return $ liftedExpr
+    }
   where
     vars = varSetElems fvs
+    fvs  = filterVarSet isToplevel fvs_orig -- only include 'Id's that are not toplevel
+    
+    isToplevel v | isId v    = not . uf_is_top . realIdUnfolding $ v
+                 | otherwise = False
 
     mkAnnLams :: [Var] -> VarSet -> AnnExpr' Var (VarSet, VectAvoidInfo) -> CoreExprWithVectInfo
     mkAnnLams []     fvs expr = ASSERT(isEmptyVarSet fvs)
@@ -270,23 +279,31 @@ liftSimple ((fvs, vi), expr)
     mkAnnApp :: CoreExprWithVectInfo -> Var -> CoreExprWithVectInfo
     mkAnnApp aexpr@((fvs, _vi), _expr) v 
       = ((fvs `extendVarSet` v, VISimple), AnnApp aexpr ((unitVarSet v, VISimple), AnnVar v))
+liftSimple aexpr
+  = pprPanic "Vectorise.Exp.liftSimple: not simple" $ ppr (deAnnotate aexpr)
+
 
 -- |Vectorise an expression.
 --
 vectExpr :: CoreExprWithVectInfo -> VM VExpr
 
-vectExpr (_, AnnVar v)
+-- !!!FIXME: needs to check for VIEncaps regardless of syntactic form first; in case it is of functional type
+
+vectExpr aexpr@(_, AnnVar v)
+  | (isFunTy . varType $ v) && isVIEncaps aexpr
+  = vectFnExpr False False aexpr
+  | otherwise
   = vectVar v
 
 vectExpr (_, AnnLit lit)
   = vectConst $ Lit lit
 
-vectExpr e@(_, AnnLam bndr _)
-  | isId bndr = vectFnExpr True False e
+vectExpr aexpr@(_, AnnLam bndr _)
+  | isId bndr = vectFnExpr True False aexpr
   | otherwise 
   = do 
     { dflags <- getDynFlags
-    ; cantVectorise dflags "Unexpected type lambda (vectExpr)" $ ppr (deAnnotate e)
+    ; cantVectorise dflags "Unexpected type lambda (vectExpr)" $ ppr (deAnnotate aexpr)
     }
 
   -- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty';
@@ -408,14 +425,18 @@ vectFnExpr inline loop_breaker expr@(_ann, AnnLam bndr body)
     ; vbody <- vectFnExpr inline loop_breaker body
     ; return $ mapVect (mkLams [vectorised vBndr]) vbody
     }
-    -- non-predicate abstraction: vectorise as a scalar computation
+    -- encapsulated non-predicate abstraction: vectorise as a scalar computation
   | isId bndr && isVIEncaps expr
   = vectScalarFun . deAnnotate $ expr
     -- non-predicate abstraction: vectorise as a non-scalar computation
   | isId bndr
   = vectLam inline loop_breaker expr
-vectFnExpr _ _  expr
-    -- not an abstraction: vectorise as a vanilla expression
+vectFnExpr _ _ expr
+    -- encapsulated function: vectorise as a scalar computation
+  | (isFunTy . annExprType $ expr) && isVIEncaps expr
+  = vectScalarFun . deAnnotate $ expr
+  | otherwise
+    -- not an abstraction: vectorise as a non-scalar vanilla expression
   = vectExpr expr
 
 -- |Vectorise type and dictionary applications.
@@ -543,7 +564,7 @@ vectDictExpr (Coercion coe)
 vectScalarFun :: CoreExpr -> VM VExpr
 vectScalarFun expr 
   = do 
-    { traceVt "vectScalarFun" (ppr expr) 
+    { traceVt "vectorise scalar functions:" (ppr expr) 
     ; let (arg_tys, res_ty) = splitFunTys (exprType expr)
     ; mkScalarFun arg_tys res_ty expr
     }
index e1cd43a..1632589 100644 (file)
@@ -42,37 +42,37 @@ import Digraph
 -- * tycons which haven't been converted (because they can't or weren't vectorised) are not
 --   elements of the map
 --
-classifyTyCons :: UniqFM Bool                   -- ^type constructor vectorisation status
-               -> NameSet                       -- ^tycons involving parallel arrays
-               -> [TyCon]                       -- ^type constructors that need to be classified
-               -> ( [TyCon]                     -- to be converted
-                  , [TyCon]                     -- need not be converted (but could be)
-                  , [TyCon]                     -- can't be converted, but involve parallel arrays
-                  , [TyCon]                     -- can't be converted and have no parallel arrays
+classifyTyCons :: UniqFM Bool                  -- ^type constructor vectorisation status
+               -> NameSet                      -- ^tycons involving parallel arrays
+               -> [TyCon]                      -- ^type constructors that need to be classified
+               -> ( [TyCon]                    -- to be converted
+                  , [TyCon]                    -- need not be converted (but could be)
+                  , [TyCon]                    -- involve parallel arrays (whether converted or not)
+                  , [TyCon]                    -- can't be converted
                   )
 classifyTyCons convStatus parTyCons tcs = classify [] [] [] [] convStatus parTyCons (tyConGroups tcs)
   where
     classify conv keep par novect _  _   []               = (conv, keep, par, novect)
     classify conv keep par novect cs pts ((tcs, ds) : rs)
       | can_convert && must_convert
-      = classify (tcs ++ conv) keep par novect (cs `addListToUFM` [(tc, True)  | tc <- tcs]) pts' rs
+      = classify (tcs ++ conv) keep (par ++ tcs_par) novect (cs `addListToUFM` [(tc, True)  | tc <- tcs]) pts' rs
       | can_convert
-      = classify conv (tcs ++ keep) par novect (cs `addListToUFM` [(tc, False) | tc <- tcs]) pts' rs
-      | has_parr
-      = classify conv keep (tcs ++ par) novect cs pts' rs
+      = classify conv (tcs ++ keep) (par ++ tcs_par) novect (cs `addListToUFM` [(tc, False) | tc <- tcs]) pts' rs
       | otherwise
-      = classify conv keep par (tcs ++ novect) cs pts' rs
+      = classify conv keep (par ++ tcs_par) (tcs ++ novect) cs pts' rs
       where
         refs = ds `delListFromUniqSet` tcs
         
-        pts' | has_parr  = pts `addListToNameSet` map tyConName tcs
-             | otherwise = pts
+          -- the tycons that directly or indirectly depend on parallel arrays
+        tcs_par | any ((`elemNameSet` parTyCons) . tyConName) . eltsUFM $ refs = tcs
+                | otherwise                                                    = []
+
+        pts' = pts `addListToNameSet` map tyConName tcs_par
 
         can_convert  = (isNullUFM (refs `minusUFM` cs) && all convertable tcs)
                        || isShowClass tcs
         must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
                        && (not . isShowClass $ tcs)
-        has_parr     = any ((`elemNameSet` parTyCons) . tyConName) . eltsUFM $ refs
 
         -- We currently admit Haskell 2011-style data and newtype declarations as well as type
         -- constructors representing classes.
index faa80a8..9553e5c 100644 (file)
@@ -205,11 +205,13 @@ vectTypeEnv tycons vectTypeDecls vectClassDecls
            -- these are being handled separately.  NB: Some type constructors may be marked SCALAR
            -- /and/ have an explicit right-hand side.)
            --
-           -- Furthermore, 'par_tcs' and 'drop_tcs' are those type constructors that we cannot
-           -- vectorise, and of those, only the 'par_tcs' involve parallel arrays.
-       ; parallelTyCons <- globalParallelTyCons
+           -- Furthermore, 'par_tcs' are those type constructors (converted or not) whose
+           -- definition, directly or indirectly, depends on parallel arrays. Finally, 'drop_tcs'
+           -- are all type constructors that cannot be vectorised.
+       ; parallelTyCons <- (`addListToNameSet` map (tyConName . fst3) vectTyConsWithRHS) <$> 
+                             globalParallelTyCons
        ; let maybeVectoriseTyCons = filter notVectSpecialTyCon tycons ++ impVectTyCons
-             (conv_tcs, keep_tcs, par_tcs, drop_tcs) 
+             (conv_tcs, keep_tcs, par_tcs, drop_tcs)
                = classifyTyCons vectTyConFlavour parallelTyCons maybeVectoriseTyCons
              
        ; traceVt " VECT SCALAR    : " $ ppr (scalarTyConsNoRHS ++ 
@@ -223,12 +225,12 @@ vectTypeEnv tycons vectTypeDecls vectClassDecls
            -- warn the user about unvectorised type constructors
        ; let explanation    = ptext (sLit "(They use unsupported language extensions") $$
                               ptext (sLit "or depend on type constructors that are not vectorised)")
-             drop_tcs_nosyn = filter (not . isSynTyCon) (par_tcs ++ drop_tcs)
+             drop_tcs_nosyn = filter (not . isSynTyCon) drop_tcs
        ; unless (null drop_tcs_nosyn) $
            emitVt "Warning: cannot vectorise these type constructors:" $ 
              pprQuotedList drop_tcs_nosyn $$ explanation
 
-       ; mapM_ addParallelTyConAndCons $ conv_tcs ++ par_tcs
+       ; mapM_ addParallelTyConAndCons $ par_tcs ++ [tc | (tc, _, False) <- vectTyConsWithRHS]
 
        ; let mapping =      
                     -- Type constructors that we found we don't need to vectorise and those