Remove getDOptsDs; use getDynFlags instead
[ghc.git] / compiler / deSugar / DsListComp.lhs
1 %
2 % (c) The University of Glasgow 2006
3 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
4 %
5
6 Desugaring list comprehensions, monad comprehensions and array comprehensions
7
8 \begin{code}
9 {-# LANGUAGE NamedFieldPuns #-}
10
11 module DsListComp ( dsListComp, dsPArrComp, dsMonadComp ) where
12
13 #include "HsVersions.h"
14
15 import {-# SOURCE #-} DsExpr ( dsExpr, dsLExpr, dsLocalBinds )
16
17 import HsSyn
18 import TcHsSyn
19 import CoreSyn
20 import MkCore
21
22 import TcEvidence
23 import DsMonad          -- the monadery used in the desugarer
24 import DsUtils
25
26 import DynFlags
27 import CoreUtils
28 import Id
29 import Type
30 import TysWiredIn
31 import Match
32 import PrelNames
33 import SrcLoc
34 import Outputable
35 import FastString
36 import TcType
37 \end{code}
38
39 List comprehensions may be desugared in one of two ways: ``ordinary''
40 (as you would expect if you read SLPJ's book) and ``with foldr/build
41 turned on'' (if you read Gill {\em et al.}'s paper on the subject).
42
43 There will be at least one ``qualifier'' in the input.
44
45 \begin{code}
46 dsListComp :: [LStmt Id]
47            -> Type              -- Type of entire list
48            -> DsM CoreExpr
49 dsListComp lquals res_ty = do
50     dflags <- getDynFlags
51     let quals = map unLoc lquals
52         elt_ty = case tcTyConAppArgs res_ty of
53                    [elt_ty] -> elt_ty
54                    _ -> pprPanic "dsListComp" (ppr res_ty $$ ppr lquals)
55
56     if not (dopt Opt_EnableRewriteRules dflags) || dopt Opt_IgnoreInterfacePragmas dflags
57        -- Either rules are switched off, or we are ignoring what there are;
58        -- Either way foldr/build won't happen, so use the more efficient
59        -- Wadler-style desugaring
60        || isParallelComp quals
61        -- Foldr-style desugaring can't handle parallel list comprehensions
62         then deListComp quals (mkNilExpr elt_ty)
63         else mkBuildExpr elt_ty (\(c, _) (n, _) -> dfListComp c n quals)
64              -- Foldr/build should be enabled, so desugar
65              -- into foldrs and builds
66
67   where
68     -- We must test for ParStmt anywhere, not just at the head, because an extension
69     -- to list comprehensions would be to add brackets to specify the associativity
70     -- of qualifier lists. This is really easy to do by adding extra ParStmts into the
71     -- mix of possibly a single element in length, so we do this to leave the possibility open
72     isParallelComp = any isParallelStmt
73
74     isParallelStmt (ParStmt _ _ _ _) = True
75     isParallelStmt _                 = False
76
77
78 -- This function lets you desugar a inner list comprehension and a list of the binders
79 -- of that comprehension that we need in the outer comprehension into such an expression
80 -- and the type of the elements that it outputs (tuples of binders)
81 dsInnerListComp :: ([LStmt Id], [Id]) -> DsM (CoreExpr, Type)
82 dsInnerListComp (stmts, bndrs)
83   = do { expr <- dsListComp (stmts ++ [noLoc $ mkLastStmt (mkBigLHsVarTup bndrs)])
84                             (mkListTy bndrs_tuple_type)
85        ; return (expr, bndrs_tuple_type) }
86   where
87     bndrs_tuple_type = mkBigCoreVarTupTy bndrs
88
89 -- This function factors out commonality between the desugaring strategies for GroupStmt.
90 -- Given such a statement it gives you back an expression representing how to compute the transformed
91 -- list and the tuple that you need to bind from that list in order to proceed with your desugaring
92 dsTransStmt :: Stmt Id -> DsM (CoreExpr, LPat Id)
93 dsTransStmt (TransStmt { trS_form = form, trS_stmts = stmts, trS_bndrs = binderMap
94                        , trS_by = by, trS_using = using }) = do
95     let (from_bndrs, to_bndrs) = unzip binderMap
96         from_bndrs_tys  = map idType from_bndrs
97         to_bndrs_tys    = map idType to_bndrs
98         to_bndrs_tup_ty = mkBigCoreTupTy to_bndrs_tys
99
100     -- Desugar an inner comprehension which outputs a list of tuples of the "from" binders
101     (expr, from_tup_ty) <- dsInnerListComp (stmts, from_bndrs)
102
103     -- Work out what arguments should be supplied to that expression: i.e. is an extraction
104     -- function required? If so, create that desugared function and add to arguments
105     usingExpr' <- dsLExpr using
106     usingArgs <- case by of
107                    Nothing   -> return [expr]
108                    Just by_e -> do { by_e' <- dsLExpr by_e
109                                    ; lam <- matchTuple from_bndrs by_e'
110                                    ; return [lam, expr] }
111
112     -- Create an unzip function for the appropriate arity and element types and find "map"
113     unzip_stuff <- mkUnzipBind form from_bndrs_tys
114     map_id <- dsLookupGlobalId mapName
115
116     -- Generate the expressions to build the grouped list
117     let -- First we apply the grouping function to the inner list
118         inner_list_expr = mkApps usingExpr' usingArgs
119         -- Then we map our "unzip" across it to turn the lists of tuples into tuples of lists
120         -- We make sure we instantiate the type variable "a" to be a list of "from" tuples and
121         -- the "b" to be a tuple of "to" lists!
122         -- Then finally we bind the unzip function around that expression
123         bound_unzipped_inner_list_expr
124           = case unzip_stuff of
125               Nothing -> inner_list_expr
126               Just (unzip_fn, unzip_rhs) -> Let (Rec [(unzip_fn, unzip_rhs)]) $
127                                             mkApps (Var map_id) $
128                                             [ Type (mkListTy from_tup_ty)
129                                             , Type to_bndrs_tup_ty
130                                             , Var unzip_fn
131                                             , inner_list_expr]
132
133     -- Build a pattern that ensures the consumer binds into the NEW binders,
134     -- which hold lists rather than single values
135     let pat = mkBigLHsVarPatTup to_bndrs
136     return (bound_unzipped_inner_list_expr, pat)
137
138 dsTransStmt _ = panic "dsTransStmt: Not given a TransStmt"
139 \end{code}
140
141 %************************************************************************
142 %*                                                                      *
143 \subsection[DsListComp-ordinary]{Ordinary desugaring of list comprehensions}
144 %*                                                                      *
145 %************************************************************************
146
147 Just as in Phil's chapter~7 in SLPJ, using the rules for
148 optimally-compiled list comprehensions.  This is what Kevin followed
149 as well, and I quite happily do the same.  The TQ translation scheme
150 transforms a list of qualifiers (either boolean expressions or
151 generators) into a single expression which implements the list
152 comprehension.  Because we are generating 2nd-order polymorphic
153 lambda-calculus, calls to NIL and CONS must be applied to a type
154 argument, as well as their usual value arguments.
155 \begin{verbatim}
156 TE << [ e | qs ] >>  =  TQ << [ e | qs ] ++ Nil (typeOf e) >>
157
158 (Rule C)
159 TQ << [ e | ] ++ L >> = Cons (typeOf e) TE <<e>> TE <<L>>
160
161 (Rule B)
162 TQ << [ e | b , qs ] ++ L >> =
163     if TE << b >> then TQ << [ e | qs ] ++ L >> else TE << L >>
164
165 (Rule A')
166 TQ << [ e | p <- L1, qs ]  ++  L2 >> =
167   letrec
168     h = \ u1 ->
169           case u1 of
170             []        ->  TE << L2 >>
171             (u2 : u3) ->
172                   (( \ TE << p >> -> ( TQ << [e | qs]  ++  (h u3) >> )) u2)
173                     [] (h u3)
174   in
175     h ( TE << L1 >> )
176
177 "h", "u1", "u2", and "u3" are new variables.
178 \end{verbatim}
179
180 @deListComp@ is the TQ translation scheme.  Roughly speaking, @dsExpr@
181 is the TE translation scheme.  Note that we carry around the @L@ list
182 already desugared.  @dsListComp@ does the top TE rule mentioned above.
183
184 To the above, we add an additional rule to deal with parallel list
185 comprehensions.  The translation goes roughly as follows:
186      [ e | p1 <- e11, let v1 = e12, p2 <- e13
187          | q1 <- e21, let v2 = e22, q2 <- e23]
188      =>
189      [ e | ((x1, .., xn), (y1, ..., ym)) <-
190                zip [(x1,..,xn) | p1 <- e11, let v1 = e12, p2 <- e13]
191                    [(y1,..,ym) | q1 <- e21, let v2 = e22, q2 <- e23]]
192 where (x1, .., xn) are the variables bound in p1, v1, p2
193       (y1, .., ym) are the variables bound in q1, v2, q2
194
195 In the translation below, the ParStmt branch translates each parallel branch
196 into a sub-comprehension, and desugars each independently.  The resulting lists
197 are fed to a zip function, we create a binding for all the variables bound in all
198 the comprehensions, and then we hand things off the the desugarer for bindings.
199 The zip function is generated here a) because it's small, and b) because then we
200 don't have to deal with arbitrary limits on the number of zip functions in the
201 prelude, nor which library the zip function came from.
202 The introduced tuples are Boxed, but only because I couldn't get it to work
203 with the Unboxed variety.
204
205 \begin{code}
206
207 deListComp :: [Stmt Id] -> CoreExpr -> DsM CoreExpr
208
209 deListComp [] _ = panic "deListComp"
210
211 deListComp (LastStmt body _ : quals) list
212   =     -- Figure 7.4, SLPJ, p 135, rule C above
213     ASSERT( null quals )
214     do { core_body <- dsLExpr body
215        ; return (mkConsExpr (exprType core_body) core_body list) }
216
217         -- Non-last: must be a guard
218 deListComp (ExprStmt guard _ _ _ : quals) list = do  -- rule B above
219     core_guard <- dsLExpr guard
220     core_rest <- deListComp quals list
221     return (mkIfThenElse core_guard core_rest list)
222
223 -- [e | let B, qs] = let B in [e | qs]
224 deListComp (LetStmt binds : quals) list = do
225     core_rest <- deListComp quals list
226     dsLocalBinds binds core_rest
227
228 deListComp (stmt@(TransStmt {}) : quals) list = do
229     (inner_list_expr, pat) <- dsTransStmt stmt
230     deBindComp pat inner_list_expr quals list
231
232 deListComp (BindStmt pat list1 _ _ : quals) core_list2 = do -- rule A' above
233     core_list1 <- dsLExpr list1
234     deBindComp pat core_list1 quals core_list2
235
236 deListComp (ParStmt stmtss_w_bndrs _ _ _ : quals) list
237   = do { exps_and_qual_tys <- mapM dsInnerListComp stmtss_w_bndrs
238        ; let (exps, qual_tys) = unzip exps_and_qual_tys
239
240        ; (zip_fn, zip_rhs) <- mkZipBind qual_tys
241
242         -- Deal with [e | pat <- zip l1 .. ln] in example above
243        ; deBindComp pat (Let (Rec [(zip_fn, zip_rhs)]) (mkApps (Var zip_fn) exps))
244                     quals list }
245   where
246         bndrs_s = map snd stmtss_w_bndrs
247
248         -- pat is the pattern ((x1,..,xn), (y1,..,ym)) in the example above
249         pat  = mkBigLHsPatTup pats
250         pats = map mkBigLHsVarPatTup bndrs_s
251
252 deListComp (RecStmt {} : _) _ = panic "deListComp RecStmt"
253 \end{code}
254
255
256 \begin{code}
257 deBindComp :: OutPat Id
258            -> CoreExpr
259            -> [Stmt Id]
260            -> CoreExpr
261            -> DsM (Expr Id)
262 deBindComp pat core_list1 quals core_list2 = do
263     let
264         u3_ty@u1_ty = exprType core_list1       -- two names, same thing
265
266         -- u1_ty is a [alpha] type, and u2_ty = alpha
267         u2_ty = hsLPatType pat
268
269         res_ty = exprType core_list2
270         h_ty   = u1_ty `mkFunTy` res_ty
271
272     [h, u1, u2, u3] <- newSysLocalsDs [h_ty, u1_ty, u2_ty, u3_ty]
273
274     -- the "fail" value ...
275     let
276         core_fail   = App (Var h) (Var u3)
277         letrec_body = App (Var h) core_list1
278
279     rest_expr <- deListComp quals core_fail
280     core_match <- matchSimply (Var u2) (StmtCtxt ListComp) pat rest_expr core_fail
281
282     let
283         rhs = Lam u1 $
284               Case (Var u1) u1 res_ty
285                    [(DataAlt nilDataCon,  [],       core_list2),
286                     (DataAlt consDataCon, [u2, u3], core_match)]
287                         -- Increasing order of tag
288
289     return (Let (Rec [(h, rhs)]) letrec_body)
290 \end{code}
291
292 %************************************************************************
293 %*                                                                      *
294 \subsection[DsListComp-foldr-build]{Foldr/Build desugaring of list comprehensions}
295 %*                                                                      *
296 %************************************************************************
297
298 @dfListComp@ are the rules used with foldr/build turned on:
299
300 \begin{verbatim}
301 TE[ e | ]            c n = c e n
302 TE[ e | b , q ]      c n = if b then TE[ e | q ] c n else n
303 TE[ e | p <- l , q ] c n = let
304                                 f = \ x b -> case x of
305                                                   p -> TE[ e | q ] c b
306                                                   _ -> b
307                            in
308                            foldr f n l
309 \end{verbatim}
310
311 \begin{code}
312 dfListComp :: Id -> Id -- 'c' and 'n'
313         -> [Stmt Id]   -- the rest of the qual's
314         -> DsM CoreExpr
315
316 dfListComp _ _ [] = panic "dfListComp"
317
318 dfListComp c_id n_id (LastStmt body _ : quals)
319   = ASSERT( null quals )
320     do { core_body <- dsLExpr body
321        ; return (mkApps (Var c_id) [core_body, Var n_id]) }
322
323         -- Non-last: must be a guard
324 dfListComp c_id n_id (ExprStmt guard _ _ _  : quals) = do
325     core_guard <- dsLExpr guard
326     core_rest <- dfListComp c_id n_id quals
327     return (mkIfThenElse core_guard core_rest (Var n_id))
328
329 dfListComp c_id n_id (LetStmt binds : quals) = do
330     -- new in 1.3, local bindings
331     core_rest <- dfListComp c_id n_id quals
332     dsLocalBinds binds core_rest
333
334 dfListComp c_id n_id (stmt@(TransStmt {}) : quals) = do
335     (inner_list_expr, pat) <- dsTransStmt stmt
336     -- Anyway, we bind the newly grouped list via the generic binding function
337     dfBindComp c_id n_id (pat, inner_list_expr) quals
338
339 dfListComp c_id n_id (BindStmt pat list1 _ _ : quals) = do
340     -- evaluate the two lists
341     core_list1 <- dsLExpr list1
342
343     -- Do the rest of the work in the generic binding builder
344     dfBindComp c_id n_id (pat, core_list1) quals
345
346 dfListComp _ _ (ParStmt {} : _) = panic "dfListComp ParStmt"
347 dfListComp _ _ (RecStmt {} : _) = panic "dfListComp RecStmt"
348
349 dfBindComp :: Id -> Id          -- 'c' and 'n'
350        -> (LPat Id, CoreExpr)
351            -> [Stmt Id]                 -- the rest of the qual's
352            -> DsM CoreExpr
353 dfBindComp c_id n_id (pat, core_list1) quals = do
354     -- find the required type
355     let x_ty   = hsLPatType pat
356         b_ty   = idType n_id
357
358     -- create some new local id's
359     [b, x] <- newSysLocalsDs [b_ty, x_ty]
360
361     -- build rest of the comprehesion
362     core_rest <- dfListComp c_id b quals
363
364     -- build the pattern match
365     core_expr <- matchSimply (Var x) (StmtCtxt ListComp)
366                 pat core_rest (Var b)
367
368     -- now build the outermost foldr, and return
369     mkFoldrExpr x_ty b_ty (mkLams [x, b] core_expr) (Var n_id) core_list1
370 \end{code}
371
372 %************************************************************************
373 %*                                                                      *
374 \subsection[DsFunGeneration]{Generation of zip/unzip functions for use in desugaring}
375 %*                                                                      *
376 %************************************************************************
377
378 \begin{code}
379
380 mkZipBind :: [Type] -> DsM (Id, CoreExpr)
381 -- mkZipBind [t1, t2]
382 -- = (zip, \as1:[t1] as2:[t2]
383 --         -> case as1 of
384 --              [] -> []
385 --              (a1:as'1) -> case as2 of
386 --                              [] -> []
387 --                              (a2:as'2) -> (a1, a2) : zip as'1 as'2)]
388
389 mkZipBind elt_tys = do
390     ass  <- mapM newSysLocalDs  elt_list_tys
391     as'  <- mapM newSysLocalDs  elt_tys
392     as's <- mapM newSysLocalDs  elt_list_tys
393
394     zip_fn <- newSysLocalDs zip_fn_ty
395
396     let inner_rhs = mkConsExpr elt_tuple_ty
397                         (mkBigCoreVarTup as')
398                         (mkVarApps (Var zip_fn) as's)
399         zip_body  = foldr mk_case inner_rhs (zip3 ass as' as's)
400
401     return (zip_fn, mkLams ass zip_body)
402   where
403     elt_list_tys      = map mkListTy elt_tys
404     elt_tuple_ty      = mkBigCoreTupTy elt_tys
405     elt_tuple_list_ty = mkListTy elt_tuple_ty
406
407     zip_fn_ty         = mkFunTys elt_list_tys elt_tuple_list_ty
408
409     mk_case (as, a', as') rest
410           = Case (Var as) as elt_tuple_list_ty
411                   [(DataAlt nilDataCon,  [],        mkNilExpr elt_tuple_ty),
412                    (DataAlt consDataCon, [a', as'], rest)]
413                         -- Increasing order of tag
414
415
416 mkUnzipBind :: TransForm -> [Type] -> DsM (Maybe (Id, CoreExpr))
417 -- mkUnzipBind [t1, t2]
418 -- = (unzip, \ys :: [(t1, t2)] -> foldr (\ax :: (t1, t2) axs :: ([t1], [t2])
419 --     -> case ax of
420 --      (x1, x2) -> case axs of
421 --                (xs1, xs2) -> (x1 : xs1, x2 : xs2))
422 --      ([], [])
423 --      ys)
424 --
425 -- We use foldr here in all cases, even if rules are turned off, because we may as well!
426 mkUnzipBind ThenForm _
427  = return Nothing    -- No unzipping for ThenForm
428 mkUnzipBind _ elt_tys
429   = do { ax  <- newSysLocalDs elt_tuple_ty
430        ; axs <- newSysLocalDs elt_list_tuple_ty
431        ; ys  <- newSysLocalDs elt_tuple_list_ty
432        ; xs  <- mapM newSysLocalDs elt_tys
433        ; xss <- mapM newSysLocalDs elt_list_tys
434
435        ; unzip_fn <- newSysLocalDs unzip_fn_ty
436
437        ; [us1, us2] <- sequence [newUniqueSupply, newUniqueSupply]
438
439        ; let nil_tuple = mkBigCoreTup (map mkNilExpr elt_tys)
440              concat_expressions = map mkConcatExpression (zip3 elt_tys (map Var xs) (map Var xss))
441              tupled_concat_expression = mkBigCoreTup concat_expressions
442
443              folder_body_inner_case = mkTupleCase us1 xss tupled_concat_expression axs (Var axs)
444              folder_body_outer_case = mkTupleCase us2 xs folder_body_inner_case ax (Var ax)
445              folder_body = mkLams [ax, axs] folder_body_outer_case
446
447        ; unzip_body <- mkFoldrExpr elt_tuple_ty elt_list_tuple_ty folder_body nil_tuple (Var ys)
448        ; return (Just (unzip_fn, mkLams [ys] unzip_body)) }
449   where
450     elt_tuple_ty       = mkBigCoreTupTy elt_tys
451     elt_tuple_list_ty  = mkListTy elt_tuple_ty
452     elt_list_tys       = map mkListTy elt_tys
453     elt_list_tuple_ty  = mkBigCoreTupTy elt_list_tys
454
455     unzip_fn_ty        = elt_tuple_list_ty `mkFunTy` elt_list_tuple_ty
456
457     mkConcatExpression (list_element_ty, head, tail) = mkConsExpr list_element_ty head tail
458 \end{code}
459
460 %************************************************************************
461 %*                                                                      *
462 \subsection[DsPArrComp]{Desugaring of array comprehensions}
463 %*                                                                      *
464 %************************************************************************
465
466 \begin{code}
467
468 -- entry point for desugaring a parallel array comprehension
469 --
470 --   [:e | qss:] = <<[:e | qss:]>> () [:():]
471 --
472 dsPArrComp :: [Stmt Id]
473             -> DsM CoreExpr
474
475 -- Special case for parallel comprehension
476 dsPArrComp (ParStmt qss _ _ _ : quals) = dePArrParComp qss quals
477
478 -- Special case for simple generators:
479 --
480 --  <<[:e' | p <- e, qs:]>> = <<[: e' | qs :]>> p e
481 --
482 -- if matching again p cannot fail, or else
483 --
484 --  <<[:e' | p <- e, qs:]>> =
485 --    <<[:e' | qs:]>> p (filterP (\x -> case x of {p -> True; _ -> False}) e)
486 --
487 dsPArrComp (BindStmt p e _ _ : qs) = do
488     filterP <- dsDPHBuiltin filterPVar
489     ce <- dsLExpr e
490     let ety'ce  = parrElemType ce
491         false   = Var falseDataConId
492         true    = Var trueDataConId
493     v <- newSysLocalDs ety'ce
494     pred <- matchSimply (Var v) (StmtCtxt PArrComp) p true false
495     let gen | isIrrefutableHsPat p = ce
496             | otherwise            = mkApps (Var filterP) [Type ety'ce, mkLams [v] pred, ce]
497     dePArrComp qs p gen
498
499 dsPArrComp qs = do -- no ParStmt in `qs'
500     sglP <- dsDPHBuiltin singletonPVar
501     let unitArray = mkApps (Var sglP) [Type unitTy, mkCoreTup []]
502     dePArrComp qs (noLoc $ WildPat unitTy) unitArray
503
504
505
506 -- the work horse
507 --
508 dePArrComp :: [Stmt Id]
509            -> LPat Id           -- the current generator pattern
510            -> CoreExpr          -- the current generator expression
511            -> DsM CoreExpr
512
513 dePArrComp [] _ _ = panic "dePArrComp"
514
515 --
516 --  <<[:e' | :]>> pa ea = mapP (\pa -> e') ea
517 --
518 dePArrComp (LastStmt e' _ : quals) pa cea
519   = ASSERT( null quals )
520     do { mapP <- dsDPHBuiltin mapPVar
521        ; let ty = parrElemType cea
522        ; (clam, ty'e') <- deLambda ty pa e'
523        ; return $ mkApps (Var mapP) [Type ty, Type ty'e', clam, cea] }
524 --
525 --  <<[:e' | b, qs:]>> pa ea = <<[:e' | qs:]>> pa (filterP (\pa -> b) ea)
526 --
527 dePArrComp (ExprStmt b _ _ _ : qs) pa cea = do
528     filterP <- dsDPHBuiltin filterPVar
529     let ty = parrElemType cea
530     (clam,_) <- deLambda ty pa b
531     dePArrComp qs pa (mkApps (Var filterP) [Type ty, clam, cea])
532
533 --
534 --  <<[:e' | p <- e, qs:]>> pa ea =
535 --    let ef = \pa -> e
536 --    in
537 --    <<[:e' | qs:]>> (pa, p) (crossMap ea ef)
538 --
539 -- if matching again p cannot fail, or else
540 --
541 --  <<[:e' | p <- e, qs:]>> pa ea =
542 --    let ef = \pa -> filterP (\x -> case x of {p -> True; _ -> False}) e
543 --    in
544 --    <<[:e' | qs:]>> (pa, p) (crossMapP ea ef)
545 --
546 dePArrComp (BindStmt p e _ _ : qs) pa cea = do
547     filterP <- dsDPHBuiltin filterPVar
548     crossMapP <- dsDPHBuiltin crossMapPVar
549     ce <- dsLExpr e
550     let ety'cea = parrElemType cea
551         ety'ce  = parrElemType ce
552         false   = Var falseDataConId
553         true    = Var trueDataConId
554     v <- newSysLocalDs ety'ce
555     pred <- matchSimply (Var v) (StmtCtxt PArrComp) p true false
556     let cef | isIrrefutableHsPat p = ce
557             | otherwise            = mkApps (Var filterP) [Type ety'ce, mkLams [v] pred, ce]
558     (clam, _) <- mkLambda ety'cea pa cef
559     let ety'cef = ety'ce                    -- filter doesn't change the element type
560         pa'     = mkLHsPatTup [pa, p]
561
562     dePArrComp qs pa' (mkApps (Var crossMapP)
563                                  [Type ety'cea, Type ety'cef, cea, clam])
564 --
565 --  <<[:e' | let ds, qs:]>> pa ea =
566 --    <<[:e' | qs:]>> (pa, (x_1, ..., x_n))
567 --                    (mapP (\v@pa -> let ds in (v, (x_1, ..., x_n))) ea)
568 --  where
569 --    {x_1, ..., x_n} = DV (ds)         -- Defined Variables
570 --
571 dePArrComp (LetStmt ds : qs) pa cea = do
572     mapP <- dsDPHBuiltin mapPVar
573     let xs     = collectLocalBinders ds
574         ty'cea = parrElemType cea
575     v <- newSysLocalDs ty'cea
576     clet <- dsLocalBinds ds (mkCoreTup (map Var xs))
577     let'v <- newSysLocalDs (exprType clet)
578     let projBody = mkCoreLet (NonRec let'v clet) $
579                    mkCoreTup [Var v, Var let'v]
580         errTy    = exprType projBody
581         errMsg   = ptext (sLit "DsListComp.dePArrComp: internal error!")
582     cerr <- mkErrorAppDs pAT_ERROR_ID errTy errMsg
583     ccase <- matchSimply (Var v) (StmtCtxt PArrComp) pa projBody cerr
584     let pa'    = mkLHsPatTup [pa, mkLHsPatTup (map nlVarPat xs)]
585         proj   = mkLams [v] ccase
586     dePArrComp qs pa' (mkApps (Var mapP)
587                                    [Type ty'cea, Type errTy, proj, cea])
588 --
589 -- The parser guarantees that parallel comprehensions can only appear as
590 -- singeltons qualifier lists, which we already special case in the caller.
591 -- So, encountering one here is a bug.
592 --
593 dePArrComp (ParStmt _ _ _ _ : _) _ _ =
594   panic "DsListComp.dePArrComp: malformed comprehension AST: ParStmt"
595 dePArrComp (TransStmt {} : _) _ _ = panic "DsListComp.dePArrComp: TransStmt"
596 dePArrComp (RecStmt   {} : _) _ _ = panic "DsListComp.dePArrComp: RecStmt"
597
598 --  <<[:e' | qs | qss:]>> pa ea =
599 --    <<[:e' | qss:]>> (pa, (x_1, ..., x_n))
600 --                     (zipP ea <<[:(x_1, ..., x_n) | qs:]>>)
601 --    where
602 --      {x_1, ..., x_n} = DV (qs)
603 --
604 dePArrParComp :: [([LStmt Id], [Id])] -> [Stmt Id] -> DsM CoreExpr
605 dePArrParComp qss quals = do
606     (pQss, ceQss) <- deParStmt qss
607     dePArrComp quals pQss ceQss
608   where
609     deParStmt []             =
610       -- empty parallel statement lists have no source representation
611       panic "DsListComp.dePArrComp: Empty parallel list comprehension"
612     deParStmt ((qs, xs):qss) = do        -- first statement
613       let res_expr = mkLHsVarTuple xs
614       cqs <- dsPArrComp (map unLoc qs ++ [mkLastStmt res_expr])
615       parStmts qss (mkLHsVarPatTup xs) cqs
616     ---
617     parStmts []             pa cea = return (pa, cea)
618     parStmts ((qs, xs):qss) pa cea = do  -- subsequent statements (zip'ed)
619       zipP <- dsDPHBuiltin zipPVar
620       let pa'      = mkLHsPatTup [pa, mkLHsVarPatTup xs]
621           ty'cea   = parrElemType cea
622           res_expr = mkLHsVarTuple xs
623       cqs <- dsPArrComp (map unLoc qs ++ [mkLastStmt res_expr])
624       let ty'cqs = parrElemType cqs
625           cea'   = mkApps (Var zipP) [Type ty'cea, Type ty'cqs, cea, cqs]
626       parStmts qss pa' cea'
627
628 -- generate Core corresponding to `\p -> e'
629 --
630 deLambda :: Type                        -- type of the argument
631           -> LPat Id                    -- argument pattern
632           -> LHsExpr Id                 -- body
633           -> DsM (CoreExpr, Type)
634 deLambda ty p e =
635     mkLambda ty p =<< dsLExpr e
636
637 -- generate Core for a lambda pattern match, where the body is already in Core
638 --
639 mkLambda :: Type                        -- type of the argument
640          -> LPat Id                     -- argument pattern
641          -> CoreExpr                    -- desugared body
642          -> DsM (CoreExpr, Type)
643 mkLambda ty p ce = do
644     v <- newSysLocalDs ty
645     let errMsg = ptext (sLit "DsListComp.deLambda: internal error!")
646         ce'ty  = exprType ce
647     cerr <- mkErrorAppDs pAT_ERROR_ID ce'ty errMsg
648     res <- matchSimply (Var v) (StmtCtxt PArrComp) p ce cerr
649     return (mkLams [v] res, ce'ty)
650
651 -- obtain the element type of the parallel array produced by the given Core
652 -- expression
653 --
654 parrElemType   :: CoreExpr -> Type
655 parrElemType e  =
656   case splitTyConApp_maybe (exprType e) of
657     Just (tycon, [ty]) | tycon == parrTyCon -> ty
658     _                                                     -> panic
659       "DsListComp.parrElemType: not a parallel array type"
660 \end{code}
661
662 Translation for monad comprehensions
663
664 \begin{code}
665 -- Entry point for monad comprehension desugaring
666 dsMonadComp :: [LStmt Id] -> DsM CoreExpr
667 dsMonadComp stmts = dsMcStmts stmts
668
669 dsMcStmts :: [LStmt Id] -> DsM CoreExpr
670 dsMcStmts []                    = panic "dsMcStmts"
671 dsMcStmts (L loc stmt : lstmts) = putSrcSpanDs loc (dsMcStmt stmt lstmts)
672
673 ---------------
674 dsMcStmt :: Stmt Id -> [LStmt Id] -> DsM CoreExpr
675
676 dsMcStmt (LastStmt body ret_op) stmts
677   = ASSERT( null stmts )
678     do { body' <- dsLExpr body
679        ; ret_op' <- dsExpr ret_op
680        ; return (App ret_op' body') }
681
682 --   [ .. | let binds, stmts ]
683 dsMcStmt (LetStmt binds) stmts
684   = do { rest <- dsMcStmts stmts
685        ; dsLocalBinds binds rest }
686
687 --   [ .. | a <- m, stmts ]
688 dsMcStmt (BindStmt pat rhs bind_op fail_op) stmts
689   = do { rhs' <- dsLExpr rhs
690        ; dsMcBindStmt pat rhs' bind_op fail_op stmts }
691
692 -- Apply `guard` to the `exp` expression
693 --
694 --   [ .. | exp, stmts ]
695 --
696 dsMcStmt (ExprStmt exp then_exp guard_exp _) stmts
697   = do { exp'       <- dsLExpr exp
698        ; guard_exp' <- dsExpr guard_exp
699        ; then_exp'  <- dsExpr then_exp
700        ; rest       <- dsMcStmts stmts
701        ; return $ mkApps then_exp' [ mkApps guard_exp' [exp']
702                                    , rest ] }
703
704 -- Group statements desugar like this:
705 --
706 --   [| (q, then group by e using f); rest |]
707 --   --->  f {qt} (\qv -> e) [| q; return qv |] >>= \ n_tup ->
708 --         case unzip n_tup of qv' -> [| rest |]
709 --
710 -- where   variables (v1:t1, ..., vk:tk) are bound by q
711 --         qv = (v1, ..., vk)
712 --         qt = (t1, ..., tk)
713 --         (>>=) :: m2 a -> (a -> m3 b) -> m3 b
714 --         f :: forall a. (a -> t) -> m1 a -> m2 (n a)
715 --         n_tup :: n qt
716 --         unzip :: n qt -> (n t1, ..., n tk)    (needs Functor n)
717
718 dsMcStmt (TransStmt { trS_stmts = stmts, trS_bndrs = bndrs
719                     , trS_by = by, trS_using = using
720                     , trS_ret = return_op, trS_bind = bind_op
721                     , trS_fmap = fmap_op, trS_form = form }) stmts_rest
722   = do { let (from_bndrs, to_bndrs) = unzip bndrs
723              from_bndr_tys          = map idType from_bndrs     -- Types ty
724
725        -- Desugar an inner comprehension which outputs a list of tuples of the "from" binders
726        ; expr <- dsInnerMonadComp stmts from_bndrs return_op
727
728        -- Work out what arguments should be supplied to that expression: i.e. is an extraction
729        -- function required? If so, create that desugared function and add to arguments
730        ; usingExpr' <- dsLExpr using
731        ; usingArgs <- case by of
732                         Nothing   -> return [expr]
733                         Just by_e -> do { by_e' <- dsLExpr by_e
734                                         ; lam <- matchTuple from_bndrs by_e'
735                                         ; return [lam, expr] }
736
737        -- Generate the expressions to build the grouped list
738        -- Build a pattern that ensures the consumer binds into the NEW binders,
739        -- which hold monads rather than single values
740        ; bind_op' <- dsExpr bind_op
741        ; let bind_ty  = exprType bind_op'    -- m2 (n (a,b,c)) -> (n (a,b,c) -> r1) -> r2
742              n_tup_ty = funArgTy $ funArgTy $ funResultTy bind_ty   -- n (a,b,c)
743              tup_n_ty = mkBigCoreVarTupTy to_bndrs
744
745        ; body       <- dsMcStmts stmts_rest
746        ; n_tup_var  <- newSysLocalDs n_tup_ty
747        ; tup_n_var  <- newSysLocalDs tup_n_ty
748        ; tup_n_expr <- mkMcUnzipM form fmap_op n_tup_var from_bndr_tys
749        ; us         <- newUniqueSupply
750        ; let rhs'  = mkApps usingExpr' usingArgs
751              body' = mkTupleCase us to_bndrs body tup_n_var tup_n_expr
752
753        ; return (mkApps bind_op' [rhs', Lam n_tup_var body']) }
754
755 -- Parallel statements. Use `Control.Monad.Zip.mzip` to zip parallel
756 -- statements, for example:
757 --
758 --   [ body | qs1 | qs2 | qs3 ]
759 --     ->  [ body | (bndrs1, (bndrs2, bndrs3))
760 --                     <- [bndrs1 | qs1] `mzip` ([bndrs2 | qs2] `mzip` [bndrs3 | qs3]) ]
761 --
762 -- where `mzip` has type
763 --   mzip :: forall a b. m a -> m b -> m (a,b)
764 -- NB: we need a polymorphic mzip because we call it several times
765
766 dsMcStmt (ParStmt pairs mzip_op bind_op return_op) stmts_rest
767  = do  { exps_w_tys  <- mapM ds_inner pairs   -- Pairs (exp :: m ty, ty)
768        ; mzip_op'    <- dsExpr mzip_op
769
770        ; let -- The pattern variables
771              pats = map (mkBigLHsVarPatTup . snd) pairs
772              -- Pattern with tuples of variables
773              -- [v1,v2,v3]  =>  (v1, (v2, v3))
774              pat = foldr1 (\p1 p2 -> mkLHsPatTup [p1, p2]) pats
775              (rhs, _) = foldr1 (\(e1,t1) (e2,t2) ->
776                                  (mkApps mzip_op' [Type t1, Type t2, e1, e2],
777                                   mkBoxedTupleTy [t1,t2]))
778                                exps_w_tys
779
780        ; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest }
781   where
782     ds_inner (stmts, bndrs) = do { exp <- dsInnerMonadComp stmts bndrs mono_ret_op
783                                  ; return (exp, tup_ty) }
784        where
785          mono_ret_op = HsWrap (WpTyApp tup_ty) return_op
786          tup_ty      = mkBigCoreVarTupTy bndrs
787
788 dsMcStmt stmt _ = pprPanic "dsMcStmt: unexpected stmt" (ppr stmt)
789
790
791 matchTuple :: [Id] -> CoreExpr -> DsM CoreExpr
792 -- (matchTuple [a,b,c] body)
793 --       returns the Core term
794 --  \x. case x of (a,b,c) -> body
795 matchTuple ids body
796   = do { us <- newUniqueSupply
797        ; tup_id <- newSysLocalDs (mkBigCoreVarTupTy ids)
798        ; return (Lam tup_id $ mkTupleCase us ids body tup_id (Var tup_id)) }
799
800 -- general `rhs' >>= \pat -> stmts` desugaring where `rhs'` is already a
801 -- desugared `CoreExpr`
802 dsMcBindStmt :: LPat Id
803              -> CoreExpr        -- ^ the desugared rhs of the bind statement
804              -> SyntaxExpr Id
805              -> SyntaxExpr Id
806              -> [LStmt Id]
807              -> DsM CoreExpr
808 dsMcBindStmt pat rhs' bind_op fail_op stmts
809   = do  { body     <- dsMcStmts stmts
810         ; bind_op' <- dsExpr bind_op
811         ; var      <- selectSimpleMatchVarL pat
812         ; let bind_ty = exprType bind_op'       -- rhs -> (pat -> res1) -> res2
813               res1_ty = funResultTy (funArgTy (funResultTy bind_ty))
814         ; match <- matchSinglePat (Var var) (StmtCtxt DoExpr) pat
815                                   res1_ty (cantFailMatchResult body)
816         ; match_code <- handle_failure pat match fail_op
817         ; return (mkApps bind_op' [rhs', Lam var match_code]) }
818
819   where
820     -- In a monad comprehension expression, pattern-match failure just calls
821     -- the monadic `fail` rather than throwing an exception
822     handle_failure pat match fail_op
823       | matchCanFail match
824         = do { fail_op' <- dsExpr fail_op
825              ; fail_msg <- mkStringExpr (mk_fail_msg pat)
826              ; extractMatchResult match (App fail_op' fail_msg) }
827       | otherwise
828         = extractMatchResult match (error "It can't fail")
829
830     mk_fail_msg :: Located e -> String
831     mk_fail_msg pat = "Pattern match failure in monad comprehension at " ++
832                       showSDoc (ppr (getLoc pat))
833
834 -- Desugar nested monad comprehensions, for example in `then..` constructs
835 --    dsInnerMonadComp quals [a,b,c] ret_op
836 -- returns the desugaring of
837 --       [ (a,b,c) | quals ]
838
839 dsInnerMonadComp :: [LStmt Id]
840                  -> [Id]        -- Return a tuple of these variables
841                  -> HsExpr Id   -- The monomorphic "return" operator
842                  -> DsM CoreExpr
843 dsInnerMonadComp stmts bndrs ret_op
844   = dsMcStmts (stmts ++ [noLoc (LastStmt (mkBigLHsVarTup bndrs) ret_op)])
845
846 -- The `unzip` function for `GroupStmt` in a monad comprehensions
847 --
848 --   unzip :: m (a,b,..) -> (m a,m b,..)
849 --   unzip m_tuple = ( liftM selN1 m_tuple
850 --                   , liftM selN2 m_tuple
851 --                   , .. )
852 --
853 --   mkMcUnzipM fmap ys [t1, t2]
854 --     = ( fmap (selN1 :: (t1, t2) -> t1) ys
855 --       , fmap (selN2 :: (t1, t2) -> t2) ys )
856
857 mkMcUnzipM :: TransForm
858            -> SyntaxExpr TcId   -- fmap
859            -> Id                -- Of type n (a,b,c)
860            -> [Type]            -- [a,b,c]
861            -> DsM CoreExpr      -- Of type (n a, n b, n c)
862 mkMcUnzipM ThenForm _ ys _
863   = return (Var ys) -- No unzipping to do
864
865 mkMcUnzipM _ fmap_op ys elt_tys
866   = do { fmap_op' <- dsExpr fmap_op
867        ; xs       <- mapM newSysLocalDs elt_tys
868        ; let tup_ty = mkBigCoreTupTy elt_tys
869        ; tup_xs   <- newSysLocalDs tup_ty
870
871        ; let mk_elt i = mkApps fmap_op'  -- fmap :: forall a b. (a -> b) -> n a -> n b
872                            [ Type tup_ty, Type (elt_tys !! i)
873                            , mk_sel i, Var ys]
874
875              mk_sel n = Lam tup_xs $
876                         mkTupleSelector xs (xs !! n) tup_xs (Var tup_xs)
877
878        ; return (mkBigCoreTup (map mk_elt [0..length elt_tys - 1])) }
879 \end{code}