Add VECTORISE [SCALAR] type pragma
[ghc.git] / compiler / vectorise / Vectorise.hs
1 -- Main entry point to the vectoriser. It is invoked iff the option '-fvectorise' is passed.
2 --
3 -- This module provides the function 'vectorise', which vectorises an entire (desugared) module.
4 -- It vectorises all type declarations and value bindings. It also processes all VECTORISE pragmas
5 -- (aka vectorisation declarations), which can lead to the vectorisation of imported data types
6 -- and the enrichment of imported functions with vectorised versions.
7
8 module Vectorise ( vectorise )
9 where
10
11 import Vectorise.Type.Env
12 import Vectorise.Type.Type
13 import Vectorise.Convert
14 import Vectorise.Utils.Hoisting
15 import Vectorise.Exp
16 import Vectorise.Vect
17 import Vectorise.Env
18 import Vectorise.Monad
19
20 import HscTypes hiding ( MonadThings(..) )
21 import CoreUnfold ( mkInlineUnfolding )
22 import CoreFVs
23 import PprCore
24 import CoreSyn
25 import CoreMonad ( CoreM, getHscEnv )
26 import Type
27 import Id
28 import OccName
29 import DynFlags
30 import BasicTypes ( isStrongLoopBreaker )
31 import Outputable
32 import Util ( zipLazy )
33 import MonadUtils
34
35 import Control.Monad
36
37
38 -- | Vectorise a single module.
39 --
40 vectorise :: ModGuts -> CoreM ModGuts
41 vectorise guts
42 = do { hsc_env <- getHscEnv
43 ; liftIO $ vectoriseIO hsc_env guts
44 }
45
46 -- | Vectorise a single monad, given the dynamic compiler flags and HscEnv.
47 --
48 vectoriseIO :: HscEnv -> ModGuts -> IO ModGuts
49 vectoriseIO hsc_env guts
50 = do { -- Get information about currently loaded external packages.
51 ; eps <- hscEPS hsc_env
52
53 -- Combine vectorisation info from the current module, and external ones.
54 ; let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
55
56 -- Run the main VM computation.
57 ; Just (info', guts') <- initV hsc_env guts info (vectModule guts)
58 ; return (guts' { mg_vect_info = info' })
59 }
60
61 -- | Vectorise a single module, in the VM monad.
62 --
63 vectModule :: ModGuts -> VM ModGuts
64 vectModule guts@(ModGuts { mg_types = types
65 , mg_binds = binds
66 , mg_fam_insts = fam_insts
67 , mg_vect_decls = vect_decls
68 })
69 = do { dumpOptVt Opt_D_dump_vt_trace "Before vectorisation" $
70 pprCoreBindings binds
71
72 -- Vectorise the type environment. This will add vectorised type constructors, their
73 -- representaions, and the conrresponding data constructors. Moreover, we produce
74 -- bindings for dfuns and family instances of the classes and type families used in the
75 -- DPH library to represent array types.
76 ; (types', new_fam_insts, tc_binds) <- vectTypeEnv types [vd | vd@(VectType _ _) <- vect_decls]
77
78 ; (_, fam_inst_env) <- readGEnv global_fam_inst_env
79
80 -- Vectorise all the top level bindings.
81 ; binds' <- mapM vectTopBind binds
82
83 ; return $ guts { mg_types = types'
84 , mg_binds = Rec tc_binds : binds'
85 , mg_fam_inst_env = fam_inst_env
86 , mg_fam_insts = fam_insts ++ new_fam_insts
87 }
88 }
89
90 -- |Try to vectorise a top-level binding. If it doesn't vectorise then return it unharmed.
91 --
92 -- For example, for the binding
93 --
94 -- @
95 -- foo :: Int -> Int
96 -- foo = \x -> x + x
97 -- @
98 --
99 -- we get
100 -- @
101 -- foo :: Int -> Int
102 -- foo = \x -> vfoo $: x
103 --
104 -- v_foo :: Closure void vfoo lfoo
105 -- v_foo = closure vfoo lfoo void
106 --
107 -- vfoo :: Void -> Int -> Int
108 -- vfoo = ...
109 --
110 -- lfoo :: PData Void -> PData Int -> PData Int
111 -- lfoo = ...
112 -- @
113 --
114 -- @vfoo@ is the "vectorised", or scalar, version that does the same as the original
115 -- function foo, but takes an explicit environment.
116 --
117 -- @lfoo@ is the "lifted" version that works on arrays.
118 --
119 -- @v_foo@ combines both of these into a `Closure` that also contains the
120 -- environment.
121 --
122 -- The original binding @foo@ is rewritten to call the vectorised version
123 -- present in the closure.
124 --
125 -- Vectorisation may be surpressed by annotating a binding with a 'NOVECTORISE' pragma. If this
126 -- pragma is used in a group of mutually recursive bindings, either all or no binding must have
127 -- the pragma. If only some bindings are annotated, a fatal error is being raised.
128 -- FIXME: Once we support partial vectorisation, we may be able to vectorise parts of a group, or
129 -- we may emit a warning and refrain from vectorising the entire group.
130 --
131 vectTopBind :: CoreBind -> VM CoreBind
132 vectTopBind b@(NonRec var expr)
133 = unlessNoVectDecl $
134 do { -- Vectorise the right-hand side, create an appropriate top-level binding and add it
135 -- to the vectorisation map.
136 ; (inline, isScalar, expr') <- vectTopRhs [] var expr
137 ; var' <- vectTopBinder var inline expr'
138 ; when isScalar $
139 addGlobalScalar var
140
141 -- We replace the original top-level binding by a value projected from the vectorised
142 -- closure and add any newly created hoisted top-level bindings.
143 ; cexpr <- tryConvert var var' expr
144 ; hs <- takeHoisted
145 ; return . Rec $ (var, cexpr) : (var', expr') : hs
146 }
147 `orElseV`
148 return b
149 where
150 unlessNoVectDecl vectorise
151 = do { hasNoVectDecl <- noVectDecl var
152 ; when hasNoVectDecl $
153 traceVt "NOVECTORISE" $ ppr var
154 ; if hasNoVectDecl then return b else vectorise
155 }
156 vectTopBind b@(Rec bs)
157 = unlessSomeNoVectDecl $
158 do { (vars', _, exprs', hs) <- fixV $
159 \ ~(_, inlines, rhss, _) ->
160 do { -- Vectorise the right-hand sides, create an appropriate top-level bindings
161 -- and add them to the vectorisation map.
162 ; vars' <- sequence [vectTopBinder var inline rhs
163 | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
164 ; (inlines, areScalars, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
165 ; hs <- takeHoisted
166 ; if and areScalars
167 then -- (1) Entire recursive group is scalar
168 -- => add all variables to the global set of scalars
169 do { mapM_ addGlobalScalar vars
170 ; return (vars', inlines, exprs', hs)
171 }
172 else -- (2) At least one binding is not scalar
173 -- => vectorise again with empty set of local scalars
174 do { (inlines, _, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
175 ; hs <- takeHoisted
176 ; return (vars', inlines, exprs', hs)
177 }
178 }
179
180 -- Replace the original top-level bindings by a values projected from the vectorised
181 -- closures and add any newly created hoisted top-level bindings to the group.
182 ; cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
183 ; return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
184 }
185 `orElseV`
186 return b
187 where
188 (vars, exprs) = unzip bs
189
190 unlessSomeNoVectDecl vectorise
191 = do { hasNoVectDecls <- mapM noVectDecl vars
192 ; when (and hasNoVectDecls) $
193 traceVt "NOVECTORISE" $ ppr vars
194 ; if and hasNoVectDecls
195 then return b -- all bindings have 'NOVECTORISE'
196 else if or hasNoVectDecls
197 then cantVectorise noVectoriseErr (ppr b) -- some (but not all) have 'NOVECTORISE'
198 else vectorise -- no binding has a 'NOVECTORISE' decl
199 }
200 noVectoriseErr = "NOVECTORISE must be used on all or no bindings of a recursive group"
201
202 -- | Make the vectorised version of this top level binder, and add the mapping
203 -- between it and the original to the state. For some binder @foo@ the vectorised
204 -- version is @$v_foo@
205 --
206 -- NOTE: 'vectTopBinder' *MUST* be lazy in inline and expr because of how it is
207 -- used inside of 'fixV' in 'vectTopBind'.
208 --
209 vectTopBinder :: Var -- ^ Name of the binding.
210 -> Inline -- ^ Whether it should be inlined, used to annotate it.
211 -> CoreExpr -- ^ RHS of binding, used to set the 'Unfolding' of the returned 'Var'.
212 -> VM Var -- ^ Name of the vectorised binding.
213 vectTopBinder var inline expr
214 = do { -- Vectorise the type attached to the var.
215 ; vty <- vectType (idType var)
216
217 -- If there is a vectorisation declartion for this binding, make sure that its type
218 -- matches
219 ; vectDecl <- lookupVectDecl var
220 ; case vectDecl of
221 Nothing -> return ()
222 Just (vdty, _)
223 | eqType vty vdty -> return ()
224 | otherwise ->
225 cantVectorise ("Type mismatch in vectorisation pragma for " ++ show var) $
226 (text "Expected type" <+> ppr vty)
227 $$
228 (text "Inferred type" <+> ppr vdty)
229
230 -- Make the vectorised version of binding's name, and set the unfolding used for inlining
231 ; var' <- liftM (`setIdUnfoldingLazily` unfolding)
232 $ cloneId mkVectOcc var vty
233
234 -- Add the mapping between the plain and vectorised name to the state.
235 ; defGlobalVar var var'
236
237 ; return var'
238 }
239 where
240 unfolding = case inline of
241 Inline arity -> mkInlineUnfolding (Just arity) expr
242 DontInline -> noUnfolding
243
244 -- | Vectorise the RHS of a top-level binding, in an empty local environment.
245 --
246 -- We need to distinguish three cases:
247 --
248 -- (1) We have a (non-scalar) vectorisation declaration for the variable (which explicitly provides
249 -- vectorised code implemented by the user)
250 -- => no automatic vectorisation & instead use the user-supplied code
251 --
252 -- (2) We have a scalar vectorisation declaration for the variable
253 -- => generate vectorised code that uses a scalar 'map'/'zipWith' to lift the computation
254 --
255 -- (3) There is no vectorisation declaration for the variable
256 -- => perform automatic vectorisation of the RHS
257 --
258 vectTopRhs :: [Var] -- ^ Names of all functions in the rec block
259 -> Var -- ^ Name of the binding.
260 -> CoreExpr -- ^ Body of the binding.
261 -> VM ( Inline -- (1) inline specification for the binding
262 , Bool -- (2) whether the right-hand side is a scalar computation
263 , CoreExpr) -- (3) the vectorised right-hand side
264 vectTopRhs recFs var expr
265 = closedV
266 $ do { traceVt ("vectTopRhs of " ++ show var) $ ppr expr
267
268 ; globalScalar <- isGlobalScalar var
269 ; vectDecl <- lookupVectDecl var
270 ; rhs globalScalar vectDecl
271 }
272 where
273 rhs _globalScalar (Just (_, expr')) -- Case (1)
274 = return (inlineMe, False, expr')
275 rhs True Nothing -- Case (2)
276 = do { expr' <- vectScalarFun True recFs expr
277 ; return (inlineMe, True, vectorised expr')
278 }
279 rhs False Nothing -- Case (3)
280 = do { let fvs = freeVars expr
281 ; (inline, isScalar, vexpr) <- inBind var $
282 vectPolyExpr (isStrongLoopBreaker $ idOccInfo var) recFs fvs
283 ; return (inline, isScalar, vectorised vexpr)
284 }
285
286 -- | Project out the vectorised version of a binding from some closure,
287 -- or return the original body if that doesn't work or the binding is scalar.
288 --
289 tryConvert :: Var -- ^ Name of the original binding (eg @foo@)
290 -> Var -- ^ Name of vectorised version of binding (eg @$vfoo@)
291 -> CoreExpr -- ^ The original body of the binding.
292 -> VM CoreExpr
293 tryConvert var vect_var rhs
294 = do { globalScalar <- isGlobalScalar var
295 ; if globalScalar
296 then
297 return rhs
298 else
299 fromVect (idType var) (Var vect_var) `orElseV` return rhs
300 }