module MkGraph
  ( CmmAGraph, CmmAGraphScoped, CgStmt(..)
  , (<*>), catAGraphs
  , mkLabel, mkMiddle, mkLast, outOfLine
  , lgraphOfAGraph, labelAGraph
  , stackStubExpr
  , mkNop, mkAssign, mkStore
  , mkUnsafeCall, mkFinalCall, mkCallReturnsTo
  , mkJumpReturnsTo
  , mkJump, mkJumpExtra
  , mkRawJump
  , mkCbranch, mkSwitch
  , mkReturn, mkComment, mkCallEntry, mkBranch
  , mkUnwind
  , copyInOflow, copyOutOflow
  , noExtraStack
  , toCall, Transfer(..)
  )
where
import GhcPrelude (($),Int,Bool,Eq(..)) 
import BlockId
import Cmm
import CmmCallConv
import CmmSwitch (SwitchTargets)
import Hoopl.Block
import Hoopl.Graph
import Hoopl.Label
import DynFlags
import FastString
import ForeignCall
import OrdList
import SMRep (ByteOff)
import UniqSupply
import Control.Monad
import Data.List
import Data.Maybe
#include "HsVersions.h"
type CmmAGraph = OrdList CgStmt
type CmmAGraphScoped = (CmmAGraph, CmmTickScope)
data CgStmt
  = CgLabel BlockId CmmTickScope
  | CgStmt  (CmmNode O O)
  | CgLast  (CmmNode O C)
  | CgFork  BlockId CmmAGraph CmmTickScope
flattenCmmAGraph :: BlockId -> CmmAGraphScoped -> CmmGraph
flattenCmmAGraph id (stmts_t, tscope) =
    CmmGraph { g_entry = id,
               g_graph = GMany NothingO body NothingO }
  where
  body = foldr addBlock emptyBody $ flatten id stmts_t tscope []
  
  
  
  
  
  
  flatten :: Label -> CmmAGraph -> CmmTickScope -> [Block CmmNode C C]
          -> [Block CmmNode C C]
  flatten id g tscope blocks
      = flatten1 (fromOL g) block' blocks
      where !block' = blockJoinHead (CmmEntry id tscope) emptyBlock
  
  
  
  
  flatten0 :: [CgStmt] -> [Block CmmNode C C] -> [Block CmmNode C C]
  flatten0 [] blocks = blocks
  flatten0 (CgLabel id tscope : stmts) blocks
    = flatten1 stmts block blocks
    where !block = blockJoinHead (CmmEntry id tscope) emptyBlock
  flatten0 (CgFork fork_id stmts_t tscope : rest) blocks
    = flatten fork_id stmts_t tscope $ flatten0 rest blocks
  flatten0 (CgLast _ : stmts) blocks = flatten0 stmts blocks
  flatten0 (CgStmt _ : stmts) blocks = flatten0 stmts blocks
  
  
  
  
  
  flatten1 :: [CgStmt] -> Block CmmNode C O
           -> [Block CmmNode C C] -> [Block CmmNode C C]
  
  
  
  
  
  
  flatten1 [] block blocks
    = blockJoinTail block (CmmBranch (entryLabel block)) : blocks
  flatten1 (CgLast stmt : stmts) block blocks
    = block' : flatten0 stmts blocks
    where !block' = blockJoinTail block stmt
  flatten1 (CgStmt stmt : stmts) block blocks
    = flatten1 stmts block' blocks
    where !block' = blockSnoc block stmt
  flatten1 (CgFork fork_id stmts_t tscope : rest) block blocks
    = flatten fork_id stmts_t tscope $ flatten1 rest block blocks
  
  
  flatten1 (CgLabel id tscp : stmts) block blocks
    = blockJoinTail block (CmmBranch id) :
      flatten1 stmts (blockJoinHead (CmmEntry id tscp) emptyBlock) blocks
(<*>)          :: CmmAGraph -> CmmAGraph -> CmmAGraph
(<*>)           = appOL
catAGraphs     :: [CmmAGraph] -> CmmAGraph
catAGraphs      = concatOL
mkLabel        :: BlockId -> CmmTickScope -> CmmAGraph
mkLabel bid scp = unitOL (CgLabel bid scp)
mkMiddle        :: CmmNode O O -> CmmAGraph
mkMiddle middle = unitOL (CgStmt middle)
mkLast         :: CmmNode O C -> CmmAGraph
mkLast last     = unitOL (CgLast last)
outOfLine      :: BlockId -> CmmAGraphScoped -> CmmAGraph
outOfLine l (c,s) = unitOL (CgFork l c s)
lgraphOfAGraph :: CmmAGraphScoped -> UniqSM CmmGraph
lgraphOfAGraph g = do
  u <- getUniqueM
  return (labelAGraph (mkBlockId u) g)
labelAGraph    :: BlockId -> CmmAGraphScoped -> CmmGraph
labelAGraph lbl ag = flattenCmmAGraph lbl ag
mkNop        :: CmmAGraph
mkNop         = nilOL
mkComment    :: FastString -> CmmAGraph
#if defined(DEBUG)
mkComment fs  = mkMiddle $ CmmComment fs
#else
mkComment _   = nilOL
#endif
mkAssign     :: CmmReg  -> CmmExpr -> CmmAGraph
mkAssign l (CmmReg r) | l == r  = mkNop
mkAssign l r  = mkMiddle $ CmmAssign l r
mkStore      :: CmmExpr -> CmmExpr -> CmmAGraph
mkStore  l r  = mkMiddle $ CmmStore  l r
mkJump          :: DynFlags -> Convention -> CmmExpr
                -> [CmmExpr]
                -> UpdFrameOffset
                -> CmmAGraph
mkJump dflags conv e actuals updfr_off =
  lastWithArgs dflags Jump Old conv actuals updfr_off $
    toCall e Nothing updfr_off 0
mkRawJump       :: DynFlags -> CmmExpr -> UpdFrameOffset -> [GlobalReg]
                -> CmmAGraph
mkRawJump dflags e updfr_off vols =
  lastWithArgs dflags Jump Old NativeNodeCall [] updfr_off $
    \arg_space _  -> toCall e Nothing updfr_off 0 arg_space vols
mkJumpExtra :: DynFlags -> Convention -> CmmExpr -> [CmmExpr]
                -> UpdFrameOffset -> [CmmExpr]
                -> CmmAGraph
mkJumpExtra dflags conv e actuals updfr_off extra_stack =
  lastWithArgsAndExtraStack dflags Jump Old conv actuals updfr_off extra_stack $
    toCall e Nothing updfr_off 0
mkCbranch       :: CmmExpr -> BlockId -> BlockId -> Maybe Bool -> CmmAGraph
mkCbranch pred ifso ifnot likely =
  mkLast (CmmCondBranch pred ifso ifnot likely)
mkSwitch        :: CmmExpr -> SwitchTargets -> CmmAGraph
mkSwitch e tbl   = mkLast $ CmmSwitch e tbl
mkReturn        :: DynFlags -> CmmExpr -> [CmmExpr] -> UpdFrameOffset
                -> CmmAGraph
mkReturn dflags e actuals updfr_off =
  lastWithArgs dflags Ret  Old NativeReturn actuals updfr_off $
    toCall e Nothing updfr_off 0
mkBranch        :: BlockId -> CmmAGraph
mkBranch bid     = mkLast (CmmBranch bid)
mkFinalCall   :: DynFlags
              -> CmmExpr -> CCallConv -> [CmmExpr] -> UpdFrameOffset
              -> CmmAGraph
mkFinalCall dflags f _ actuals updfr_off =
  lastWithArgs dflags Call Old NativeDirectCall actuals updfr_off $
    toCall f Nothing updfr_off 0
mkCallReturnsTo :: DynFlags -> CmmExpr -> Convention -> [CmmExpr]
                -> BlockId
                -> ByteOff
                -> UpdFrameOffset
                -> [CmmExpr]
                -> CmmAGraph
mkCallReturnsTo dflags f callConv actuals ret_lbl ret_off updfr_off extra_stack = do
  lastWithArgsAndExtraStack dflags Call (Young ret_lbl) callConv actuals
     updfr_off extra_stack $
       toCall f (Just ret_lbl) updfr_off ret_off
mkJumpReturnsTo :: DynFlags -> CmmExpr -> Convention -> [CmmExpr]
                -> BlockId
                -> ByteOff
                -> UpdFrameOffset
                -> CmmAGraph
mkJumpReturnsTo dflags f callConv actuals ret_lbl ret_off updfr_off  = do
  lastWithArgs dflags JumpRet (Young ret_lbl) callConv actuals updfr_off $
       toCall f (Just ret_lbl) updfr_off ret_off
mkUnsafeCall  :: ForeignTarget -> [CmmFormal] -> [CmmActual] -> CmmAGraph
mkUnsafeCall t fs as = mkMiddle $ CmmUnsafeForeignCall t fs as
mkUnwind     :: GlobalReg -> CmmExpr -> CmmAGraph
mkUnwind r e  = mkMiddle $ CmmUnwind [(r, Just e)]
stackStubExpr :: Width -> CmmExpr
stackStubExpr w = CmmLit (CmmInt 0 w)
copyInOflow  :: DynFlags -> Convention -> Area
             -> [CmmFormal]
             -> [CmmFormal]
             -> (Int, [GlobalReg], CmmAGraph)
copyInOflow dflags conv area formals extra_stk
  = (offset, gregs, catAGraphs $ map mkMiddle nodes)
  where (offset, gregs, nodes) = copyIn dflags conv area formals extra_stk
copyIn :: DynFlags -> Convention -> Area
       -> [CmmFormal]
       -> [CmmFormal]
       -> (ByteOff, [GlobalReg], [CmmNode O O])
copyIn dflags conv area formals extra_stk
  = (stk_size, [r | (_, RegisterParam r) <- args], map ci (stk_args ++ args))
  where
     ci (reg, RegisterParam r) =
          CmmAssign (CmmLocal reg) (CmmReg (CmmGlobal r))
     ci (reg, StackParam off) =
          CmmAssign (CmmLocal reg) (CmmLoad (CmmStackSlot area off) ty)
          where ty = localRegType reg
     init_offset = widthInBytes (wordWidth dflags) 
     (stk_off, stk_args) = assignStack dflags init_offset localRegType extra_stk
     (stk_size, args) = assignArgumentsPos dflags stk_off conv
                                           localRegType formals
data Transfer = Call | JumpRet | Jump | Ret deriving Eq
copyOutOflow :: DynFlags -> Convention -> Transfer -> Area -> [CmmExpr]
             -> UpdFrameOffset
             -> [CmmExpr] 
             -> (Int, [GlobalReg], CmmAGraph)
copyOutOflow dflags conv transfer area actuals updfr_off extra_stack_stuff
  = (stk_size, regs, graph)
  where
    (regs, graph) = foldr co ([], mkNop) (setRA ++ args ++ stack_params)
    co (v, RegisterParam r) (rs, ms)
       = (r:rs, mkAssign (CmmGlobal r) v <*> ms)
    co (v, StackParam off)  (rs, ms)
       = (rs, mkStore (CmmStackSlot area off) v <*> ms)
    (setRA, init_offset) =
      case area of
            Young id ->  
                         
                  case transfer of
                     Call ->
                       ([(CmmLit (CmmBlock id), StackParam init_offset)],
                       widthInBytes (wordWidth dflags))
                     JumpRet ->
                       ([],
                       widthInBytes (wordWidth dflags))
                     _other ->
                       ([], 0)
            Old -> ([], updfr_off)
    (extra_stack_off, stack_params) =
       assignStack dflags init_offset (cmmExprType dflags) extra_stack_stuff
    args :: [(CmmExpr, ParamLocation)]   
    (stk_size, args) = assignArgumentsPos dflags extra_stack_off conv
                                          (cmmExprType dflags) actuals
mkCallEntry :: DynFlags -> Convention -> [CmmFormal] -> [CmmFormal]
            -> (Int, [GlobalReg], CmmAGraph)
mkCallEntry dflags conv formals extra_stk
  = copyInOflow dflags conv Old formals extra_stk
lastWithArgs :: DynFlags -> Transfer -> Area -> Convention -> [CmmExpr]
             -> UpdFrameOffset
             -> (ByteOff -> [GlobalReg] -> CmmAGraph)
             -> CmmAGraph
lastWithArgs dflags transfer area conv actuals updfr_off last =
  lastWithArgsAndExtraStack dflags transfer area conv actuals
                            updfr_off noExtraStack last
lastWithArgsAndExtraStack :: DynFlags
             -> Transfer -> Area -> Convention -> [CmmExpr]
             -> UpdFrameOffset -> [CmmExpr]
             -> (ByteOff -> [GlobalReg] -> CmmAGraph)
             -> CmmAGraph
lastWithArgsAndExtraStack dflags transfer area conv actuals updfr_off
                          extra_stack last =
  copies <*> last outArgs regs
 where
  (outArgs, regs, copies) = copyOutOflow dflags conv transfer area actuals
                               updfr_off extra_stack
noExtraStack :: [CmmExpr]
noExtraStack = []
toCall :: CmmExpr -> Maybe BlockId -> UpdFrameOffset -> ByteOff
       -> ByteOff -> [GlobalReg]
       -> CmmAGraph
toCall e cont updfr_off res_space arg_space regs =
  mkLast $ CmmCall e cont regs arg_space res_space updfr_off