Use implication constraints to improve type inference
[ghc.git] / compiler / typecheck / TcArrows.lhs
1 %
2 % (c) The University of Glasgow 2006
3 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
4 %
5 Typecheck arrow notation
6
7 \begin{code}
8 module TcArrows ( tcProc ) where
9
10 #include "HsVersions.h"
11
12 import {-# SOURCE #-}   TcExpr( tcMonoExpr, tcInferRho )
13
14 import HsSyn
15 import TcHsSyn
16
17 import TcMatches
18
19 import TcType
20 import TcMType
21 import TcBinds
22 import TcSimplify
23 import TcGadt
24 import TcPat
25 import TcUnify
26 import TcRnMonad
27 import Inst
28 import Name
29 import TysWiredIn
30 import VarSet 
31 import TysPrim
32 import Type
33
34 import SrcLoc
35 import Outputable
36 import Util
37 \end{code}
38
39 %************************************************************************
40 %*                                                                      *
41                 Proc    
42 %*                                                                      *
43 %************************************************************************
44
45 \begin{code}
46 tcProc :: InPat Name -> LHsCmdTop Name          -- proc pat -> expr
47        -> BoxyRhoType                           -- Expected type of whole proc expression
48        -> TcM (OutPat TcId, LHsCmdTop TcId)
49
50 tcProc pat cmd exp_ty
51   = newArrowScope $
52     do  { (exp_ty1, res_ty) <- boxySplitAppTy exp_ty 
53         ; (arr_ty, arg_ty)  <- boxySplitAppTy exp_ty1
54         ; let cmd_env = CmdEnv { cmd_arr = arr_ty }
55         ; (pat', cmd') <- tcLamPat pat arg_ty (emptyRefinement, res_ty) $
56                           tcCmdTop cmd_env cmd []
57         ; return (pat', cmd') }
58 \end{code}
59
60
61 %************************************************************************
62 %*                                                                      *
63                 Commands
64 %*                                                                      *
65 %************************************************************************
66
67 \begin{code}
68 type CmdStack = [TcTauType]
69 data CmdEnv
70   = CmdEnv {
71         cmd_arr         :: TcType -- arrow type constructor, of kind *->*->*
72     }
73
74 mkCmdArrTy :: CmdEnv -> TcTauType -> TcTauType -> TcTauType
75 mkCmdArrTy env t1 t2 = mkAppTys (cmd_arr env) [t1, t2]
76
77 ---------------------------------------
78 tcCmdTop :: CmdEnv 
79          -> LHsCmdTop Name
80          -> CmdStack
81          -> (Refinement, TcTauType)     -- Expected result type; always a monotype
82                                         -- We know exactly how many cmd args are expected,
83                                         -- albeit perhaps not their types; so we can pass 
84                                         -- in a CmdStack
85         -> TcM (LHsCmdTop TcId)
86
87 tcCmdTop env (L loc (HsCmdTop cmd _ _ names)) cmd_stk reft_res_ty@(_,res_ty)
88   = setSrcSpan loc $
89     do  { cmd'   <- tcGuardedCmd env cmd cmd_stk reft_res_ty
90         ; names' <- mapM (tcSyntaxName ProcOrigin (cmd_arr env)) names
91         ; return (L loc $ HsCmdTop cmd' cmd_stk res_ty names') }
92
93
94 ----------------------------------------
95 tcGuardedCmd :: CmdEnv -> LHsExpr Name -> CmdStack
96              -> (Refinement, TcTauType) -> TcM (LHsExpr TcId)
97 -- A wrapper that deals with the refinement (if any)
98 tcGuardedCmd env expr stk (reft, res_ty)
99   = do  { let (co, res_ty') = refineResType reft res_ty
100         ; body <- tcCmd env expr (stk, res_ty')
101         ; return (mkLHsWrap co body) }
102
103 tcCmd :: CmdEnv -> LHsExpr Name -> (CmdStack, TcTauType) -> TcM (LHsExpr TcId)
104         -- The main recursive function
105 tcCmd env (L loc expr) res_ty
106   = setSrcSpan loc $ do
107         { expr' <- tc_cmd env expr res_ty
108         ; return (L loc expr') }
109
110 tc_cmd env (HsPar cmd) res_ty
111   = do  { cmd' <- tcCmd env cmd res_ty
112         ; return (HsPar cmd') }
113
114 tc_cmd env (HsLet binds (L body_loc body)) res_ty
115   = do  { (binds', body') <- tcLocalBinds binds         $
116                              setSrcSpan body_loc        $
117                              tc_cmd env body res_ty
118         ; return (HsLet binds' (L body_loc body')) }
119
120 tc_cmd env in_cmd@(HsCase scrut matches) (stk, res_ty)
121   = addErrCtxt (cmdCtxt in_cmd)         $
122     addErrCtxt (caseScrutCtxt scrut)    (
123       tcInferRho scrut 
124     )                                                   `thenM` \ (scrut', scrut_ty) ->
125     tcMatchesCase match_ctxt scrut_ty matches res_ty    `thenM` \ matches' ->
126     returnM (HsCase scrut' matches')
127   where
128     match_ctxt = MC { mc_what = CaseAlt,
129                       mc_body = mc_body }
130     mc_body body res_ty' = tcGuardedCmd env body stk res_ty'
131
132 tc_cmd env (HsIf pred b1 b2) res_ty
133   = do  { pred' <- tcMonoExpr pred boolTy
134         ; b1'   <- tcCmd env b1 res_ty
135         ; b2'   <- tcCmd env b2 res_ty
136         ; return (HsIf pred' b1' b2')
137     }
138
139 -------------------------------------------
140 --              Arrow application
141 --          (f -< a)   or   (f -<< a)
142
143 tc_cmd env cmd@(HsArrApp fun arg _ ho_app lr) (cmd_stk, res_ty)
144   = addErrCtxt (cmdCtxt cmd)    $
145     do  { arg_ty <- newFlexiTyVarTy openTypeKind
146         ; let fun_ty = mkCmdArrTy env (foldl mkPairTy arg_ty cmd_stk) res_ty
147
148         ; fun' <- select_arrow_scope (tcMonoExpr fun fun_ty)
149
150         ; arg' <- tcMonoExpr arg arg_ty
151
152         ; return (HsArrApp fun' arg' fun_ty ho_app lr) }
153   where
154         -- Before type-checking f, use the environment of the enclosing
155         -- proc for the (-<) case.  
156         -- Local bindings, inside the enclosing proc, are not in scope 
157         -- inside f.  In the higher-order case (-<<), they are.
158     select_arrow_scope tc = case ho_app of
159         HsHigherOrderApp -> tc
160         HsFirstOrderApp  -> escapeArrowScope tc
161
162 -------------------------------------------
163 --              Command application
164
165 tc_cmd env cmd@(HsApp fun arg) (cmd_stk, res_ty)
166   = addErrCtxt (cmdCtxt cmd)    $
167 -- gaw 2004 FIX?
168     do  { arg_ty <- newFlexiTyVarTy openTypeKind
169
170         ; fun' <- tcCmd env fun (arg_ty:cmd_stk, res_ty)
171
172         ; arg' <- tcMonoExpr arg arg_ty
173
174         ; return (HsApp fun' arg') }
175
176 -------------------------------------------
177 --              Lambda
178
179 tc_cmd env cmd@(HsLam (MatchGroup [L mtch_loc (match@(Match pats maybe_rhs_sig grhss))] _))
180        (cmd_stk, res_ty)
181   = addErrCtxt (matchCtxt match_ctxt match)     $
182
183     do  {       -- Check the cmd stack is big enough
184         ; checkTc (lengthAtLeast cmd_stk n_pats)
185                   (kappaUnderflow cmd)
186
187                 -- Check the patterns, and the GRHSs inside
188         ; (pats', grhss') <- setSrcSpan mtch_loc                $
189                              tcLamPats pats cmd_stk res_ty      $
190                              tc_grhss grhss
191
192         ; let match' = L mtch_loc (Match pats' Nothing grhss')
193         ; return (HsLam (MatchGroup [match'] res_ty))
194         }
195
196   where
197     n_pats     = length pats
198     stk'       = drop n_pats cmd_stk
199     match_ctxt = LambdaExpr     -- Maybe KappaExpr?
200     pg_ctxt    = PatGuard match_ctxt
201
202     tc_grhss (GRHSs grhss binds) res_ty
203         = do { (binds', grhss') <- tcLocalBinds binds $
204                                    mapM (wrapLocM (tc_grhs res_ty)) grhss
205              ; return (GRHSs grhss' binds') }
206
207     tc_grhs res_ty (GRHS guards body)
208         = do { (guards', rhs') <- tcStmts pg_ctxt tcGuardStmt guards res_ty $
209                                   tcGuardedCmd env body stk'
210              ; return (GRHS guards' rhs') }
211
212 -------------------------------------------
213 --              Do notation
214
215 tc_cmd env cmd@(HsDo do_or_lc stmts body ty) (cmd_stk, res_ty)
216   = do  { checkTc (null cmd_stk) (nonEmptyCmdStkErr cmd)
217         ; (stmts', body') <- tcStmts do_or_lc tc_stmt stmts (emptyRefinement, res_ty) $
218                              tcGuardedCmd env body []
219         ; return (HsDo do_or_lc stmts' body' res_ty) }
220   where
221     tc_stmt = tcMDoStmt tc_rhs
222     tc_rhs rhs = do { ty <- newFlexiTyVarTy liftedTypeKind
223                     ; rhs' <- tcCmd env rhs ([], ty)
224                     ; return (rhs', ty) }
225
226
227 -----------------------------------------------------------------
228 --      Arrow ``forms''       (| e c1 .. cn |)
229 --
230 --      G      |-b  c : [s1 .. sm] s
231 --      pop(G) |-   e : forall w. b ((w,s1) .. sm) s
232 --                              -> a ((w,t1) .. tn) t
233 --      e \not\in (s, s1..sm, t, t1..tn)
234 --      ----------------------------------------------
235 --      G |-a  (| e c |)  :  [t1 .. tn] t
236
237 tc_cmd env cmd@(HsArrForm expr fixity cmd_args) (cmd_stk, res_ty)       
238   = addErrCtxt (cmdCtxt cmd)    $
239     do  { cmds_w_tys <- zipWithM new_cmd_ty cmd_args [1..]
240         ; span       <- getSrcSpanM
241         ; [w_tv]     <- tcInstSkolTyVars ArrowSkol [alphaTyVar]
242         ; let w_ty = mkTyVarTy w_tv     -- Just a convenient starting point
243
244                 --  a ((w,t1) .. tn) t
245         ; let e_res_ty = mkCmdArrTy env (foldl mkPairTy w_ty cmd_stk) res_ty
246
247                 --   b ((w,s1) .. sm) s
248                 --   -> a ((w,t1) .. tn) t
249         ; let e_ty = mkFunTys [mkAppTys b [tup,s] | (_,_,b,tup,s) <- cmds_w_tys] 
250                               e_res_ty
251
252                 -- Check expr
253         ; (expr', lie) <- escapeArrowScope (getLIE (tcMonoExpr expr e_ty))
254         ; loc <- getInstLoc (SigOrigin ArrowSkol)
255         ; inst_binds <- tcSimplifyCheck loc [w_tv] [] lie
256
257                 -- Check that the polymorphic variable hasn't been unified with anything
258                 -- and is not free in res_ty or the cmd_stk  (i.e.  t, t1..tn)
259         ; checkSigTyVarsWrt (tyVarsOfTypes (res_ty:cmd_stk)) [w_tv] 
260
261                 -- OK, now we are in a position to unscramble 
262                 -- the s1..sm and check each cmd
263         ; cmds' <- mapM (tc_cmd w_tv) cmds_w_tys
264
265         ; returnM (HsArrForm (noLoc $ HsWrap (WpTyLam w_tv) 
266                                                (unLoc $ mkHsDictLet inst_binds expr')) 
267                              fixity cmds')
268         }
269   where
270         -- Make the types       
271         --      b, ((e,s1) .. sm), s
272     new_cmd_ty :: LHsCmdTop Name -> Int
273                -> TcM (LHsCmdTop Name, Int, TcType, TcType, TcType)
274     new_cmd_ty cmd i
275           = do  { b_ty   <- newFlexiTyVarTy arrowTyConKind
276                 ; tup_ty <- newFlexiTyVarTy liftedTypeKind
277                         -- We actually make a type variable for the tuple
278                         -- because we don't know how deeply nested it is yet    
279                 ; s_ty   <- newFlexiTyVarTy liftedTypeKind
280                 ; return (cmd, i, b_ty, tup_ty, s_ty)
281                 }
282
283     tc_cmd w_tv (cmd, i, b, tup_ty, s)
284       = do { tup_ty' <- zonkTcType tup_ty
285            ; let (corner_ty, arg_tys) = unscramble tup_ty'
286
287                 -- Check that it has the right shape:
288                 --      ((w,s1) .. sn)
289                 -- where the si do not mention w
290            ; checkTc (corner_ty `tcEqType` mkTyVarTy w_tv && 
291                       not (w_tv `elemVarSet` tyVarsOfTypes arg_tys))
292                      (badFormFun i tup_ty')
293
294            ; tcCmdTop (env { cmd_arr = b }) cmd arg_tys (emptyRefinement, s) }
295
296     unscramble :: TcType -> (TcType, [TcType])
297     -- unscramble ((w,s1) .. sn)        =  (w, [s1..sn])
298     unscramble ty
299        = case tcSplitTyConApp_maybe ty of
300             Just (tc, [t,s]) | tc == pairTyCon 
301                ->  let 
302                       (w,ss) = unscramble t  
303                    in (w, s:ss)
304                                     
305             other -> (ty, [])
306
307 -----------------------------------------------------------------
308 --              Base case for illegal commands
309 -- This is where expressions that aren't commands get rejected
310
311 tc_cmd env cmd _
312   = failWithTc (vcat [ptext SLIT("The expression"), nest 2 (ppr cmd), 
313                       ptext SLIT("was found where an arrow command was expected")])
314 \end{code}
315
316
317 %************************************************************************
318 %*                                                                      *
319                 Helpers
320 %*                                                                      *
321 %************************************************************************
322
323
324 \begin{code}
325 mkPairTy t1 t2 = mkTyConApp pairTyCon [t1,t2]
326
327 arrowTyConKind :: Kind          --  *->*->*
328 arrowTyConKind = mkArrowKinds [liftedTypeKind, liftedTypeKind] liftedTypeKind
329 \end{code}
330
331
332 %************************************************************************
333 %*                                                                      *
334                 Errors
335 %*                                                                      *
336 %************************************************************************
337
338 \begin{code}
339 cmdCtxt cmd = ptext SLIT("In the command:") <+> ppr cmd
340
341 caseScrutCtxt cmd
342   = hang (ptext SLIT("In the scrutinee of a case command:")) 4 (ppr cmd)
343
344 nonEmptyCmdStkErr cmd
345   = hang (ptext SLIT("Non-empty command stack at command:"))
346          4 (ppr cmd)
347
348 kappaUnderflow cmd
349   = hang (ptext SLIT("Command stack underflow at command:"))
350          4 (ppr cmd)
351
352 badFormFun i tup_ty'
353  = hang (ptext SLIT("The type of the") <+> speakNth i <+> ptext SLIT("argument of a command form has the wrong shape"))
354         4 (ptext SLIT("Argument type:") <+> ppr tup_ty')
355 \end{code}