VECTORISE pragmas for type classes and instances
[ghc.git] / compiler / vectorise / Vectorise.hs
1 -- Main entry point to the vectoriser. It is invoked iff the option '-fvectorise' is passed.
2 --
3 -- This module provides the function 'vectorise', which vectorises an entire (desugared) module.
4 -- It vectorises all type declarations and value bindings. It also processes all VECTORISE pragmas
5 -- (aka vectorisation declarations), which can lead to the vectorisation of imported data types
6 -- and the enrichment of imported functions with vectorised versions.
7
8 module Vectorise ( vectorise )
9 where
10
11 import Vectorise.Type.Env
12 import Vectorise.Type.Type
13 import Vectorise.Convert
14 import Vectorise.Utils.Hoisting
15 import Vectorise.Exp
16 import Vectorise.Vect
17 import Vectorise.Env
18 import Vectorise.Monad
19
20 import HscTypes hiding ( MonadThings(..) )
21 import CoreUnfold ( mkInlineUnfolding )
22 import CoreFVs
23 import PprCore
24 import CoreSyn
25 import CoreMonad ( CoreM, getHscEnv )
26 import Type
27 import Id
28 import DynFlags
29 import BasicTypes ( isStrongLoopBreaker )
30 import Outputable
31 import Util ( zipLazy )
32 import MonadUtils
33
34 import Control.Monad
35 import Data.Maybe
36
37
38 -- |Vectorise a single module.
39 --
40 vectorise :: ModGuts -> CoreM ModGuts
41 vectorise guts
42 = do { hsc_env <- getHscEnv
43 ; liftIO $ vectoriseIO hsc_env guts
44 }
45
46 -- Vectorise a single monad, given the dynamic compiler flags and HscEnv.
47 --
48 vectoriseIO :: HscEnv -> ModGuts -> IO ModGuts
49 vectoriseIO hsc_env guts
50 = do { -- Get information about currently loaded external packages.
51 ; eps <- hscEPS hsc_env
52
53 -- Combine vectorisation info from the current module, and external ones.
54 ; let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
55
56 -- Run the main VM computation.
57 ; Just (info', guts') <- initV hsc_env guts info (vectModule guts)
58 ; return (guts' { mg_vect_info = info' })
59 }
60
61 -- Vectorise a single module, in the VM monad.
62 --
63 vectModule :: ModGuts -> VM ModGuts
64 vectModule guts@(ModGuts { mg_tcs = tycons
65 , mg_clss = classes
66 , mg_insts = insts
67 , mg_binds = binds
68 , mg_fam_insts = fam_insts
69 , mg_vect_decls = vect_decls
70 })
71 = do { dumpOptVt Opt_D_dump_vt_trace "Before vectorisation" $
72 pprCoreBindings binds
73
74 -- Vectorise the type environment. This will add vectorised
75 -- type constructors, their representaions, and the
76 -- conrresponding data constructors. Moreover, we produce
77 -- bindings for dfuns and family instances of the classes
78 -- and type families used in the DPH library to represent
79 -- array types.
80 ; (new_tycons, new_fam_insts, tc_binds) <- vectTypeEnv tycons [vd
81 | vd@(VectType _ _ _) <- vect_decls]
82
83 ; let new_classes = [] -- !!!FIXME
84 new_insts = []
85 -- !!!we need to compute an extended 'mg_inst_env' as well!!!
86
87 -- Family instance environment for /all/ home-package modules including those instances
88 -- generated by 'vectTypeEnv'.
89 ; (_, fam_inst_env) <- readGEnv global_fam_inst_env
90
91 -- Vectorise all the top level bindings and VECTORISE declarations on imported identifiers
92 ; binds_top <- mapM vectTopBind binds
93 ; binds_imp <- mapM vectImpBind [imp_id | Vect imp_id _ <- vect_decls, isGlobalId imp_id]
94
95 ; return $ guts { mg_tcs = tycons ++ new_tycons
96 , mg_clss = classes ++ new_classes
97 , mg_insts = insts ++ new_insts
98 , mg_binds = Rec tc_binds : (binds_top ++ binds_imp)
99 , mg_fam_inst_env = fam_inst_env
100 , mg_fam_insts = fam_insts ++ new_fam_insts
101 }
102 }
103
104 -- Try to vectorise a top-level binding. If it doesn't vectorise then return it unharmed.
105 --
106 -- For example, for the binding
107 --
108 -- @
109 -- foo :: Int -> Int
110 -- foo = \x -> x + x
111 -- @
112 --
113 -- we get
114 -- @
115 -- foo :: Int -> Int
116 -- foo = \x -> vfoo $: x
117 --
118 -- v_foo :: Closure void vfoo lfoo
119 -- v_foo = closure vfoo lfoo void
120 --
121 -- vfoo :: Void -> Int -> Int
122 -- vfoo = ...
123 --
124 -- lfoo :: PData Void -> PData Int -> PData Int
125 -- lfoo = ...
126 -- @
127 --
128 -- @vfoo@ is the "vectorised", or scalar, version that does the same as the original
129 -- function foo, but takes an explicit environment.
130 --
131 -- @lfoo@ is the "lifted" version that works on arrays.
132 --
133 -- @v_foo@ combines both of these into a `Closure` that also contains the
134 -- environment.
135 --
136 -- The original binding @foo@ is rewritten to call the vectorised version
137 -- present in the closure.
138 --
139 -- Vectorisation may be surpressed by annotating a binding with a 'NOVECTORISE' pragma. If this
140 -- pragma is used in a group of mutually recursive bindings, either all or no binding must have
141 -- the pragma. If only some bindings are annotated, a fatal error is being raised.
142 -- FIXME: Once we support partial vectorisation, we may be able to vectorise parts of a group, or
143 -- we may emit a warning and refrain from vectorising the entire group.
144 --
145 vectTopBind :: CoreBind -> VM CoreBind
146 vectTopBind b@(NonRec var expr)
147 = unlessNoVectDecl $
148 do { -- Vectorise the right-hand side, create an appropriate top-level binding and add it
149 -- to the vectorisation map.
150 ; (inline, isScalar, expr') <- vectTopRhs [] var expr
151 ; var' <- vectTopBinder var inline expr'
152 ; when isScalar $
153 addGlobalScalar var
154
155 -- We replace the original top-level binding by a value projected from the vectorised
156 -- closure and add any newly created hoisted top-level bindings.
157 ; cexpr <- tryConvert var var' expr
158 ; hs <- takeHoisted
159 ; return . Rec $ (var, cexpr) : (var', expr') : hs
160 }
161 `orElseErrV`
162 do { emitVt " Could NOT vectorise top-level binding" $ ppr var
163 ; return b
164 }
165 where
166 unlessNoVectDecl vectorise
167 = do { hasNoVectDecl <- noVectDecl var
168 ; when hasNoVectDecl $
169 traceVt "NOVECTORISE" $ ppr var
170 ; if hasNoVectDecl then return b else vectorise
171 }
172 vectTopBind b@(Rec bs)
173 = unlessSomeNoVectDecl $
174 do { (vars', _, exprs', hs) <- fixV $
175 \ ~(_, inlines, rhss, _) ->
176 do { -- Vectorise the right-hand sides, create an appropriate top-level bindings
177 -- and add them to the vectorisation map.
178 ; vars' <- sequence [vectTopBinder var inline rhs
179 | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
180 ; (inlines, areScalars, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
181 ; hs <- takeHoisted
182 ; if and areScalars
183 then -- (1) Entire recursive group is scalar
184 -- => add all variables to the global set of scalars
185 do { mapM_ addGlobalScalar vars
186 ; return (vars', inlines, exprs', hs)
187 }
188 else -- (2) At least one binding is not scalar
189 -- => vectorise again with empty set of local scalars
190 do { (inlines, _, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
191 ; hs <- takeHoisted
192 ; return (vars', inlines, exprs', hs)
193 }
194 }
195
196 -- Replace the original top-level bindings by a values projected from the vectorised
197 -- closures and add any newly created hoisted top-level bindings to the group.
198 ; cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
199 ; return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
200 }
201 `orElseErrV`
202 return b
203 where
204 (vars, exprs) = unzip bs
205
206 unlessSomeNoVectDecl vectorise
207 = do { hasNoVectDecls <- mapM noVectDecl vars
208 ; when (and hasNoVectDecls) $
209 traceVt "NOVECTORISE" $ ppr vars
210 ; if and hasNoVectDecls
211 then return b -- all bindings have 'NOVECTORISE'
212 else if or hasNoVectDecls
213 then cantVectorise noVectoriseErr (ppr b) -- some (but not all) have 'NOVECTORISE'
214 else vectorise -- no binding has a 'NOVECTORISE' decl
215 }
216 noVectoriseErr = "NOVECTORISE must be used on all or no bindings of a recursive group"
217
218 -- Add a vectorised binding to an imported top-level variable that has a VECTORISE [SCALAR] pragma
219 -- in this module.
220 --
221 vectImpBind :: Id -> VM CoreBind
222 vectImpBind var
223 = do { -- Vectorise the right-hand side, create an appropriate top-level binding and add it
224 -- to the vectorisation map. For the non-lifted version, we refer to the original
225 -- definition — i.e., 'Var var'.
226 ; (inline, isScalar, expr') <- vectTopRhs [] var (Var var)
227 ; var' <- vectTopBinder var inline expr'
228 ; when isScalar $
229 addGlobalScalar var
230
231 -- We add any newly created hoisted top-level bindings.
232 ; hs <- takeHoisted
233 ; return . Rec $ (var', expr') : hs
234 }
235
236 -- | Make the vectorised version of this top level binder, and add the mapping
237 -- between it and the original to the state. For some binder @foo@ the vectorised
238 -- version is @$v_foo@
239 --
240 -- NOTE: 'vectTopBinder' *MUST* be lazy in inline and expr because of how it is
241 -- used inside of 'fixV' in 'vectTopBind'.
242 --
243 vectTopBinder :: Var -- ^ Name of the binding.
244 -> Inline -- ^ Whether it should be inlined, used to annotate it.
245 -> CoreExpr -- ^ RHS of binding, used to set the 'Unfolding' of the returned 'Var'.
246 -> VM Var -- ^ Name of the vectorised binding.
247 vectTopBinder var inline expr
248 = do { -- Vectorise the type attached to the var.
249 ; vty <- vectType (idType var)
250
251 -- If there is a vectorisation declartion for this binding, make sure that its type
252 -- matches
253 ; vectDecl <- lookupVectDecl var
254 ; case vectDecl of
255 Nothing -> return ()
256 Just (vdty, _)
257 | eqType vty vdty -> return ()
258 | otherwise ->
259 cantVectorise ("Type mismatch in vectorisation pragma for " ++ show var) $
260 (text "Expected type" <+> ppr vty)
261 $$
262 (text "Inferred type" <+> ppr vdty)
263
264 -- Make the vectorised version of binding's name, and set the unfolding used for inlining
265 ; var' <- liftM (`setIdUnfoldingLazily` unfolding)
266 $ mkVectId var vty
267
268 -- Add the mapping between the plain and vectorised name to the state.
269 ; defGlobalVar var var'
270
271 ; return var'
272 }
273 where
274 unfolding = case inline of
275 Inline arity -> mkInlineUnfolding (Just arity) expr
276 DontInline -> noUnfolding
277
278 -- | Vectorise the RHS of a top-level binding, in an empty local environment.
279 --
280 -- We need to distinguish three cases:
281 --
282 -- (1) We have a (non-scalar) vectorisation declaration for the variable (which explicitly provides
283 -- vectorised code implemented by the user)
284 -- => no automatic vectorisation & instead use the user-supplied code
285 --
286 -- (2) We have a scalar vectorisation declaration for the variable
287 -- => generate vectorised code that uses a scalar 'map'/'zipWith' to lift the computation
288 --
289 -- (3) There is no vectorisation declaration for the variable
290 -- => perform automatic vectorisation of the RHS
291 --
292 vectTopRhs :: [Var] -- ^ Names of all functions in the rec block
293 -> Var -- ^ Name of the binding.
294 -> CoreExpr -- ^ Body of the binding.
295 -> VM ( Inline -- (1) inline specification for the binding
296 , Bool -- (2) whether the right-hand side is a scalar computation
297 , CoreExpr) -- (3) the vectorised right-hand side
298 vectTopRhs recFs var expr
299 = closedV
300 $ do { globalScalar <- isGlobalScalar var
301 ; vectDecl <- lookupVectDecl var
302
303 ; traceVt ("vectTopRhs of " ++ show var ++ info globalScalar vectDecl) $ ppr expr
304
305 ; rhs globalScalar vectDecl
306 }
307 where
308 rhs _globalScalar (Just (_, expr')) -- Case (1)
309 = return (inlineMe, False, expr')
310 rhs True Nothing -- Case (2)
311 = do { expr' <- vectScalarFun True recFs expr
312 ; return (inlineMe, True, vectorised expr')
313 }
314 rhs False Nothing -- Case (3)
315 = do { let fvs = freeVars expr
316 ; (inline, isScalar, vexpr)
317 <- inBind var $
318 vectPolyExpr (isStrongLoopBreaker $ idOccInfo var) recFs fvs
319 ; return (inline, isScalar, vectorised vexpr)
320 }
321
322 info True _ = " [VECTORISE SCALAR]"
323 info False vectDecl | isJust vectDecl = " [VECTORISE]"
324 | otherwise = " (no pragma)"
325
326 -- |Project out the vectorised version of a binding from some closure,
327 -- or return the original body if that doesn't work or the binding is scalar.
328 --
329 tryConvert :: Var -- ^ Name of the original binding (eg @foo@)
330 -> Var -- ^ Name of vectorised version of binding (eg @$vfoo@)
331 -> CoreExpr -- ^ The original body of the binding.
332 -> VM CoreExpr
333 tryConvert var vect_var rhs
334 = do { globalScalar <- isGlobalScalar var
335 ; if globalScalar
336 then
337 return rhs
338 else
339 fromVect (idType var) (Var vect_var)
340 `orElseErrV`
341 do { emitVt " Could NOT call vectorised from original version" $ ppr var
342 ; return rhs
343 }
344 }