daa2ed0725d0ee0fc45e4372c8b9d1812b36f420
[ghc.git] / compiler / vectorise / Vectorise.hs
1 -- Main entry point to the vectoriser. It is invoked iff the option '-fvectorise' is passed.
2 --
3 -- This module provides the function 'vectorise', which vectorises an entire (desugared) module.
4 -- It vectorises all type declarations and value bindings. It also processes all VECTORISE pragmas
5 -- (aka vectorisation declarations), which can lead to the vectorisation of imported data types
6 -- and the enrichment of imported functions with vectorised versions.
7
8 module Vectorise ( vectorise )
9 where
10
11 import Vectorise.Type.Env
12 import Vectorise.Type.Type
13 import Vectorise.Convert
14 import Vectorise.Utils.Hoisting
15 import Vectorise.Exp
16 import Vectorise.Vect
17 import Vectorise.Env
18 import Vectorise.Monad
19
20 import HscTypes hiding ( MonadThings(..) )
21 import CoreUnfold ( mkInlineUnfolding )
22 import CoreFVs
23 import PprCore
24 import CoreSyn
25 import CoreMonad ( CoreM, getHscEnv )
26 import Type
27 import Id
28 import DynFlags
29 import BasicTypes ( isStrongLoopBreaker )
30 import Outputable
31 import Util ( zipLazy )
32 import MonadUtils
33
34 import Control.Monad
35 import Data.Maybe
36
37
38 -- |Vectorise a single module.
39 --
40 vectorise :: ModGuts -> CoreM ModGuts
41 vectorise guts
42 = do { hsc_env <- getHscEnv
43 ; liftIO $ vectoriseIO hsc_env guts
44 }
45
46 -- Vectorise a single monad, given the dynamic compiler flags and HscEnv.
47 --
48 vectoriseIO :: HscEnv -> ModGuts -> IO ModGuts
49 vectoriseIO hsc_env guts
50 = do { -- Get information about currently loaded external packages.
51 ; eps <- hscEPS hsc_env
52
53 -- Combine vectorisation info from the current module, and external ones.
54 ; let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
55
56 -- Run the main VM computation.
57 ; Just (info', guts') <- initV hsc_env guts info (vectModule guts)
58 ; return (guts' { mg_vect_info = info' })
59 }
60
61 -- Vectorise a single module, in the VM monad.
62 --
63 vectModule :: ModGuts -> VM ModGuts
64 vectModule guts@(ModGuts { mg_tcs = tycons
65 , mg_binds = binds
66 , mg_fam_insts = fam_insts
67 , mg_vect_decls = vect_decls
68 })
69 = do { dumpOptVt Opt_D_dump_vt_trace "Before vectorisation" $
70 pprCoreBindings binds
71
72 -- Vectorise the type environment. This will add vectorised
73 -- type constructors, their representaions, and the
74 -- conrresponding data constructors. Moreover, we produce
75 -- bindings for dfuns and family instances of the classes
76 -- and type families used in the DPH library to represent
77 -- array types.
78 ; (tycons', new_fam_insts, tc_binds) <- vectTypeEnv tycons [vd
79 | vd@(VectType _ _ _) <- vect_decls]
80
81 ; (_, fam_inst_env) <- readGEnv global_fam_inst_env
82
83 -- Vectorise all the top level bindings and VECTORISE declarations on imported identifiers
84 ; binds_top <- mapM vectTopBind binds
85 ; binds_imp <- mapM vectImpBind [imp_id | Vect imp_id _ <- vect_decls, isGlobalId imp_id]
86
87 ; return $ guts { mg_tcs = tycons'
88 , mg_binds = Rec tc_binds : (binds_top ++ binds_imp)
89 , mg_fam_inst_env = fam_inst_env
90 , mg_fam_insts = fam_insts ++ new_fam_insts
91 }
92 }
93
94 -- Try to vectorise a top-level binding. If it doesn't vectorise then return it unharmed.
95 --
96 -- For example, for the binding
97 --
98 -- @
99 -- foo :: Int -> Int
100 -- foo = \x -> x + x
101 -- @
102 --
103 -- we get
104 -- @
105 -- foo :: Int -> Int
106 -- foo = \x -> vfoo $: x
107 --
108 -- v_foo :: Closure void vfoo lfoo
109 -- v_foo = closure vfoo lfoo void
110 --
111 -- vfoo :: Void -> Int -> Int
112 -- vfoo = ...
113 --
114 -- lfoo :: PData Void -> PData Int -> PData Int
115 -- lfoo = ...
116 -- @
117 --
118 -- @vfoo@ is the "vectorised", or scalar, version that does the same as the original
119 -- function foo, but takes an explicit environment.
120 --
121 -- @lfoo@ is the "lifted" version that works on arrays.
122 --
123 -- @v_foo@ combines both of these into a `Closure` that also contains the
124 -- environment.
125 --
126 -- The original binding @foo@ is rewritten to call the vectorised version
127 -- present in the closure.
128 --
129 -- Vectorisation may be surpressed by annotating a binding with a 'NOVECTORISE' pragma. If this
130 -- pragma is used in a group of mutually recursive bindings, either all or no binding must have
131 -- the pragma. If only some bindings are annotated, a fatal error is being raised.
132 -- FIXME: Once we support partial vectorisation, we may be able to vectorise parts of a group, or
133 -- we may emit a warning and refrain from vectorising the entire group.
134 --
135 vectTopBind :: CoreBind -> VM CoreBind
136 vectTopBind b@(NonRec var expr)
137 = unlessNoVectDecl $
138 do { -- Vectorise the right-hand side, create an appropriate top-level binding and add it
139 -- to the vectorisation map.
140 ; (inline, isScalar, expr') <- vectTopRhs [] var expr
141 ; var' <- vectTopBinder var inline expr'
142 ; when isScalar $
143 addGlobalScalar var
144
145 -- We replace the original top-level binding by a value projected from the vectorised
146 -- closure and add any newly created hoisted top-level bindings.
147 ; cexpr <- tryConvert var var' expr
148 ; hs <- takeHoisted
149 ; return . Rec $ (var, cexpr) : (var', expr') : hs
150 }
151 `orElseErrV`
152 do { emitVt " Could NOT vectorise top-level binding" $ ppr var
153 ; return b
154 }
155 where
156 unlessNoVectDecl vectorise
157 = do { hasNoVectDecl <- noVectDecl var
158 ; when hasNoVectDecl $
159 traceVt "NOVECTORISE" $ ppr var
160 ; if hasNoVectDecl then return b else vectorise
161 }
162 vectTopBind b@(Rec bs)
163 = unlessSomeNoVectDecl $
164 do { (vars', _, exprs', hs) <- fixV $
165 \ ~(_, inlines, rhss, _) ->
166 do { -- Vectorise the right-hand sides, create an appropriate top-level bindings
167 -- and add them to the vectorisation map.
168 ; vars' <- sequence [vectTopBinder var inline rhs
169 | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
170 ; (inlines, areScalars, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
171 ; hs <- takeHoisted
172 ; if and areScalars
173 then -- (1) Entire recursive group is scalar
174 -- => add all variables to the global set of scalars
175 do { mapM_ addGlobalScalar vars
176 ; return (vars', inlines, exprs', hs)
177 }
178 else -- (2) At least one binding is not scalar
179 -- => vectorise again with empty set of local scalars
180 do { (inlines, _, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
181 ; hs <- takeHoisted
182 ; return (vars', inlines, exprs', hs)
183 }
184 }
185
186 -- Replace the original top-level bindings by a values projected from the vectorised
187 -- closures and add any newly created hoisted top-level bindings to the group.
188 ; cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
189 ; return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
190 }
191 `orElseErrV`
192 return b
193 where
194 (vars, exprs) = unzip bs
195
196 unlessSomeNoVectDecl vectorise
197 = do { hasNoVectDecls <- mapM noVectDecl vars
198 ; when (and hasNoVectDecls) $
199 traceVt "NOVECTORISE" $ ppr vars
200 ; if and hasNoVectDecls
201 then return b -- all bindings have 'NOVECTORISE'
202 else if or hasNoVectDecls
203 then cantVectorise noVectoriseErr (ppr b) -- some (but not all) have 'NOVECTORISE'
204 else vectorise -- no binding has a 'NOVECTORISE' decl
205 }
206 noVectoriseErr = "NOVECTORISE must be used on all or no bindings of a recursive group"
207
208 -- Add a vectorised binding to an imported top-level variable that has a VECTORISE [SCALAR] pragma
209 -- in this module.
210 --
211 vectImpBind :: Id -> VM CoreBind
212 vectImpBind var
213 = do { -- Vectorise the right-hand side, create an appropriate top-level binding and add it
214 -- to the vectorisation map. For the non-lifted version, we refer to the original
215 -- definition — i.e., 'Var var'.
216 ; (inline, isScalar, expr') <- vectTopRhs [] var (Var var)
217 ; var' <- vectTopBinder var inline expr'
218 ; when isScalar $
219 addGlobalScalar var
220
221 -- We add any newly created hoisted top-level bindings.
222 ; hs <- takeHoisted
223 ; return . Rec $ (var', expr') : hs
224 }
225
226 -- | Make the vectorised version of this top level binder, and add the mapping
227 -- between it and the original to the state. For some binder @foo@ the vectorised
228 -- version is @$v_foo@
229 --
230 -- NOTE: 'vectTopBinder' *MUST* be lazy in inline and expr because of how it is
231 -- used inside of 'fixV' in 'vectTopBind'.
232 --
233 vectTopBinder :: Var -- ^ Name of the binding.
234 -> Inline -- ^ Whether it should be inlined, used to annotate it.
235 -> CoreExpr -- ^ RHS of binding, used to set the 'Unfolding' of the returned 'Var'.
236 -> VM Var -- ^ Name of the vectorised binding.
237 vectTopBinder var inline expr
238 = do { -- Vectorise the type attached to the var.
239 ; vty <- vectType (idType var)
240
241 -- If there is a vectorisation declartion for this binding, make sure that its type
242 -- matches
243 ; vectDecl <- lookupVectDecl var
244 ; case vectDecl of
245 Nothing -> return ()
246 Just (vdty, _)
247 | eqType vty vdty -> return ()
248 | otherwise ->
249 cantVectorise ("Type mismatch in vectorisation pragma for " ++ show var) $
250 (text "Expected type" <+> ppr vty)
251 $$
252 (text "Inferred type" <+> ppr vdty)
253
254 -- Make the vectorised version of binding's name, and set the unfolding used for inlining
255 ; var' <- liftM (`setIdUnfoldingLazily` unfolding)
256 $ mkVectId var vty
257
258 -- Add the mapping between the plain and vectorised name to the state.
259 ; defGlobalVar var var'
260
261 ; return var'
262 }
263 where
264 unfolding = case inline of
265 Inline arity -> mkInlineUnfolding (Just arity) expr
266 DontInline -> noUnfolding
267
268 -- | Vectorise the RHS of a top-level binding, in an empty local environment.
269 --
270 -- We need to distinguish three cases:
271 --
272 -- (1) We have a (non-scalar) vectorisation declaration for the variable (which explicitly provides
273 -- vectorised code implemented by the user)
274 -- => no automatic vectorisation & instead use the user-supplied code
275 --
276 -- (2) We have a scalar vectorisation declaration for the variable
277 -- => generate vectorised code that uses a scalar 'map'/'zipWith' to lift the computation
278 --
279 -- (3) There is no vectorisation declaration for the variable
280 -- => perform automatic vectorisation of the RHS
281 --
282 vectTopRhs :: [Var] -- ^ Names of all functions in the rec block
283 -> Var -- ^ Name of the binding.
284 -> CoreExpr -- ^ Body of the binding.
285 -> VM ( Inline -- (1) inline specification for the binding
286 , Bool -- (2) whether the right-hand side is a scalar computation
287 , CoreExpr) -- (3) the vectorised right-hand side
288 vectTopRhs recFs var expr
289 = closedV
290 $ do { globalScalar <- isGlobalScalar var
291 ; vectDecl <- lookupVectDecl var
292
293 ; traceVt ("vectTopRhs of " ++ show var ++ info globalScalar vectDecl) $ ppr expr
294
295 ; rhs globalScalar vectDecl
296 }
297 where
298 rhs _globalScalar (Just (_, expr')) -- Case (1)
299 = return (inlineMe, False, expr')
300 rhs True Nothing -- Case (2)
301 = do { expr' <- vectScalarFun True recFs expr
302 ; return (inlineMe, True, vectorised expr')
303 }
304 rhs False Nothing -- Case (3)
305 = do { let fvs = freeVars expr
306 ; (inline, isScalar, vexpr)
307 <- inBind var $
308 vectPolyExpr (isStrongLoopBreaker $ idOccInfo var) recFs fvs
309 ; return (inline, isScalar, vectorised vexpr)
310 }
311
312 info True _ = " [VECTORISE SCALAR]"
313 info False vectDecl | isJust vectDecl = " [VECTORISE]"
314 | otherwise = " (no pragma)"
315
316 -- |Project out the vectorised version of a binding from some closure,
317 -- or return the original body if that doesn't work or the binding is scalar.
318 --
319 tryConvert :: Var -- ^ Name of the original binding (eg @foo@)
320 -> Var -- ^ Name of vectorised version of binding (eg @$vfoo@)
321 -> CoreExpr -- ^ The original body of the binding.
322 -> VM CoreExpr
323 tryConvert var vect_var rhs
324 = do { globalScalar <- isGlobalScalar var
325 ; if globalScalar
326 then
327 return rhs
328 else
329 fromVect (idType var) (Var vect_var)
330 `orElseErrV`
331 do { emitVt " Could NOT call vectorised from original version" $ ppr var
332 ; return rhs
333 }
334 }