fix extraction of command stack of arguments of arrow "forms" (fixes #4236)
[ghc.git] / compiler / typecheck / TcArrows.lhs
1 %
2 % (c) The University of Glasgow 2006
3 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
4 %
5 Typecheck arrow notation
6
7 \begin{code}
8 module TcArrows ( tcProc ) where
9
10 import {-# SOURCE #-}   TcExpr( tcMonoExpr, tcInferRho )
11
12 import HsSyn
13 import TcHsSyn
14
15 import TcMatches
16
17 import TcType
18 import TcMType
19 import TcBinds
20 import TcSimplify
21 import TcPat
22 import TcUnify
23 import TcRnMonad
24 import Coercion
25 import Inst
26 import Name
27 import TysWiredIn
28 import VarSet 
29 import TysPrim
30
31 import SrcLoc
32 import Outputable
33 import FastString
34 import Util
35
36 import Control.Monad
37 \end{code}
38
39 %************************************************************************
40 %*                                                                      *
41                 Proc    
42 %*                                                                      *
43 %************************************************************************
44
45 \begin{code}
46 tcProc :: InPat Name -> LHsCmdTop Name          -- proc pat -> expr
47        -> BoxyRhoType                           -- Expected type of whole proc expression
48        -> TcM (OutPat TcId, LHsCmdTop TcId, CoercionI)
49
50 tcProc pat cmd exp_ty
51   = newArrowScope $
52     do  { ((exp_ty1, res_ty), coi) <- boxySplitAppTy exp_ty 
53         ; ((arr_ty, arg_ty), coi1) <- boxySplitAppTy exp_ty1
54         ; let cmd_env = CmdEnv { cmd_arr = arr_ty }
55         ; (pat', cmd') <- tcPat ProcExpr pat arg_ty res_ty $
56                           tcCmdTop cmd_env cmd []
57         ; let res_coi = mkTransCoI coi (mkAppTyCoI exp_ty1 coi1 res_ty IdCo)
58         ; return (pat', cmd', res_coi) 
59         }
60 \end{code}
61
62
63 %************************************************************************
64 %*                                                                      *
65                 Commands
66 %*                                                                      *
67 %************************************************************************
68
69 \begin{code}
70 type CmdStack = [TcTauType]
71 data CmdEnv
72   = CmdEnv {
73         cmd_arr         :: TcType -- arrow type constructor, of kind *->*->*
74     }
75
76 mkCmdArrTy :: CmdEnv -> TcTauType -> TcTauType -> TcTauType
77 mkCmdArrTy env t1 t2 = mkAppTys (cmd_arr env) [t1, t2]
78
79 ---------------------------------------
80 tcCmdTop :: CmdEnv 
81          -> LHsCmdTop Name
82          -> CmdStack
83          -> TcTauType   -- Expected result type; always a monotype
84                              -- We know exactly how many cmd args are expected,
85                              -- albeit perhaps not their types; so we can pass 
86                              -- in a CmdStack
87         -> TcM (LHsCmdTop TcId)
88
89 tcCmdTop env (L loc (HsCmdTop cmd _ _ names)) cmd_stk res_ty
90   = setSrcSpan loc $
91     do  { cmd'   <- tcGuardedCmd env cmd cmd_stk res_ty
92         ; names' <- mapM (tcSyntaxName ProcOrigin (cmd_arr env)) names
93         ; return (L loc $ HsCmdTop cmd' cmd_stk res_ty names') }
94
95
96 ----------------------------------------
97 tcGuardedCmd :: CmdEnv -> LHsExpr Name -> CmdStack
98              -> TcTauType -> TcM (LHsExpr TcId)
99 -- A wrapper that deals with the refinement (if any)
100 tcGuardedCmd env expr stk res_ty
101   = do  { body <- tcCmd env expr (stk, res_ty)
102         ; return body 
103         }
104
105 tcCmd :: CmdEnv -> LHsExpr Name -> (CmdStack, TcTauType) -> TcM (LHsExpr TcId)
106         -- The main recursive function
107 tcCmd env (L loc expr) res_ty
108   = setSrcSpan loc $ do
109         { expr' <- tc_cmd env expr res_ty
110         ; return (L loc expr') }
111
112 tc_cmd :: CmdEnv -> HsExpr Name -> (CmdStack, TcTauType) -> TcM (HsExpr TcId)
113 tc_cmd env (HsPar cmd) res_ty
114   = do  { cmd' <- tcCmd env cmd res_ty
115         ; return (HsPar cmd') }
116
117 tc_cmd env (HsLet binds (L body_loc body)) res_ty
118   = do  { (binds', body') <- tcLocalBinds binds         $
119                              setSrcSpan body_loc        $
120                              tc_cmd env body res_ty
121         ; return (HsLet binds' (L body_loc body')) }
122
123 tc_cmd env in_cmd@(HsCase scrut matches) (stk, res_ty)
124   = addErrCtxt (cmdCtxt in_cmd) $ do
125       (scrut', scrut_ty) <- tcInferRho scrut 
126       matches' <- tcMatchesCase match_ctxt scrut_ty matches res_ty
127       return (HsCase scrut' matches')
128   where
129     match_ctxt = MC { mc_what = CaseAlt,
130                       mc_body = mc_body }
131     mc_body body res_ty' = tcGuardedCmd env body stk res_ty'
132
133 tc_cmd env (HsIf pred b1 b2) res_ty
134   = do  { pred' <- tcMonoExpr pred boolTy
135         ; b1'   <- tcCmd env b1 res_ty
136         ; b2'   <- tcCmd env b2 res_ty
137         ; return (HsIf pred' b1' b2')
138     }
139
140 -------------------------------------------
141 --              Arrow application
142 --          (f -< a)   or   (f -<< a)
143
144 tc_cmd env cmd@(HsArrApp fun arg _ ho_app lr) (cmd_stk, res_ty)
145   = addErrCtxt (cmdCtxt cmd)    $
146     do  { arg_ty <- newFlexiTyVarTy openTypeKind
147         ; let fun_ty = mkCmdArrTy env (foldl mkPairTy arg_ty cmd_stk) res_ty
148
149         ; fun' <- select_arrow_scope (tcMonoExpr fun fun_ty)
150
151         ; arg' <- tcMonoExpr arg arg_ty
152
153         ; return (HsArrApp fun' arg' fun_ty ho_app lr) }
154   where
155         -- Before type-checking f, use the environment of the enclosing
156         -- proc for the (-<) case.  
157         -- Local bindings, inside the enclosing proc, are not in scope 
158         -- inside f.  In the higher-order case (-<<), they are.
159     select_arrow_scope tc = case ho_app of
160         HsHigherOrderApp -> tc
161         HsFirstOrderApp  -> escapeArrowScope tc
162
163 -------------------------------------------
164 --              Command application
165
166 tc_cmd env cmd@(HsApp fun arg) (cmd_stk, res_ty)
167   = addErrCtxt (cmdCtxt cmd)    $
168     do  { arg_ty <- newFlexiTyVarTy openTypeKind
169
170         ; fun' <- tcCmd env fun (arg_ty:cmd_stk, res_ty)
171
172         ; arg' <- tcMonoExpr arg arg_ty
173
174         ; return (HsApp fun' arg') }
175
176 -------------------------------------------
177 --              Lambda
178
179 tc_cmd env cmd@(HsLam (MatchGroup [L mtch_loc (match@(Match pats _maybe_rhs_sig grhss))] _))
180        (cmd_stk, res_ty)
181   = addErrCtxt (pprMatchInCtxt match_ctxt match)        $
182
183     do  {       -- Check the cmd stack is big enough
184         ; checkTc (lengthAtLeast cmd_stk n_pats)
185                   (kappaUnderflow cmd)
186
187                 -- Check the patterns, and the GRHSs inside
188         ; (pats', grhss') <- setSrcSpan mtch_loc                        $
189                              tcPats LambdaExpr pats cmd_stk res_ty      $
190                              tc_grhss grhss
191
192         ; let match' = L mtch_loc (Match pats' Nothing grhss')
193         ; return (HsLam (MatchGroup [match'] res_ty))
194         }
195
196   where
197     n_pats     = length pats
198     stk'       = drop n_pats cmd_stk
199     match_ctxt = (LambdaExpr :: HsMatchContext Name)    -- Maybe KappaExpr?
200     pg_ctxt    = PatGuard match_ctxt
201
202     tc_grhss (GRHSs grhss binds) res_ty
203         = do { (binds', grhss') <- tcLocalBinds binds $
204                                    mapM (wrapLocM (tc_grhs res_ty)) grhss
205              ; return (GRHSs grhss' binds') }
206
207     tc_grhs res_ty (GRHS guards body)
208         = do { (guards', rhs') <- tcStmts pg_ctxt tcGuardStmt guards res_ty $
209                                   tcGuardedCmd env body stk'
210              ; return (GRHS guards' rhs') }
211
212 -------------------------------------------
213 --              Do notation
214
215 tc_cmd env cmd@(HsDo do_or_lc stmts body _ty) (cmd_stk, res_ty)
216   = do  { checkTc (null cmd_stk) (nonEmptyCmdStkErr cmd)
217         ; (stmts', body') <- tcStmts do_or_lc tc_stmt stmts res_ty $
218                              tcGuardedCmd env body []
219         ; return (HsDo do_or_lc stmts' body' res_ty) }
220   where
221     tc_stmt = tcMDoStmt tc_rhs
222     tc_rhs rhs = do { ty <- newFlexiTyVarTy liftedTypeKind
223                     ; rhs' <- tcCmd env rhs ([], ty)
224                     ; return (rhs', ty) }
225
226
227 -----------------------------------------------------------------
228 --      Arrow ``forms''       (| e c1 .. cn |)
229 --
230 --      G      |-b  c : [s1 .. sm] s
231 --      pop(G) |-   e : forall w. b ((w,s1) .. sm) s
232 --                              -> a ((w,t1) .. tn) t
233 --      e \not\in (s, s1..sm, t, t1..tn)
234 --      ----------------------------------------------
235 --      G |-a  (| e c |)  :  [t1 .. tn] t
236
237 tc_cmd env cmd@(HsArrForm expr fixity cmd_args) (cmd_stk, res_ty)       
238   = addErrCtxt (cmdCtxt cmd)    $
239     do  { cmds_w_tys <- zipWithM new_cmd_ty cmd_args [1..]
240         ; [w_tv]     <- tcInstSkolTyVars ArrowSkol [alphaTyVar]
241         ; let w_ty = mkTyVarTy w_tv     -- Just a convenient starting point
242
243                 --  a ((w,t1) .. tn) t
244         ; let e_res_ty = mkCmdArrTy env (foldl mkPairTy w_ty cmd_stk) res_ty
245
246                 --   b ((w,s1) .. sm) s
247                 --   -> a ((w,t1) .. tn) t
248         ; let e_ty = mkFunTys [mkAppTys b [tup,s] | (_,_,b,tup,s) <- cmds_w_tys] 
249                               e_res_ty
250
251                 -- Check expr
252         ; (expr', lie) <- escapeArrowScope (getLIE (tcMonoExpr expr e_ty))
253         ; loc <- getInstLoc (SigOrigin ArrowSkol)
254         ; inst_binds <- tcSimplifyCheck loc [w_tv] [] lie
255
256                 -- Check that the polymorphic variable hasn't been unified with anything
257                 -- and is not free in res_ty or the cmd_stk  (i.e.  t, t1..tn)
258         ; checkSigTyVarsWrt (tyVarsOfTypes (res_ty:cmd_stk)) [w_tv] 
259
260                 -- OK, now we are in a position to unscramble 
261                 -- the s1..sm and check each cmd
262         ; cmds' <- mapM (tc_cmd w_tv) cmds_w_tys
263
264         ; return (HsArrForm (noLoc $ HsWrap (WpTyLam w_tv) 
265                                                (unLoc $ mkHsDictLet inst_binds expr')) 
266                              fixity cmds')
267         }
268   where
269         -- Make the types       
270         --      b, ((e,s1) .. sm), s
271     new_cmd_ty :: LHsCmdTop Name -> Int
272                -> TcM (LHsCmdTop Name, Int, TcType, TcType, TcType)
273     new_cmd_ty cmd i
274           = do  { b_ty   <- newFlexiTyVarTy arrowTyConKind
275                 ; tup_ty <- newFlexiTyVarTy liftedTypeKind
276                         -- We actually make a type variable for the tuple
277                         -- because we don't know how deeply nested it is yet    
278                 ; s_ty   <- newFlexiTyVarTy liftedTypeKind
279                 ; return (cmd, i, b_ty, tup_ty, s_ty)
280                 }
281
282     tc_cmd w_tv (cmd, i, b, tup_ty, s)
283       = do { tup_ty' <- zonkTcType tup_ty
284            ; let (corner_ty, arg_tys) = unscramble tup_ty'
285
286                 -- Check that it has the right shape:
287                 --      ((w,s1) .. sn)
288                 -- where the si do not mention w
289            ; checkTc (corner_ty `tcEqType` mkTyVarTy w_tv && 
290                       not (w_tv `elemVarSet` tyVarsOfTypes arg_tys))
291                      (badFormFun i tup_ty')
292
293            ; tcCmdTop (env { cmd_arr = b }) cmd arg_tys s }
294
295     unscramble :: TcType -> (TcType, [TcType])
296     -- unscramble ((w,s1) .. sn)        =  (w, [s1..sn])
297     unscramble ty = unscramble' ty []
298
299     unscramble' ty ss
300        = case tcSplitTyConApp_maybe ty of
301             Just (tc, [t,s]) | tc == pairTyCon 
302                ->  unscramble' t (s:ss)
303             _ -> (ty, ss)
304
305 -----------------------------------------------------------------
306 --              Base case for illegal commands
307 -- This is where expressions that aren't commands get rejected
308
309 tc_cmd _ cmd _
310   = failWithTc (vcat [ptext (sLit "The expression"), nest 2 (ppr cmd), 
311                       ptext (sLit "was found where an arrow command was expected")])
312 \end{code}
313
314
315 %************************************************************************
316 %*                                                                      *
317                 Helpers
318 %*                                                                      *
319 %************************************************************************
320
321
322 \begin{code}
323 mkPairTy :: Type -> Type -> Type
324 mkPairTy t1 t2 = mkTyConApp pairTyCon [t1,t2]
325
326 arrowTyConKind :: Kind          --  *->*->*
327 arrowTyConKind = mkArrowKinds [liftedTypeKind, liftedTypeKind] liftedTypeKind
328 \end{code}
329
330
331 %************************************************************************
332 %*                                                                      *
333                 Errors
334 %*                                                                      *
335 %************************************************************************
336
337 \begin{code}
338 cmdCtxt :: HsExpr Name -> SDoc
339 cmdCtxt cmd = ptext (sLit "In the command:") <+> ppr cmd
340
341 nonEmptyCmdStkErr :: HsExpr Name -> SDoc
342 nonEmptyCmdStkErr cmd
343   = hang (ptext (sLit "Non-empty command stack at command:"))
344          4 (ppr cmd)
345
346 kappaUnderflow :: HsExpr Name -> SDoc
347 kappaUnderflow cmd
348   = hang (ptext (sLit "Command stack underflow at command:"))
349          4 (ppr cmd)
350
351 badFormFun :: Int -> TcType -> SDoc
352 badFormFun i tup_ty'
353  = hang (ptext (sLit "The type of the") <+> speakNth i <+> ptext (sLit "argument of a command form has the wrong shape"))
354         4 (ptext (sLit "Argument type:") <+> ppr tup_ty')
355 \end{code}