8f6e32130f1a5d6080f70166e22875f2ada8dbbd
[ghc.git] / compiler / vectorise / Vectorise.hs
1 -- Main entry point to the vectoriser. It is invoked iff the option '-fvectorise' is passed.
2 --
3 -- This module provides the function 'vectorise', which vectorises an entire (desugared) module.
4 -- It vectorises all type declarations and value bindings. It also processes all VECTORISE pragmas
5 -- (aka vectorisation declarations), which can lead to the vectorisation of imported data types
6 -- and the enrichment of imported functions with vectorised versions.
7
8 module Vectorise ( vectorise )
9 where
10
11 import Vectorise.Type.Env
12 import Vectorise.Type.Type
13 import Vectorise.Convert
14 import Vectorise.Utils.Hoisting
15 import Vectorise.Exp
16 import Vectorise.Vect
17 import Vectorise.Env
18 import Vectorise.Monad
19
20 import HscTypes hiding ( MonadThings(..) )
21 import CoreUnfold ( mkInlineUnfolding )
22 import CoreFVs
23 import PprCore
24 import CoreSyn
25 import CoreMonad ( CoreM, getHscEnv )
26 import Type
27 import Id
28 import DynFlags
29 import BasicTypes ( isStrongLoopBreaker )
30 import Outputable
31 import Util ( zipLazy )
32 import MonadUtils
33
34 import Control.Monad
35 import Data.Maybe
36
37
38 -- |Vectorise a single module.
39 --
40 vectorise :: ModGuts -> CoreM ModGuts
41 vectorise guts
42 = do { hsc_env <- getHscEnv
43 ; liftIO $ vectoriseIO hsc_env guts
44 }
45
46 -- Vectorise a single monad, given the dynamic compiler flags and HscEnv.
47 --
48 vectoriseIO :: HscEnv -> ModGuts -> IO ModGuts
49 vectoriseIO hsc_env guts
50 = do { -- Get information about currently loaded external packages.
51 ; eps <- hscEPS hsc_env
52
53 -- Combine vectorisation info from the current module, and external ones.
54 ; let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
55
56 -- Run the main VM computation.
57 ; Just (info', guts') <- initV hsc_env guts info (vectModule guts)
58 ; return (guts' { mg_vect_info = info' })
59 }
60
61 -- Vectorise a single module, in the VM monad.
62 --
63 vectModule :: ModGuts -> VM ModGuts
64 vectModule guts@(ModGuts { mg_tcs = tycons
65 , mg_binds = binds
66 , mg_fam_insts = fam_insts
67 , mg_vect_decls = vect_decls
68 })
69 = do { dumpOptVt Opt_D_dump_vt_trace "Before vectorisation" $
70 pprCoreBindings binds
71
72 -- Pick out all 'VECTORISE type' and 'VECTORISE class' pragmas
73 ; let ty_vect_decls = [vd | vd@(VectType _ _ _) <- vect_decls]
74 cls_vect_decls = [vd | vd@(VectClass _) <- vect_decls]
75
76 -- Vectorise the type environment. This will add vectorised
77 -- type constructors, their representaions, and the
78 -- conrresponding data constructors. Moreover, we produce
79 -- bindings for dfuns and family instances of the classes
80 -- and type families used in the DPH library to represent
81 -- array types.
82 ; (new_tycons, new_fam_insts, tc_binds) <- vectTypeEnv tycons ty_vect_decls cls_vect_decls
83
84 -- Family instance environment for /all/ home-package modules including those instances
85 -- generated by 'vectTypeEnv'.
86 ; (_, fam_inst_env) <- readGEnv global_fam_inst_env
87
88 -- Vectorise all the top level bindings and VECTORISE declarations on imported identifiers
89 -- NB: Need to vectorise the imported bindings first (local bindings may depend on them).
90 ; let impBinds = [imp_id | Vect imp_id _ <- vect_decls, isGlobalId imp_id] ++
91 [imp_id | VectInst imp_id <- vect_decls, isGlobalId imp_id]
92 ; binds_imp <- mapM vectImpBind impBinds
93 ; binds_top <- mapM vectTopBind binds
94
95 ; return $ guts { mg_tcs = tycons ++ new_tycons
96 -- we produce no new classes or instances, only new class type constructors
97 -- and dfuns
98 , mg_binds = Rec tc_binds : (binds_top ++ binds_imp)
99 , mg_fam_inst_env = fam_inst_env
100 , mg_fam_insts = fam_insts ++ new_fam_insts
101 }
102 }
103
104 -- Try to vectorise a top-level binding. If it doesn't vectorise then return it unharmed.
105 --
106 -- For example, for the binding
107 --
108 -- @
109 -- foo :: Int -> Int
110 -- foo = \x -> x + x
111 -- @
112 --
113 -- we get
114 -- @
115 -- foo :: Int -> Int
116 -- foo = \x -> vfoo $: x
117 --
118 -- v_foo :: Closure void vfoo lfoo
119 -- v_foo = closure vfoo lfoo void
120 --
121 -- vfoo :: Void -> Int -> Int
122 -- vfoo = ...
123 --
124 -- lfoo :: PData Void -> PData Int -> PData Int
125 -- lfoo = ...
126 -- @
127 --
128 -- @vfoo@ is the "vectorised", or scalar, version that does the same as the original
129 -- function foo, but takes an explicit environment.
130 --
131 -- @lfoo@ is the "lifted" version that works on arrays.
132 --
133 -- @v_foo@ combines both of these into a `Closure` that also contains the
134 -- environment.
135 --
136 -- The original binding @foo@ is rewritten to call the vectorised version
137 -- present in the closure.
138 --
139 -- Vectorisation may be surpressed by annotating a binding with a 'NOVECTORISE' pragma. If this
140 -- pragma is used in a group of mutually recursive bindings, either all or no binding must have
141 -- the pragma. If only some bindings are annotated, a fatal error is being raised.
142 -- FIXME: Once we support partial vectorisation, we may be able to vectorise parts of a group, or
143 -- we may emit a warning and refrain from vectorising the entire group.
144 --
145 vectTopBind :: CoreBind -> VM CoreBind
146 vectTopBind b@(NonRec var expr)
147 = unlessNoVectDecl $
148 do { -- Vectorise the right-hand side, create an appropriate top-level binding and add it
149 -- to the vectorisation map.
150 ; (inline, isScalar, expr') <- vectTopRhs [] var expr
151 ; var' <- vectTopBinder var inline expr'
152 ; when isScalar $
153 addGlobalScalarVar var
154
155 -- We replace the original top-level binding by a value projected from the vectorised
156 -- closure and add any newly created hoisted top-level bindings.
157 ; cexpr <- tryConvert var var' expr
158 ; hs <- takeHoisted
159 ; return . Rec $ (var, cexpr) : (var', expr') : hs
160 }
161 `orElseErrV`
162 do { emitVt " Could NOT vectorise top-level binding" $ ppr var
163 ; return b
164 }
165 where
166 unlessNoVectDecl vectorise
167 = do { hasNoVectDecl <- noVectDecl var
168 ; when hasNoVectDecl $
169 traceVt "NOVECTORISE" $ ppr var
170 ; if hasNoVectDecl then return b else vectorise
171 }
172 vectTopBind b@(Rec bs)
173 = unlessSomeNoVectDecl $
174 do { (vars', _, exprs', hs) <- fixV $
175 \ ~(_, inlines, rhss, _) ->
176 do { -- Vectorise the right-hand sides, create an appropriate top-level bindings
177 -- and add them to the vectorisation map.
178 ; vars' <- sequence [vectTopBinder var inline rhs
179 | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
180 ; (inlines, areScalars, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
181 ; hs <- takeHoisted
182 ; if and areScalars
183 then -- (1) Entire recursive group is scalar
184 -- => add all variables to the global set of scalars
185 do { mapM_ addGlobalScalarVar vars
186 ; return (vars', inlines, exprs', hs)
187 }
188 else -- (2) At least one binding is not scalar
189 -- => vectorise again with empty set of local scalars
190 do { (inlines, _, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
191 ; hs <- takeHoisted
192 ; return (vars', inlines, exprs', hs)
193 }
194 }
195
196 -- Replace the original top-level bindings by a values projected from the vectorised
197 -- closures and add any newly created hoisted top-level bindings to the group.
198 ; cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
199 ; return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
200 }
201 `orElseErrV`
202 return b
203 where
204 (vars, exprs) = unzip bs
205
206 unlessSomeNoVectDecl vectorise
207 = do { hasNoVectDecls <- mapM noVectDecl vars
208 ; when (and hasNoVectDecls) $
209 traceVt "NOVECTORISE" $ ppr vars
210 ; if and hasNoVectDecls
211 then return b -- all bindings have 'NOVECTORISE'
212 else if or hasNoVectDecls
213 then cantVectorise noVectoriseErr (ppr b) -- some (but not all) have 'NOVECTORISE'
214 else vectorise -- no binding has a 'NOVECTORISE' decl
215 }
216 noVectoriseErr = "NOVECTORISE must be used on all or no bindings of a recursive group"
217
218 -- Add a vectorised binding to an imported top-level variable that has a VECTORISE [SCALAR] pragma
219 -- in this module.
220 --
221 -- RESTIRCTION: Currently, we cannot use the pragma vor mutually recursive definitions.
222 --
223 vectImpBind :: Id -> VM CoreBind
224 vectImpBind var
225 = do { -- Vectorise the right-hand side, create an appropriate top-level binding and add it
226 -- to the vectorisation map. For the non-lifted version, we refer to the original
227 -- definition — i.e., 'Var var'.
228 -- NB: To support recursive definitions, we tie a lazy knot.
229 ; (var', _, expr') <- fixV $
230 \ ~(_, inline, rhs) ->
231 do { var' <- vectTopBinder var inline rhs
232 ; (inline, isScalar, expr') <- vectTopRhs [] var (Var var)
233
234 ; when isScalar $
235 addGlobalScalarVar var
236 ; return (var', inline, expr')
237 }
238
239 -- We add any newly created hoisted top-level bindings.
240 ; hs <- takeHoisted
241 ; return . Rec $ (var', expr') : hs
242 }
243
244 -- | Make the vectorised version of this top level binder, and add the mapping
245 -- between it and the original to the state. For some binder @foo@ the vectorised
246 -- version is @$v_foo@
247 --
248 -- NOTE: 'vectTopBinder' *MUST* be lazy in inline and expr because of how it is
249 -- used inside of 'fixV' in 'vectTopBind'.
250 --
251 vectTopBinder :: Var -- ^ Name of the binding.
252 -> Inline -- ^ Whether it should be inlined, used to annotate it.
253 -> CoreExpr -- ^ RHS of binding, used to set the 'Unfolding' of the returned 'Var'.
254 -> VM Var -- ^ Name of the vectorised binding.
255 vectTopBinder var inline expr
256 = do { -- Vectorise the type attached to the var.
257 ; vty <- vectType (idType var)
258
259 -- If there is a vectorisation declartion for this binding, make sure that its type
260 -- matches
261 ; vectDecl <- lookupVectDecl var
262 ; case vectDecl of
263 Nothing -> return ()
264 Just (vdty, _)
265 | eqType vty vdty -> return ()
266 | otherwise ->
267 cantVectorise ("Type mismatch in vectorisation pragma for " ++ show var) $
268 (text "Expected type" <+> ppr vty)
269 $$
270 (text "Inferred type" <+> ppr vdty)
271
272 -- Make the vectorised version of binding's name, and set the unfolding used for inlining
273 ; var' <- liftM (`setIdUnfoldingLazily` unfolding)
274 $ mkVectId var vty
275
276 -- Add the mapping between the plain and vectorised name to the state.
277 ; defGlobalVar var var'
278
279 ; return var'
280 }
281 where
282 unfolding = case inline of
283 Inline arity -> mkInlineUnfolding (Just arity) expr
284 DontInline -> noUnfolding
285 {-
286 !!!TODO: dfuns and unfoldings:
287 -- Do not inline the dfun; instead give it a magic DFunFunfolding
288 -- See Note [ClassOp/DFun selection]
289 -- See also note [Single-method classes]
290 dfun_id_w_fun
291 | isNewTyCon class_tc
292 = dfun_id `setInlinePragma` alwaysInlinePragma { inl_sat = Just 0 }
293 | otherwise
294 = dfun_id `setIdUnfolding` mkDFunUnfolding dfun_ty dfun_args
295 `setInlinePragma` dfunInlinePragma
296 -}
297
298 -- | Vectorise the RHS of a top-level binding, in an empty local environment.
299 --
300 -- We need to distinguish four cases:
301 --
302 -- (1) We have a (non-scalar) vectorisation declaration for the variable (which explicitly provides
303 -- vectorised code implemented by the user)
304 -- => no automatic vectorisation & instead use the user-supplied code
305 --
306 -- (2) We have a scalar vectorisation declaration for a variable that is no dfun
307 -- => generate vectorised code that uses a scalar 'map'/'zipWith' to lift the computation
308 --
309 -- (3) We have a scalar vectorisation declaration for a variable that *is* a dfun
310 -- => generate vectorised code according to the the "Note [Scalar dfuns]" below
311 --
312 -- (4) There is no vectorisation declaration for the variable
313 -- => perform automatic vectorisation of the RHS (the definition may or may not be a dfun;
314 -- vectorisation proceeds differently depending on which it is)
315 --
316 -- Note [Scalar dfuns]
317 -- ~~~~~~~~~~~~~~~~~~~
318 --
319 -- Here is the translation scheme for scalar dfuns — assume the instance declaration:
320 --
321 -- instance Num Int where
322 -- (+) = primAdd
323 -- {-# VECTORISE SCALAR instance Num Int #-}
324 --
325 -- It desugars to
326 --
327 -- $dNumInt :: Num Int
328 -- $dNumInt = D:Num primAdd
329 --
330 -- We vectorise it to
331 --
332 -- $v$dNumInt :: V:Num Int
333 -- $v$dNumInt = D:V:Num (closure2 ((+) $dNumInt) (scalar_zipWith ((+) $dNumInt))))
334 --
335 -- while adding the following entry to the vectorisation map: '$dNumInt' --> '$v$dNumInt'.
336 --
337 -- See "Note [Vectorising classes]" in 'Vectorise.Type.Env' for the definition of 'V:Num'.
338 --
339 -- NB: The outlined vectorisation scheme does not require the right-hand side of the original dfun.
340 -- In fact, we definitely want to refer to the dfn variable instead of the right-hand side to
341 -- ensure that the dictionary selection rules fire.
342 --
343 vectTopRhs :: [Var] -- ^ Names of all functions in the rec block
344 -> Var -- ^ Name of the binding.
345 -> CoreExpr -- ^ Body of the binding.
346 -> VM ( Inline -- (1) inline specification for the binding
347 , Bool -- (2) whether the right-hand side is a scalar computation
348 , CoreExpr) -- (3) the vectorised right-hand side
349 vectTopRhs recFs var expr
350 = closedV
351 $ do { globalScalar <- isGlobalScalarVar var
352 ; vectDecl <- lookupVectDecl var
353 ; let isDFun = isDFunId var
354
355 ; traceVt ("vectTopRhs of " ++ show var ++ info globalScalar isDFun vectDecl ++ ":") $
356 ppr expr
357
358 ; rhs globalScalar isDFun vectDecl
359 }
360 where
361 rhs _globalScalar _isDFun (Just (_, expr')) -- Case (1)
362 = return (inlineMe, False, expr')
363 rhs True False Nothing -- Case (2)
364 = do { expr' <- vectScalarFun recFs expr
365 ; return (inlineMe, True, vectorised expr')
366 }
367 rhs True True Nothing -- Case (3)
368 = do { expr' <- vectScalarDFun var recFs
369 ; return (DontInline, True, expr')
370 }
371 rhs False False Nothing -- Case (4) — not a dfun
372 = do { let exprFvs = freeVars expr
373 ; (inline, isScalar, vexpr)
374 <- inBind var $
375 vectPolyExpr (isStrongLoopBreaker $ idOccInfo var) recFs exprFvs
376 ; return (inline, isScalar, vectorised vexpr)
377 }
378 rhs False True Nothing -- Case (4) — is a dfun
379 = do { expr' <- vectDictExpr expr
380 ; return (DontInline, True, expr')
381 }
382
383 info True False _ = " [VECTORISE SCALAR]"
384 info True True _ = " [VECTORISE SCALAR instance]"
385 info False _ vectDecl | isJust vectDecl = " [VECTORISE]"
386 | otherwise = " (no pragma)"
387
388 -- |Project out the vectorised version of a binding from some closure,
389 -- or return the original body if that doesn't work or the binding is scalar.
390 --
391 tryConvert :: Var -- ^ Name of the original binding (eg @foo@)
392 -> Var -- ^ Name of vectorised version of binding (eg @$vfoo@)
393 -> CoreExpr -- ^ The original body of the binding.
394 -> VM CoreExpr
395 tryConvert var vect_var rhs
396 = do { globalScalar <- isGlobalScalarVar var
397 ; if globalScalar
398 then
399 return rhs
400 else
401 fromVect (idType var) (Var vect_var)
402 `orElseErrV`
403 do { emitVt " Could NOT call vectorised from original version" $ ppr var
404 ; return rhs
405 }
406 }