module Vectorise.Exp
  (   
    vectTopExpr
  , vectTopExprs
  , vectScalarFun
  , vectScalarDFun
  )
where
#include "HsVersions.h"
import GhcPrelude
import Vectorise.Type.Type
import Vectorise.Var
import Vectorise.Convert
import Vectorise.Vect
import Vectorise.Env
import Vectorise.Monad
import Vectorise.Builtins
import Vectorise.Utils
import CoreUtils
import MkCore
import CoreSyn
import CoreFVs
import Class
import DataCon
import TyCon
import TcType
import Type
import TyCoRep
import Var
import VarEnv
import VarSet
import NameSet
import Id
import BasicTypes( isStrongLoopBreaker )
import Literal
import TysPrim
import Outputable
import FastString
import DynFlags
import Util
import Control.Monad
import Data.Maybe
import Data.List
vectTopExpr :: Var -> CoreExpr -> VM (Maybe (Bool, Inline, CoreExpr))
vectTopExpr var expr
  = do
    { exprVI <- encapsulateScalars <=< vectAvoidInfo emptyVarSet . freeVars $ expr
    ; if isVIEncaps exprVI
      then
        return Nothing
      else do
      { vExpr <- closedV $
                   inBind var $
                     vectAnnPolyExpr False exprVI
      ; inline <- computeInline exprVI
      ; return $ Just (isVIParr exprVI, inline, vectorised vExpr)
      }
    }
computeInline :: CoreExprWithVectInfo -> VM Inline
computeInline ((_, VIDict), _)     = return $ DontInline
computeInline (_, AnnTick _ expr)  = computeInline expr
computeInline expr@(_, AnnLam _ _) = Inline <$> polyArity tvs
  where
    (tvs, _) = collectAnnTypeBinders expr
computeInline _expr                = return $ DontInline
vectTopExprs :: [(Var, CoreExpr)] -> VM (Maybe (Bool, [(Inline, CoreExpr)]))
vectTopExprs binds
  = do
    { exprVIs <- mapM (vectAvoidAndEncapsulate emptyVarSet) exprs
    ; if all isVIEncaps exprVIs
        
      then return Nothing
      else do
      {   
      ; let areVIParr = any isVIParr exprVIs
      ; revised_exprVIs <- if not areVIParr
                             
                           then return exprVIs
                             
                           else mapM (vectAvoidAndEncapsulate (mkVarSet vars)) exprs
      ; vExprs <- zipWithM vect vars revised_exprVIs
      ; return $ Just (areVIParr, vExprs)
      }
    }
  where
    (vars, exprs) = unzip binds
    vectAvoidAndEncapsulate pvs = encapsulateScalars <=< vectAvoidInfo pvs . freeVars
    vect var exprVI
      = do
        { vExpr  <- closedV $
                      inBind var $
                        vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo var) exprVI
        ; inline <- computeInline exprVI
        ; return (inline, vectorised vExpr)
        }
vectAnnPolyExpr :: Bool -> CoreExprWithVectInfo -> VM VExpr
vectAnnPolyExpr loop_breaker (_, AnnTick tickish expr)
    
  = vTick tickish <$> vectAnnPolyExpr loop_breaker expr
vectAnnPolyExpr loop_breaker expr
  | isVIDict expr
    
  = (, undefined) <$> vectDictExpr (deAnnotate expr)
  | otherwise
    
  = polyAbstract tvs $ \args ->
      mapVect (mkLams $ tvs ++ args) <$> vectFnExpr False loop_breaker mono
  where
    (tvs, mono) = collectAnnTypeBinders expr
encapsulateScalars :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
encapsulateScalars ce@(_, AnnType _ty)
  = return ce
encapsulateScalars ce@((_, VISimple), AnnVar _v)
      
  = liftSimpleAndCase ce
encapsulateScalars ce@(_, AnnVar _v)
  = return ce
encapsulateScalars ce@(_, AnnLit _)
  = return ce
encapsulateScalars ((fvs, vi), AnnTick tck expr)
  = do
    { encExpr <- encapsulateScalars expr
    ; return ((fvs, vi), AnnTick tck encExpr)
    }
encapsulateScalars ce@((fvs, vi), AnnLam bndr expr)
  = do
    { vectAvoid <- isVectAvoidanceAggressive
    ; varsS     <- allScalarVarTypeSet fvs
        
        
        
    ; bndrsS    <- allScalarVarType bndrs
    ; case (vi, vectAvoid && varsS && bndrsS) of
        (VISimple, True) -> liftSimpleAndCase ce
        _                -> do
                            { encExpr <- encapsulateScalars expr
                            ; return ((fvs, vi), AnnLam bndr encExpr)
                            }
    }
  where
    (bndrs, _) = collectAnnBndrs ce
encapsulateScalars ce@((fvs, vi), AnnApp ce1 ce2)
  = do
    { vectAvoid <- isVectAvoidanceAggressive
    ; varsS     <- allScalarVarTypeSet fvs
    ; case (vi, (vectAvoid || isSimpleApplication ce) && varsS) of
        (VISimple, True) -> liftSimpleAndCase ce
        _                -> do
                            { encCe1 <- encapsulateScalars ce1
                            ; encCe2 <- encapsulateScalars ce2
                            ; return ((fvs, vi), AnnApp encCe1 encCe2)
                            }
    }
  where
    isSimpleApplication :: CoreExprWithVectInfo -> Bool
    isSimpleApplication (_, AnnTick _ ce)                 = isSimpleApplication ce
    isSimpleApplication (_, AnnCast ce _)                 = isSimpleApplication ce
    isSimpleApplication ce                  | isSimple ce = True
    isSimpleApplication (_, AnnApp ce1 ce2)               = isSimple ce1 && isSimpleApplication ce2
    isSimpleApplication _                                 = False
    
    isSimple :: CoreExprWithVectInfo -> Bool
    isSimple (_, AnnType {})   = True
    isSimple (_, AnnVar  {})   = True
    isSimple (_, AnnLit  {})   = True
    isSimple (_, AnnTick _ ce) = isSimple ce
    isSimple (_, AnnCast ce _) = isSimple ce
    isSimple _                 = False
encapsulateScalars ce@((fvs, vi), AnnCase scrut bndr ty alts)
  = do
    { vectAvoid <- isVectAvoidanceAggressive
    ; varsS     <- allScalarVarTypeSet fvs
    ; case (vi, vectAvoid && varsS) of
        (VISimple, True) -> liftSimpleAndCase ce
        _                -> do
                            { encScrut <- encapsulateScalars scrut
                            ; encAlts  <- mapM encAlt alts
                            ; return ((fvs, vi), AnnCase encScrut bndr ty encAlts)
                            }
    }
  where
    encAlt (con, bndrs, expr) = (con, bndrs,) <$> encapsulateScalars expr
encapsulateScalars ce@((fvs, vi), AnnLet (AnnNonRec bndr expr1) expr2)
  = do
    { vectAvoid <- isVectAvoidanceAggressive
    ; varsS     <- allScalarVarTypeSet fvs
    ; case (vi, vectAvoid && varsS) of
        (VISimple, True) -> liftSimpleAndCase ce
        _                -> do
                            { encExpr1 <- encapsulateScalars expr1
                            ; encExpr2 <- encapsulateScalars expr2
                            ; return ((fvs, vi), AnnLet (AnnNonRec bndr encExpr1) encExpr2)
                            }
    }
encapsulateScalars ce@((fvs, vi), AnnLet (AnnRec binds) expr)
  = do
    { vectAvoid <- isVectAvoidanceAggressive
    ; varsS     <- allScalarVarTypeSet fvs
    ; case (vi, vectAvoid && varsS) of
        (VISimple, True) -> liftSimpleAndCase ce
        _                -> do
                            { encBinds <- mapM encBind binds
                            ; encExpr  <- encapsulateScalars expr
                            ; return ((fvs, vi), AnnLet (AnnRec encBinds) encExpr)
                            }
    }
 where
   encBind (bndr, expr) = (bndr,) <$> encapsulateScalars expr
encapsulateScalars ((fvs, vi), AnnCast expr coercion)
  = do
    { encExpr <- encapsulateScalars expr
    ; return ((fvs, vi), AnnCast encExpr coercion)
    }
encapsulateScalars _
  = panic "Vectorise.Exp.encapsulateScalars: unknown constructor"
liftSimpleAndCase :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
liftSimpleAndCase aexpr@((fvs, _vi), AnnCase expr bndr t alts)
  = do
    { vi <- vectAvoidInfoTypeOf expr
    ; if (vi == VISimple)
      then
        liftSimple aexpr  
      else do
      { alts' <- mapM (\(ac, bndrs, aexpr) -> (ac, bndrs,) <$> liftSimpleAndCase aexpr) alts
      ; return ((fvs, vi), AnnCase expr bndr t alts')
      }
    }
liftSimpleAndCase aexpr = liftSimple aexpr
liftSimple :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
liftSimple ((fvs, vi), AnnVar v)
  | v `elemDVarSet` fvs               
  && not (isToplevel v)               
  = return $ ((fvs, vi), AnnVar v)
liftSimple aexpr@((fvs_orig, VISimple), expr)
  = do
    { let liftedExpr = mkAnnApps (mkAnnLams (reverse vars) fvs expr) vars
    ; traceVt "encapsulate:" $ ppr (deAnnotate aexpr) $$ text "==>" $$ ppr (deAnnotate liftedExpr)
    ; return $ liftedExpr
    }
  where
    vars = dVarSetElems fvs
    fvs  = filterDVarSet (not . isToplevel) fvs_orig 
    mkAnnLams :: [Var] -> DVarSet -> AnnExpr' Var (DVarSet, VectAvoidInfo) -> CoreExprWithVectInfo
    mkAnnLams []     fvs expr = ASSERT(isEmptyDVarSet fvs)
                                ((emptyDVarSet, VIEncaps), expr)
    mkAnnLams (v:vs) fvs expr = mkAnnLams vs (fvs `delDVarSet` v) (AnnLam v ((fvs, VIEncaps), expr))
    mkAnnApps :: CoreExprWithVectInfo -> [Var] -> CoreExprWithVectInfo
    mkAnnApps aexpr []     = aexpr
    mkAnnApps aexpr (v:vs) = mkAnnApps (mkAnnApp aexpr v) vs
    mkAnnApp :: CoreExprWithVectInfo -> Var -> CoreExprWithVectInfo
    mkAnnApp aexpr@((fvs, _vi), _expr) v
      = ((fvs `extendDVarSet` v, VISimple), AnnApp aexpr ((unitDVarSet v, VISimple), AnnVar v))
liftSimple aexpr
  = pprPanic "Vectorise.Exp.liftSimple: not simple" $ ppr (deAnnotate aexpr)
isToplevel :: Var -> Bool
isToplevel v | isId v    = case realIdUnfolding v of
                             NoUnfolding                     -> False
                             BootUnfolding                   -> False
                             OtherCon      {}                -> True
                             DFunUnfolding {}                -> True
                             CoreUnfolding {uf_is_top = top} -> top
             | otherwise = False
vectExpr :: CoreExprWithVectInfo -> VM VExpr
vectExpr aexpr
    
  | (isFunTy . annExprType $ aexpr) && isVIEncaps aexpr
  = vectFnExpr True False aexpr
    
  | isVIEncaps aexpr
  = traceVt "vectExpr (encapsulated constant):" (ppr . deAnnotate $ aexpr) >>
    vectConst (deAnnotate aexpr)
vectExpr (_, AnnVar v)
  = vectVar v
vectExpr (_, AnnLit lit)
  = vectConst $ Lit lit
vectExpr aexpr@(_, AnnLam _ _)
  = traceVt "vectExpr [AnnLam]:" (ppr . deAnnotate $ aexpr) >>
    vectFnExpr True False aexpr
  
  
  
vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err)
  | v == pAT_ERROR_ID
  = do
    { (vty, lty) <- vectAndLiftType ty
    ; return (mkCoreApps (Var v) [Type (getRuntimeRep vty), Type vty, err'],
              mkCoreApps (Var v) [Type lty, err'])
    }
  where
    err' = deAnnotate err
  
  
vectExpr e@(_, AnnApp _ arg)
  | isAnnTypeArg arg
  = vectPolyApp e
  
vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
  | Just _con <- isDataConId_maybe v
  = do
    { let vexpr = App (Var v) (Lit lit)
    ; lexpr <- liftPD vexpr
    ; return (vexpr, lexpr)
    }
  
vectExpr e@(_, AnnApp fn arg)
  | isPredTy arg_ty   
  = vectPolyApp e
  | otherwise         
  = do
    {   
    ; varg_ty <- vectType arg_ty
    ; vres_ty <- vectType res_ty
        
    ; vfn  <- vectExpr fn
    ; varg <- vectExpr arg
        
    ; mkClosureApp varg_ty vres_ty vfn varg
    }
  where
    (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
vectExpr (_, AnnCase scrut bndr ty alts)
  | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
  , isAlgTyCon tycon
  = vectAlgCase tycon ty_args scrut bndr ty alts
  | otherwise
  = do
    { dflags <- getDynFlags
    ; cantVectorise dflags "Can't vectorise expression (no algebraic type constructor)" $
        ppr scrut_ty
    }
  where
    scrut_ty = exprType (deAnnotate scrut)
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
  = do
    { traceVt "let binding (non-recursive)" Outputable.empty
    ; vrhs <- localV $
                inBind bndr $
                  vectAnnPolyExpr False rhs
    ; traceVt "let body (non-recursive)" Outputable.empty
    ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
    ; return $ vLet (vNonRec vbndr vrhs) vbody
    }
vectExpr (_, AnnLet (AnnRec bs) body)
  = do
    { (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs $ do
                                  { traceVt "let bindings (recursive)" Outputable.empty
                                  ; vrhss <- zipWithM vect_rhs bndrs rhss
                                  ; traceVt "let body (recursive)" Outputable.empty
                                  ; vbody <- vectExpr body
                                  ; return (vrhss, vbody)
                                  }
    ; return $ vLet (vRec vbndrs vrhss) vbody
    }
  where
    (bndrs, rhss) = unzip bs
    vect_rhs bndr rhs = localV $
                          inBind bndr $
                            vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) rhs
vectExpr (_, AnnTick tickish expr)
  = vTick tickish <$> vectExpr expr
vectExpr (_, AnnType ty)
  = vType <$> vectType ty
vectExpr e
  = do
    { dflags <- getDynFlags
    ; cantVectorise dflags "Can't vectorise expression (vectExpr)" $ ppr (deAnnotate e)
    }
vectFnExpr :: Bool                  
                                    
           -> Bool                  
           -> CoreExprWithVectInfo  
           -> VM VExpr
vectFnExpr inline loop_breaker aexpr@(_ann, AnnLam bndr body)
    
  | isId bndr
    && isPredTy (idType bndr)
  = do
    { vBndr <- vectBndr bndr
    ; vbody <- vectFnExpr inline loop_breaker body
    ; return $ mapVect (mkLams [vectorised vBndr]) vbody
    }
    
  | isId bndr && isVIEncaps aexpr
  = vectScalarFun . deAnnotate $ aexpr
    
  | isId bndr
  = vectLam inline loop_breaker aexpr
  | otherwise
  = do
    { dflags <- getDynFlags
    ; cantVectorise dflags "Vectorise.Exp.vectFnExpr: Unexpected type lambda" $
        ppr (deAnnotate aexpr)
    }
vectFnExpr _ _ aexpr
    
  | (isFunTy . annExprType $ aexpr) && isVIEncaps aexpr
  = vectScalarFun . deAnnotate $ aexpr
  | otherwise
    
    
  = vectExpr aexpr
vectPolyApp :: CoreExprWithVectInfo -> VM VExpr
vectPolyApp e0
  = case e4 of
      (_, AnnVar var)
        -> do {   
              ; vVar <- lookupVar var
              ; traceVt "vectPolyApp of" (ppr var)
                  
              ; vDictsOuter <- mapM vectDictExpr (map deAnnotate dictsOuter)
              ; vDictsInner <- mapM vectDictExpr (map deAnnotate dictsInner)
              ; vTysOuter   <- mapM vectType     tysOuter
              ; vTysInner   <- mapM vectType     tysInner
              ; let reconstructOuter v = (`mkApps` vDictsOuter) <$> polyApply v vTysOuter
              ; case vVar of
                  Local (vv, lv)
                    -> do { MASSERT( null dictsInner )    
                          ; traceVt "  LOCAL" (text "")
                          ; (,) <$> reconstructOuter (Var vv) <*> reconstructOuter (Var lv)
                          }
                  Global vv
                    | isDictComp var                      
                    -> do {   
                              
                              
                          ; ve <- if null dictsInner
                                  then
                                    return $ Var vv `mkTyApps` vTysOuter `mkApps` vDictsOuter
                                  else
                                    reconstructOuter
                                      (Var vv `mkTyApps` vTysInner `mkApps` vDictsInner)
                          ; traceVt "  GLOBAL (dict):" (ppr ve)
                          ; vectConst ve
                          }
                    | otherwise                           
                    -> do { MASSERT( null dictsInner )
                          ; ve <- reconstructOuter (Var vv)
                          ; traceVt "  GLOBAL (non-dict):" (ppr ve)
                          ; vectConst ve
                          }
              }
      _ -> pprSorry "Cannot vectorise programs with higher-rank types:" (ppr . deAnnotate $ e0)
  where
    
    (e1, dictsOuter) = collectAnnDictArgs e0
    (e2, tysOuter)   = collectAnnTypeArgs e1
    (e3, dictsInner) = collectAnnDictArgs e2
    (e4, tysInner)   = collectAnnTypeArgs e3
    
    isDictComp var = (isJust . isClassOpId_maybe $ var) || isDFunId var
vectDictExpr :: CoreExpr -> VM CoreExpr
vectDictExpr (Var var)
  = do { mb_scope <- lookupVar_maybe var
       ; case mb_scope of
           Nothing                -> return $ Var var   
           Just (Local (vVar, _)) -> return $ Var vVar  
           Just (Global vVar)     -> return $ Var vVar  
       }
vectDictExpr (Lit lit)
  = pprPanic "Vectorise.Exp.vectDictExpr: literal in dictionary computation" (ppr lit)
vectDictExpr (Lam bndr e)
  = Lam bndr <$> vectDictExpr e
vectDictExpr (App fn arg)
  = App <$> vectDictExpr fn <*> vectDictExpr arg
vectDictExpr (Case e bndr ty alts)
  = Case <$> vectDictExpr e <*> pure bndr <*> vectType ty <*> mapM vectDictAlt alts
  where
    vectDictAlt (con, bs, e) = (,,) <$> vectDictAltCon con <*> pure bs <*> vectDictExpr e
    
    vectDictAltCon (DataAlt datacon) = DataAlt <$> maybeV dataConErr (lookupDataCon datacon)
      where
        dataConErr = text "Cannot vectorise data constructor:" <+> ppr datacon
    vectDictAltCon (LitAlt lit)      = return $ LitAlt lit
    vectDictAltCon DEFAULT           = return DEFAULT
vectDictExpr (Let bnd body)
  = Let <$> vectDictBind bnd <*> vectDictExpr body
  where
    vectDictBind (NonRec bndr e) = NonRec bndr <$> vectDictExpr e
    vectDictBind (Rec bnds)      = Rec <$> mapM (\(bndr, e) -> (bndr,) <$> vectDictExpr e) bnds
vectDictExpr e@(Cast _e _coe)
  = pprSorry "Vectorise.Exp.vectDictExpr: cast" (ppr e)
vectDictExpr (Tick tickish e)
  = Tick tickish <$> vectDictExpr e
vectDictExpr (Type ty)
  = Type <$> vectType ty
vectDictExpr (Coercion coe)
  = pprSorry "Vectorise.Exp.vectDictExpr: coercion" (ppr coe)
vectScalarFun :: CoreExpr -> VM VExpr
vectScalarFun expr
  = do
    { traceVt "vectScalarFun:" (ppr expr)
    ; let (arg_tys, res_ty) = splitFunTys (exprType expr)
    ; mkScalarFun arg_tys res_ty expr
    }
mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
mkScalarFun arg_tys res_ty expr
  | isPredTy res_ty
  = do { vExpr <- vectDictExpr expr
       ; return (vExpr, unused)
       }
  | otherwise
  = do { traceVt "mkScalarFun: " $ ppr expr $$ text "  ::" <+>
                                   ppr (mkFunTys arg_tys res_ty)
       ; fn_var  <- hoistExpr (fsLit "fn") expr DontInline
       ; zipf    <- zipScalars arg_tys res_ty
       ; clo     <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
       ; clo_var <- hoistExpr (fsLit "clo") clo DontInline
       ; lclo    <- liftPD (Var clo_var)
       ; return (Var clo_var, lclo)
       }
  where
    unused = error "Vectorise.Exp.mkScalarFun: we don't lift dictionary expressions"
vectScalarDFun :: Var        
               -> VM CoreExpr
vectScalarDFun var
  = do {   
       ; mapM_ defLocalTyVar tvs
           
       ; vTheta     <- mapM vectType theta
       ; vThetaBndr <- mapM (newLocalVar (fsLit "vd")) vTheta
       ; let vThetaVars = varsToCoreExprs vThetaBndr
           
       ; thetaVars  <- mapM (newLocalVar (fsLit "d")) theta
       ; thetaExprs <- zipWithM unVectDict theta vThetaVars
       ; let thetaDictBinds = zipWith NonRec thetaVars thetaExprs
             dict           = Var var `mkTyApps` (mkTyVarTys tvs) `mkVarApps` thetaVars
             scsOps         = map (\selId -> varToCoreExpr selId `mkTyApps` tys `mkApps` [dict])
                                  selIds
       ; vScsOps <- mapM (\e -> vectorised <$> vectScalarFun e) scsOps
           
       ; Just vDataCon <- lookupDataCon dataCon
       ; vTys          <- mapM vectType tys
       ; let vBody = thetaDictBinds `mkLets` mkCoreConApps vDataCon (map Type vTys ++ vScsOps)
       ; return $ mkLams (tvs ++ vThetaBndr) vBody
       }
  where
    ty                   = varType var
    (tvs, theta, pty)    = tcSplitSigmaTy  ty        
    (cls, tys)           = tcSplitDFunHead pty       
    selIds               = classAllSelIds cls
    dataCon              = classDataCon cls
unVectDict :: Type -> CoreExpr -> VM CoreExpr
unVectDict ty e
  = do { vTys <- mapM vectType tys
       ; let meths = map (\sel -> Var sel `mkTyApps` vTys `mkApps` [e]) selIds
       ; scOps <- zipWithM fromVect methTys meths
       ; return $ mkCoreConApps dataCon (map Type tys ++ scOps)
       }
  where
    (tycon, tys) = splitTyConApp ty
    Just dataCon = isDataProductTyCon_maybe tycon
    Just cls     = tyConClass_maybe tycon
    methTys      = dataConInstArgTys dataCon tys
    selIds       = classAllSelIds cls
vectLam :: Bool                 
        -> Bool                 
        -> CoreExprWithVectInfo 
        -> VM VExpr
vectLam inline loop_breaker expr@((fvs, _vi), AnnLam _ _)
 = do { traceVt "fully vectorise a lambda expression" (ppr . deAnnotate $ expr)
      ; let (bndrs, body) = collectAnnValBinders expr
          
      ; tyvars <- localTyVars
          
      ; vfvs <- readLEnv $ \env ->
                  [ (var, fromJust mb_vv)
                  | var <- dVarSetElems fvs
                  , let mb_vv = lookupVarEnv (local_vars env) var
                  , isJust mb_vv         
                  ]
          
      ; let (vvs_dict, vvs_nondict)     = partition (isPredTy . varType . fst) vfvs
            (_fvs_dict, vfvs_dict)      = unzip vvs_dict
            (fvs_nondict, vfvs_nondict) = unzip vvs_nondict
          
      ; arg_tys <- mapM (vectType . idType) bndrs
      ; res_ty  <- vectType (exprType $ deAnnotate body)
      ; let arity      = length fvs_nondict + length bndrs
            vfvs_dict' = map vectorised vfvs_dict
      ; buildClosures tyvars vfvs_dict' vfvs_nondict arg_tys res_ty
        . hoistPolyVExpr tyvars vfvs_dict' (maybe_inline arity)
        $ do {   
             ; lc              <- builtin liftingContext
             ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) $ vectExpr body
             ; vbody' <- break_loop lc res_ty vbody
             ; return $ vLams lc vbndrs vbody'
             }
      }
  where
    maybe_inline n | inline    = Inline n
                   | otherwise = DontInline
    
    
    
    
    
    break_loop lc ty (ve, le)
      | loop_breaker
      = do { dflags <- getDynFlags
           ; empty <- emptyPD ty
           ; lty   <- mkPDataType ty
           ; return (ve, mkWildCase (Var lc) intPrimTy lty
                           [(DEFAULT, [], le),
                            (LitAlt (mkMachInt dflags 0), [], empty)])
           }
      | otherwise = return (ve, le)
vectLam _ _ _ = panic "Vectorise.Exp.vectLam: not a lambda"
vectAlgCase :: TyCon -> [Type] -> CoreExprWithVectInfo -> Var -> Type
            -> [(AltCon, [Var], CoreExprWithVectInfo)]
            -> VM VExpr
vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
  = do
    { traceVt "scrutinee (DEFAULT only)" Outputable.empty
    ; vscrut         <- vectExpr scrut
    ; (vty, lty)     <- vectAndLiftType ty
    ; traceVt "alternative body (DEFAULT only)" Outputable.empty
    ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
    ; return $ vCaseDEFAULT vscrut vbndr vty lty vbody
    }
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
  = do
    { traceVt "scrutinee (one shot w/o binders)" Outputable.empty
    ; vscrut         <- vectExpr scrut
    ; (vty, lty)     <- vectAndLiftType ty
    ; traceVt "alternative body (one shot w/o binders)" Outputable.empty
    ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
    ; return $ vCaseDEFAULT vscrut vbndr vty lty vbody
    }
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
  = do
    { traceVt "scrutinee (one shot w/ binders)" Outputable.empty
    ; vexpr      <- vectExpr scrut
    ; (vty, lty) <- vectAndLiftType ty
    ; traceVt "alternative body (one shot w/ binders)" Outputable.empty
    ; (vbndr, (vbndrs, (vect_body, lift_body)))
        <- vect_scrut_bndr
         . vectBndrsIn bndrs
         $ vectExpr body
    ; let (vect_bndrs, lift_bndrs) = unzip vbndrs
    ; (vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
    ; vect_dc <- maybeV dataConErr (lookupDataCon dc)
    ; let vcase = mk_wild_case vscrut vty vect_dc  vect_bndrs vect_body
          lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
    ; return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
    }
  where
    vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
                    | otherwise         = vectBndrIn bndr
    mk_wild_case expr ty dc bndrs body
      = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
    dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
vectAlgCase tycon _ty_args scrut bndr ty alts
  = do
    { traceVt "scrutinee (general case)" Outputable.empty
    ; vexpr <- vectExpr scrut
    ; vect_tc     <- vectTyCon tycon
    ; (vty, lty)  <- vectAndLiftType ty
    ; let arity = length (tyConDataCons vect_tc)
    ; sel_ty <- builtin (selTy arity)
    ; sel_bndr <- newLocalVar (fsLit "sel") sel_ty
    ; let sel = Var sel_bndr
    ; traceVt "alternatives' body (general case)" Outputable.empty
    ; (vbndr, valts) <- vect_scrut_bndr
                      $ mapM (proc_alt arity sel vty lty) alts'
    ; let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
    ; (vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
    ; let (vect_bodies, lift_bodies) = unzip vbodies
    ; vdummy <- newDummyVar (exprType vect_scrut)
    ; ldummy <- newDummyVar (exprType lift_scrut)
    ; let vect_case = Case vect_scrut vdummy vty
                           (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
    ; lc <- builtin liftingContext
    ; lbody <- combinePD vty (Var lc) sel lift_bodies
    ; let lift_case = Case lift_scrut ldummy lty
                           [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
                             lbody)]
    ; return . vLet (vNonRec vbndr vexpr)
             $ (vect_case, lift_case)
    }
  where
    vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
                    | otherwise         = vectBndrIn bndr
    alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
    cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
    cmp DEFAULT       DEFAULT       = EQ
    cmp DEFAULT       _             = LT
    cmp _             DEFAULT       = GT
    cmp _             _             = panic "vectAlgCase/cmp"
    proc_alt arity sel _ lty (DataAlt dc, bndrs, body@((fvs_body, _), _))
      = do
          dflags <- getDynFlags
          vect_dc <- maybeV dataConErr (lookupDataCon dc)
          let ntag = dataConTagZ vect_dc
              tag  = mkDataConTag dflags vect_dc
              fvs  = fvs_body `delDVarSetList` bndrs
          sel_tags  <- liftM (`App` sel) (builtin (selTags arity))
          lc        <- builtin liftingContext
          elems     <- builtin (selElements arity ntag)
          (vbndrs, vbody)
            <- vectBndrsIn bndrs
             . localV
             $ do
               { binds    <- mapM (pack_var (Var lc) sel_tags tag)
                           . filter isLocalId
                           $ dVarSetElems fvs
               ; traceVt "case alternative:" (ppr . deAnnotate $ body)
               ; (ve, le) <- vectExpr body
               ; return (ve, Case (elems `App` sel) lc lty
                             [(DEFAULT, [], (mkLets (concat binds) le))])
               }
                 
                 
                 
                 
                 
          let (vect_bndrs, lift_bndrs) = unzip vbndrs
          return (vect_dc, vect_bndrs, lift_bndrs, vbody)
      where
        dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
    proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
    mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
      
      
    pack_var len tags t v
      = do
        { r <- lookupVar_maybe v
        ; case r of
            Just (Local (vv, lv)) ->
              do
              { lv'  <- cloneVar lv
              ; expr <- packByTagPD (idType vv) (Var lv) len tags t
              ; updLEnv (\env -> env { local_vars = extendVarEnv (local_vars env) v (vv, lv') })
              ; return [(NonRec lv' expr)]
              }
            _ -> return []
        }
data VectAvoidInfo = VIParr       
                   | VISimple     
                   | VIComplex    
                   | VIEncaps     
                   | VIDict       
                   deriving (Eq, Show)
type CoreExprWithVectInfo = AnnExpr Id (DVarSet, VectAvoidInfo)
annExprType :: AnnExpr Var ann -> Type
annExprType = exprType . deAnnotate
vectAvoidInfoOf :: CoreExprWithVectInfo -> VectAvoidInfo
vectAvoidInfoOf ((_, vi), _) = vi
isVIParr :: CoreExprWithVectInfo -> Bool
isVIParr = (== VIParr) . vectAvoidInfoOf
isVIEncaps :: CoreExprWithVectInfo -> Bool
isVIEncaps = (== VIEncaps) . vectAvoidInfoOf
isVIDict :: CoreExprWithVectInfo -> Bool
isVIDict = (== VIDict) . vectAvoidInfoOf
unlessVIParr :: VectAvoidInfo -> VectAvoidInfo -> VectAvoidInfo
unlessVIParr _  VIParr = VIParr
unlessVIParr vi _      = vi
unlessVIParrExpr :: VectAvoidInfo -> CoreExprWithVectInfo -> VectAvoidInfo
infixl `unlessVIParrExpr`
unlessVIParrExpr e1 e2 = e1 `unlessVIParr` vectAvoidInfoOf e2
vectAvoidInfo :: VarSet -> CoreExprWithFVs -> VM CoreExprWithVectInfo
vectAvoidInfo pvs ce@(_, AnnVar v)
  = do
    { gpvs <- globalParallelVars
    ; vi <- if v `elemVarSet` pvs || v `elemDVarSet` gpvs
            then return VIParr
            else vectAvoidInfoTypeOf ce
    ; viTrace ce vi []
    ; when (vi == VIParr) $
        traceVt "  reason:" $ if v `elemVarSet` pvs  then text "local"  else
                              if v `elemDVarSet` gpvs then text "global" else text "parallel type"
    ; return ((fvs, vi), AnnVar v)
    }
  where
    fvs = freeVarsOf ce
vectAvoidInfo _pvs ce@(_, AnnLit lit)
  = do
    { vi <- vectAvoidInfoTypeOf ce
    ; viTrace ce vi []
    ; return ((fvs, vi), AnnLit lit)
    }
  where
    fvs = freeVarsOf ce
vectAvoidInfo pvs ce@(_, AnnApp e1 e2)
  = do
    { ceVI <- vectAvoidInfoTypeOf ce
    ; eVI1 <- vectAvoidInfo pvs e1
    ; eVI2 <- vectAvoidInfo pvs e2
    ; let vi = ceVI `unlessVIParrExpr` eVI1 `unlessVIParrExpr` eVI2
    
    ; return ((fvs, vi), AnnApp eVI1 eVI2)
    }
  where
    fvs = freeVarsOf ce
vectAvoidInfo pvs ce@(_, AnnLam var body)
  = do
    { bodyVI <- vectAvoidInfo pvs body
    ; varVI  <- vectAvoidInfoType $ varType var
    ; let vi = vectAvoidInfoOf bodyVI `unlessVIParr` varVI
    
    ; return ((fvs, vi), AnnLam var bodyVI)
    }
  where
    fvs = freeVarsOf ce
vectAvoidInfo pvs ce@(_, AnnLet (AnnNonRec var e) body)
  = do
    { ceVI       <- vectAvoidInfoTypeOf ce
    ; eVI        <- vectAvoidInfo pvs e
    ; isScalarTy <- isScalar $ varType var
    ; (bodyVI, vi) <- if isVIParr eVI && not isScalarTy
        then do 
        { bodyVI <- vectAvoidInfo (pvs `extendVarSet` var) body
        ; return (bodyVI, VIParr)
        }
        else do 
        { bodyVI <- vectAvoidInfo pvs body
        ; return (bodyVI, ceVI `unlessVIParrExpr` bodyVI)
        }
    
    ; return ((fvs, vi), AnnLet (AnnNonRec var eVI) bodyVI)
    }
  where
    fvs = freeVarsOf ce
vectAvoidInfo pvs ce@(_, AnnLet (AnnRec bnds) body)
  = do
    { ceVI         <- vectAvoidInfoTypeOf ce
    ; bndsVI       <- mapM (vectAvoidInfoBnd pvs) bnds
    ; parrBndrs    <- map fst <$> filterM isVIParrBnd bndsVI
    ; if not . null $ parrBndrs
      then do         
        { new_pvs <- filterM ((not <$>) . isScalar . varType) parrBndrs
        ; let extendedPvs = pvs `extendVarSetList` new_pvs
        ; bndsVI <- mapM (vectAvoidInfoBnd extendedPvs) bnds
        ; bodyVI <- vectAvoidInfo extendedPvs body
        
        ; return ((fvs, VIParr), AnnLet (AnnRec bndsVI) bodyVI)
        }
      else do         
        { bodyVI <- vectAvoidInfo pvs body
        ; let vi = ceVI `unlessVIParrExpr` bodyVI
        
        ; return ((fvs, vi), AnnLet (AnnRec bndsVI) bodyVI)
        }
    }
  where
    fvs = freeVarsOf ce
    vectAvoidInfoBnd pvs (var, e) = (var,) <$> vectAvoidInfo pvs e
    isVIParrBnd (var, eVI)
      = do
        { isScalarTy <- isScalar (varType var)
        ; return $ isVIParr eVI && not isScalarTy
        }
vectAvoidInfo pvs ce@(_, AnnCase e var ty alts)
  = do
    { ceVI           <- vectAvoidInfoTypeOf ce
    ; eVI            <- vectAvoidInfo pvs e
    ; altsVI         <- mapM (vectAvoidInfoAlt (isVIParr eVI)) alts
    ; let alteVIs = [eVI | (_, _, eVI) <- altsVI]
          vi      =  foldl unlessVIParrExpr ceVI (eVI:alteVIs)  
    
    ; return ((fvs, vi), AnnCase eVI var ty altsVI)
    }
  where
    fvs = freeVarsOf ce
    vectAvoidInfoAlt scrutIsPar (con, bndrs, e)
      = do
        { allScalar <- allScalarVarType bndrs
        ; let altPvs | scrutIsPar && not allScalar = pvs `extendVarSetList` bndrs
                     | otherwise                   = pvs
        ; (con, bndrs,) <$> vectAvoidInfo altPvs e
        }
vectAvoidInfo pvs ce@(_, AnnCast e (fvs_ann, ann))
  = do
    { eVI <- vectAvoidInfo pvs e
    ; return ((fvs, vectAvoidInfoOf eVI), AnnCast eVI ((freeVarsOfAnn fvs_ann, VISimple), ann))
    }
  where
    fvs = freeVarsOf ce
vectAvoidInfo pvs ce@(_, AnnTick tick e)
  = do
    { eVI <- vectAvoidInfo pvs e
    ; return ((fvs, vectAvoidInfoOf eVI), AnnTick tick eVI)
    }
  where
    fvs = freeVarsOf ce
vectAvoidInfo _pvs ce@(_, AnnType ty)
  = return ((fvs, VISimple), AnnType ty)
  where
    fvs = freeVarsOf ce
vectAvoidInfo _pvs ce@(_, AnnCoercion coe)
  = return ((fvs, VISimple), AnnCoercion coe)
  where
    fvs = freeVarsOf ce
vectAvoidInfoType :: Type -> VM VectAvoidInfo
vectAvoidInfoType ty
  | isPredTy ty
  = return VIDict
  | Just (arg, res) <- splitFunTy_maybe ty
  = do
    { argVI <- vectAvoidInfoType arg
    ; resVI <- vectAvoidInfoType res
    ; case (argVI, resVI) of
        (VISimple, VISimple) -> return VISimple   
        (_       , VIDict)   -> return VIDict
        _                    -> return $ VIComplex `unlessVIParr` argVI `unlessVIParr` resVI
    }
  | otherwise
  = do
    { parr <- maybeParrTy ty
    ; if parr
      then return VIParr
      else do
    { scalar <- isScalar ty
    ; if scalar
      then return VISimple
      else return VIComplex
    } }
vectAvoidInfoTypeOf :: AnnExpr Var ann -> VM VectAvoidInfo
vectAvoidInfoTypeOf = vectAvoidInfoType . annExprType
maybeParrTy :: Type -> VM Bool
maybeParrTy ty
    
  | Just ty'      <- coreView ty
  = (== VIParr) <$> vectAvoidInfoType ty'
    
  | Just (tc, ts) <- splitTyConApp_maybe ty
  = do
    { isParallel <- (tyConName tc `elemNameSet`) <$> globalParallelTyCons
    ; if isParallel
      then return True
      else or <$> mapM maybeParrTy ts
    }
  
maybeParrTy (ForAllTy _ ty) = maybeParrTy ty
maybeParrTy _               = return False
allScalarVarType :: [Var] -> VM Bool
allScalarVarType vs = and <$> mapM isScalarOrToplevel vs
  where
    isScalarOrToplevel v | isToplevel v = return True
                         | otherwise    = isScalar (varType v)
allScalarVarTypeSet :: DVarSet -> VM Bool
allScalarVarTypeSet = allScalarVarType . dVarSetElems
viTrace :: CoreExprWithFVs -> VectAvoidInfo -> [CoreExprWithVectInfo] -> VM ()
viTrace ce vi vTs
  = traceVt ("vect info: " ++ show vi ++ "[" ++
             (concat $ map ((++ " ") . show . vectAvoidInfoOf) vTs) ++ "]")
            (ppr $ deAnnotate ce)