Rename DynFlag to GeneralFlag
[ghc.git] / compiler / vectorise / Vectorise.hs
1 -- Main entry point to the vectoriser. It is invoked iff the option '-fvectorise' is passed.
2 --
3 -- This module provides the function 'vectorise', which vectorises an entire (desugared) module.
4 -- It vectorises all type declarations and value bindings. It also processes all VECTORISE pragmas
5 -- (aka vectorisation declarations), which can lead to the vectorisation of imported data types
6 -- and the enrichment of imported functions with vectorised versions.
7
8 module Vectorise ( vectorise )
9 where
10
11 import Vectorise.Type.Env
12 import Vectorise.Type.Type
13 import Vectorise.Convert
14 import Vectorise.Utils.Hoisting
15 import Vectorise.Exp
16 import Vectorise.Vect
17 import Vectorise.Env
18 import Vectorise.Monad
19
20 import HscTypes hiding ( MonadThings(..) )
21 import CoreUnfold ( mkInlineUnfolding )
22 import CoreFVs
23 import PprCore
24 import CoreSyn
25 import CoreMonad ( CoreM, getHscEnv )
26 import Type
27 import Id
28 import DynFlags
29 import BasicTypes ( isStrongLoopBreaker )
30 import Outputable
31 import Util ( zipLazy )
32 import MonadUtils
33
34 import Control.Monad
35 import Data.Maybe
36
37
38 -- |Vectorise a single module.
39 --
40 vectorise :: ModGuts -> CoreM ModGuts
41 vectorise guts
42 = do { hsc_env <- getHscEnv
43 ; liftIO $ vectoriseIO hsc_env guts
44 }
45
46 -- Vectorise a single monad, given the dynamic compiler flags and HscEnv.
47 --
48 vectoriseIO :: HscEnv -> ModGuts -> IO ModGuts
49 vectoriseIO hsc_env guts
50 = do { -- Get information about currently loaded external packages.
51 ; eps <- hscEPS hsc_env
52
53 -- Combine vectorisation info from the current module, and external ones.
54 ; let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
55
56 -- Run the main VM computation.
57 ; Just (info', guts') <- initV hsc_env guts info (vectModule guts)
58 ; return (guts' { mg_vect_info = info' })
59 }
60
61 -- Vectorise a single module, in the VM monad.
62 --
63 vectModule :: ModGuts -> VM ModGuts
64 vectModule guts@(ModGuts { mg_tcs = tycons
65 , mg_binds = binds
66 , mg_fam_insts = fam_insts
67 , mg_vect_decls = vect_decls
68 })
69 = do { dumpOptVt Opt_D_dump_vt_trace "Before vectorisation" $
70 pprCoreBindings binds
71
72 -- Pick out all 'VECTORISE type' and 'VECTORISE class' pragmas
73 ; let ty_vect_decls = [vd | vd@(VectType _ _ _) <- vect_decls]
74 cls_vect_decls = [vd | vd@(VectClass _) <- vect_decls]
75
76 -- Vectorise the type environment. This will add vectorised
77 -- type constructors, their representaions, and the
78 -- conrresponding data constructors. Moreover, we produce
79 -- bindings for dfuns and family instances of the classes
80 -- and type families used in the DPH library to represent
81 -- array types.
82 ; (new_tycons, new_fam_insts, tc_binds) <- vectTypeEnv tycons ty_vect_decls cls_vect_decls
83
84 -- Family instance environment for /all/ home-package modules including those instances
85 -- generated by 'vectTypeEnv'.
86 ; (_, fam_inst_env) <- readGEnv global_fam_inst_env
87
88 -- Vectorise all the top level bindings and VECTORISE declarations on imported identifiers
89 -- NB: Need to vectorise the imported bindings first (local bindings may depend on them).
90 ; let impBinds = [imp_id | Vect imp_id _ <- vect_decls, isGlobalId imp_id] ++
91 [imp_id | VectInst imp_id <- vect_decls, isGlobalId imp_id]
92 ; binds_imp <- mapM vectImpBind impBinds
93 ; binds_top <- mapM vectTopBind binds
94
95 ; return $ guts { mg_tcs = tycons ++ new_tycons
96 -- we produce no new classes or instances, only new class type constructors
97 -- and dfuns
98 , mg_binds = Rec tc_binds : (binds_top ++ binds_imp)
99 , mg_fam_inst_env = fam_inst_env
100 , mg_fam_insts = fam_insts ++ new_fam_insts
101 }
102 }
103
104 -- Try to vectorise a top-level binding. If it doesn't vectorise then return it unharmed.
105 --
106 -- For example, for the binding
107 --
108 -- @
109 -- foo :: Int -> Int
110 -- foo = \x -> x + x
111 -- @
112 --
113 -- we get
114 -- @
115 -- foo :: Int -> Int
116 -- foo = \x -> vfoo $: x
117 --
118 -- v_foo :: Closure void vfoo lfoo
119 -- v_foo = closure vfoo lfoo void
120 --
121 -- vfoo :: Void -> Int -> Int
122 -- vfoo = ...
123 --
124 -- lfoo :: PData Void -> PData Int -> PData Int
125 -- lfoo = ...
126 -- @
127 --
128 -- @vfoo@ is the "vectorised", or scalar, version that does the same as the original
129 -- function foo, but takes an explicit environment.
130 --
131 -- @lfoo@ is the "lifted" version that works on arrays.
132 --
133 -- @v_foo@ combines both of these into a `Closure` that also contains the
134 -- environment.
135 --
136 -- The original binding @foo@ is rewritten to call the vectorised version
137 -- present in the closure.
138 --
139 -- Vectorisation may be surpressed by annotating a binding with a 'NOVECTORISE' pragma. If this
140 -- pragma is used in a group of mutually recursive bindings, either all or no binding must have
141 -- the pragma. If only some bindings are annotated, a fatal error is being raised.
142 -- FIXME: Once we support partial vectorisation, we may be able to vectorise parts of a group, or
143 -- we may emit a warning and refrain from vectorising the entire group.
144 --
145 vectTopBind :: CoreBind -> VM CoreBind
146 vectTopBind b@(NonRec var expr)
147 = unlessNoVectDecl $
148 do { -- Vectorise the right-hand side, create an appropriate top-level binding and add it
149 -- to the vectorisation map.
150 ; (inline, isScalar, expr') <- vectTopRhs [] var expr
151 ; var' <- vectTopBinder var inline expr'
152 ; when isScalar $
153 addGlobalScalarVar var
154
155 -- We replace the original top-level binding by a value projected from the vectorised
156 -- closure and add any newly created hoisted top-level bindings.
157 ; cexpr <- tryConvert var var' expr
158 ; hs <- takeHoisted
159 ; return . Rec $ (var, cexpr) : (var', expr') : hs
160 }
161 `orElseErrV`
162 do { emitVt " Could NOT vectorise top-level binding" $ ppr var
163 ; return b
164 }
165 where
166 unlessNoVectDecl vectorise
167 = do { hasNoVectDecl <- noVectDecl var
168 ; when hasNoVectDecl $
169 traceVt "NOVECTORISE" $ ppr var
170 ; if hasNoVectDecl then return b else vectorise
171 }
172 vectTopBind b@(Rec bs)
173 = unlessSomeNoVectDecl $
174 do { (vars', _, exprs', hs) <- fixV $
175 \ ~(_, inlines, rhss, _) ->
176 do { -- Vectorise the right-hand sides, create an appropriate top-level bindings
177 -- and add them to the vectorisation map.
178 ; vars' <- sequence [vectTopBinder var inline rhs
179 | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
180 ; (inlines, areScalars, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
181 ; hs <- takeHoisted
182 ; if and areScalars
183 then -- (1) Entire recursive group is scalar
184 -- => add all variables to the global set of scalars
185 do { mapM_ addGlobalScalarVar vars
186 ; return (vars', inlines, exprs', hs)
187 }
188 else -- (2) At least one binding is not scalar
189 -- => vectorise again with empty set of local scalars
190 do { (inlines, _, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
191 ; hs <- takeHoisted
192 ; return (vars', inlines, exprs', hs)
193 }
194 }
195
196 -- Replace the original top-level bindings by a values projected from the vectorised
197 -- closures and add any newly created hoisted top-level bindings to the group.
198 ; cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
199 ; return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
200 }
201 `orElseErrV`
202 return b
203 where
204 (vars, exprs) = unzip bs
205
206 unlessSomeNoVectDecl vectorise
207 = do { hasNoVectDecls <- mapM noVectDecl vars
208 ; when (and hasNoVectDecls) $
209 traceVt "NOVECTORISE" $ ppr vars
210 ; if and hasNoVectDecls
211 then return b -- all bindings have 'NOVECTORISE'
212 else if or hasNoVectDecls
213 then do dflags <- getDynFlags
214 cantVectorise dflags noVectoriseErr (ppr b) -- some (but not all) have 'NOVECTORISE'
215 else vectorise -- no binding has a 'NOVECTORISE' decl
216 }
217 noVectoriseErr = "NOVECTORISE must be used on all or no bindings of a recursive group"
218
219 -- Add a vectorised binding to an imported top-level variable that has a VECTORISE [SCALAR] pragma
220 -- in this module.
221 --
222 -- RESTIRCTION: Currently, we cannot use the pragma vor mutually recursive definitions.
223 --
224 vectImpBind :: Id -> VM CoreBind
225 vectImpBind var
226 = do { -- Vectorise the right-hand side, create an appropriate top-level binding and add it
227 -- to the vectorisation map. For the non-lifted version, we refer to the original
228 -- definition — i.e., 'Var var'.
229 -- NB: To support recursive definitions, we tie a lazy knot.
230 ; (var', _, expr') <- fixV $
231 \ ~(_, inline, rhs) ->
232 do { var' <- vectTopBinder var inline rhs
233 ; (inline, isScalar, expr') <- vectTopRhs [] var (Var var)
234
235 ; when isScalar $
236 addGlobalScalarVar var
237 ; return (var', inline, expr')
238 }
239
240 -- We add any newly created hoisted top-level bindings.
241 ; hs <- takeHoisted
242 ; return . Rec $ (var', expr') : hs
243 }
244
245 -- | Make the vectorised version of this top level binder, and add the mapping
246 -- between it and the original to the state. For some binder @foo@ the vectorised
247 -- version is @$v_foo@
248 --
249 -- NOTE: 'vectTopBinder' *MUST* be lazy in inline and expr because of how it is
250 -- used inside of 'fixV' in 'vectTopBind'.
251 --
252 vectTopBinder :: Var -- ^ Name of the binding.
253 -> Inline -- ^ Whether it should be inlined, used to annotate it.
254 -> CoreExpr -- ^ RHS of binding, used to set the 'Unfolding' of the returned 'Var'.
255 -> VM Var -- ^ Name of the vectorised binding.
256 vectTopBinder var inline expr
257 = do { -- Vectorise the type attached to the var.
258 ; vty <- vectType (idType var)
259
260 -- If there is a vectorisation declartion for this binding, make sure that its type
261 -- matches
262 ; vectDecl <- lookupVectDecl var
263 ; case vectDecl of
264 Nothing -> return ()
265 Just (vdty, _)
266 | eqType vty vdty -> return ()
267 | otherwise ->
268 do dflags <- getDynFlags
269 cantVectorise dflags ("Type mismatch in vectorisation pragma for " ++ showPpr dflags var) $
270 (text "Expected type" <+> ppr vty)
271 $$
272 (text "Inferred type" <+> ppr vdty)
273
274 -- Make the vectorised version of binding's name, and set the unfolding used for inlining
275 ; var' <- liftM (`setIdUnfoldingLazily` unfolding)
276 $ mkVectId var vty
277
278 -- Add the mapping between the plain and vectorised name to the state.
279 ; defGlobalVar var var'
280
281 ; return var'
282 }
283 where
284 unfolding = case inline of
285 Inline arity -> mkInlineUnfolding (Just arity) expr
286 DontInline -> noUnfolding
287 {-
288 !!!TODO: dfuns and unfoldings:
289 -- Do not inline the dfun; instead give it a magic DFunFunfolding
290 -- See Note [ClassOp/DFun selection]
291 -- See also note [Single-method classes]
292 dfun_id_w_fun
293 | isNewTyCon class_tc
294 = dfun_id `setInlinePragma` alwaysInlinePragma { inl_sat = Just 0 }
295 | otherwise
296 = dfun_id `setIdUnfolding` mkDFunUnfolding dfun_ty dfun_args
297 `setInlinePragma` dfunInlinePragma
298 -}
299
300 -- | Vectorise the RHS of a top-level binding, in an empty local environment.
301 --
302 -- We need to distinguish four cases:
303 --
304 -- (1) We have a (non-scalar) vectorisation declaration for the variable (which explicitly provides
305 -- vectorised code implemented by the user)
306 -- => no automatic vectorisation & instead use the user-supplied code
307 --
308 -- (2) We have a scalar vectorisation declaration for a variable that is no dfun
309 -- => generate vectorised code that uses a scalar 'map'/'zipWith' to lift the computation
310 --
311 -- (3) We have a scalar vectorisation declaration for a variable that *is* a dfun
312 -- => generate vectorised code according to the the "Note [Scalar dfuns]" below
313 --
314 -- (4) There is no vectorisation declaration for the variable
315 -- => perform automatic vectorisation of the RHS (the definition may or may not be a dfun;
316 -- vectorisation proceeds differently depending on which it is)
317 --
318 -- Note [Scalar dfuns]
319 -- ~~~~~~~~~~~~~~~~~~~
320 --
321 -- Here is the translation scheme for scalar dfuns — assume the instance declaration:
322 --
323 -- instance Num Int where
324 -- (+) = primAdd
325 -- {-# VECTORISE SCALAR instance Num Int #-}
326 --
327 -- It desugars to
328 --
329 -- $dNumInt :: Num Int
330 -- $dNumInt = D:Num primAdd
331 --
332 -- We vectorise it to
333 --
334 -- $v$dNumInt :: V:Num Int
335 -- $v$dNumInt = D:V:Num (closure2 ((+) $dNumInt) (scalar_zipWith ((+) $dNumInt))))
336 --
337 -- while adding the following entry to the vectorisation map: '$dNumInt' --> '$v$dNumInt'.
338 --
339 -- See "Note [Vectorising classes]" in 'Vectorise.Type.Env' for the definition of 'V:Num'.
340 --
341 -- NB: The outlined vectorisation scheme does not require the right-hand side of the original dfun.
342 -- In fact, we definitely want to refer to the dfn variable instead of the right-hand side to
343 -- ensure that the dictionary selection rules fire.
344 --
345 vectTopRhs :: [Var] -- ^ Names of all functions in the rec block
346 -> Var -- ^ Name of the binding.
347 -> CoreExpr -- ^ Body of the binding.
348 -> VM ( Inline -- (1) inline specification for the binding
349 , Bool -- (2) whether the right-hand side is a scalar computation
350 , CoreExpr) -- (3) the vectorised right-hand side
351 vectTopRhs recFs var expr
352 = closedV
353 $ do { globalScalar <- isGlobalScalarVar var
354 ; vectDecl <- lookupVectDecl var
355 ; dflags <- getDynFlags
356 ; let isDFun = isDFunId var
357
358 ; traceVt ("vectTopRhs of " ++ showPpr dflags var ++ info globalScalar isDFun vectDecl ++ ":") $
359 ppr expr
360
361 ; rhs globalScalar isDFun vectDecl
362 }
363 where
364 rhs _globalScalar _isDFun (Just (_, expr')) -- Case (1)
365 = return (inlineMe, False, expr')
366 rhs True False Nothing -- Case (2)
367 = do { expr' <- vectScalarFun expr
368 ; return (inlineMe, True, vectorised expr')
369 }
370 rhs True True Nothing -- Case (3)
371 = do { expr' <- vectScalarDFun var
372 ; return (DontInline, True, expr')
373 }
374 rhs False False Nothing -- Case (4) — not a dfun
375 = do { let exprFvs = freeVars expr
376 ; (inline, isScalar, vexpr)
377 <- inBind var $
378 vectPolyExpr (isStrongLoopBreaker $ idOccInfo var) recFs exprFvs Nothing
379 ; return (inline, isScalar, vectorised vexpr)
380 }
381 rhs False True Nothing -- Case (4) — is a dfun
382 = do { expr' <- vectDictExpr expr
383 ; return (DontInline, True, expr')
384 }
385
386 info True False _ = " [VECTORISE SCALAR]"
387 info True True _ = " [VECTORISE SCALAR instance]"
388 info False _ vectDecl | isJust vectDecl = " [VECTORISE]"
389 | otherwise = " (no pragma)"
390
391 -- |Project out the vectorised version of a binding from some closure,
392 -- or return the original body if that doesn't work or the binding is scalar.
393 --
394 tryConvert :: Var -- ^ Name of the original binding (eg @foo@)
395 -> Var -- ^ Name of vectorised version of binding (eg @$vfoo@)
396 -> CoreExpr -- ^ The original body of the binding.
397 -> VM CoreExpr
398 tryConvert var vect_var rhs
399 = do { globalScalar <- isGlobalScalarVar var
400 ; if globalScalar
401 then
402 return rhs
403 else
404 fromVect (idType var) (Var vect_var)
405 `orElseErrV`
406 do { emitVt " Could NOT call vectorised from original version" $ ppr var
407 ; return rhs
408 }
409 }