module Gibbon.Passes.RearrangeFree
  ( rearrangeFree ) where

import Prelude hiding (tail)
import Gibbon.DynFlags
import Gibbon.Common
import Gibbon.L4.Syntax

-- Ensure that any calls to `free` are the last thing in the program.
-- TODO: We should figure out a way to do this in Lower. Also `withTail`
-- _does_ end up duplicating some calls to `free`. Need to fix that.
rearrangeFree :: Prog -> PassM Prog
rearrangeFree :: Prog -> PassM Prog
rearrangeFree (Prog InfoTable
info_tbl SymTable
sym_tbl [FunDecl]
fundefs Maybe MainExp
mainExp) = do
  [FunDecl]
fundefs' <- (FunDecl -> PassM FunDecl) -> [FunDecl] -> PassM [FunDecl]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM FunDecl -> PassM FunDecl
rearrangeFreeFn [FunDecl]
fundefs
  Maybe MainExp
mainExp' <- case Maybe MainExp
mainExp of
                Just (PrintExp Tail
tail) -> do
                  MainExp -> Maybe MainExp
forall a. a -> Maybe a
Just (MainExp -> Maybe MainExp)
-> (Tail -> MainExp) -> Tail -> Maybe MainExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tail -> MainExp
PrintExp (Tail -> Maybe MainExp) -> PassM Tail -> PassM (Maybe MainExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Maybe (Tail -> Tail) -> Tail -> PassM Tail
rearrangeFreeExp Bool
True Maybe (Tail -> Tail)
forall a. Maybe a
Nothing Tail
tail
                Maybe MainExp
Nothing -> Maybe MainExp -> PassM (Maybe MainExp)
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe MainExp
forall a. Maybe a
Nothing
  Prog -> PassM Prog
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Prog -> PassM Prog) -> Prog -> PassM Prog
forall a b. (a -> b) -> a -> b
$ InfoTable -> SymTable -> [FunDecl] -> Maybe MainExp -> Prog
Prog InfoTable
info_tbl SymTable
sym_tbl [FunDecl]
fundefs' Maybe MainExp
mainExp'

rearrangeFreeFn :: FunDecl -> PassM FunDecl
rearrangeFreeFn :: FunDecl -> PassM FunDecl
rearrangeFreeFn f :: FunDecl
f@FunDecl{Tail
funBody :: Tail
funBody :: FunDecl -> Tail
funBody} = do
  Tail
bod' <- Bool -> Maybe (Tail -> Tail) -> Tail -> PassM Tail
rearrangeFreeExp Bool
False Maybe (Tail -> Tail)
forall a. Maybe a
Nothing Tail
funBody
  FunDecl -> PassM FunDecl
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDecl -> PassM FunDecl) -> FunDecl -> PassM FunDecl
forall a b. (a -> b) -> a -> b
$ FunDecl
f {funBody :: Tail
funBody = Tail
bod'}

rearrangeFreeExp :: Bool -> Maybe (Tail -> Tail) -> Tail -> PassM Tail
rearrangeFreeExp :: Bool -> Maybe (Tail -> Tail) -> Tail -> PassM Tail
rearrangeFreeExp Bool
is_main Maybe (Tail -> Tail)
frees Tail
tail =
  case Tail
tail of
    Tail
EndOfMain -> Tail -> PassM Tail
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Tail
tail
    LetPrimCallT [(Var, Ty)]
binds Prim
prim [Triv]
rands Tail
bod ->
      case Prim
prim of
        Prim
FreeBuffer -> do
            let clos :: Tail -> Tail
clos = case Maybe (Tail -> Tail)
frees of
                         Just Tail -> Tail
f  -> Tail -> Tail
f (Tail -> Tail) -> (Tail -> Tail) -> Tail -> Tail
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Var, Ty)] -> Prim -> [Triv] -> Tail -> Tail
LetPrimCallT [(Var, Ty)]
binds Prim
prim [Triv]
rands
                         Maybe (Tail -> Tail)
Nothing -> [(Var, Ty)] -> Prim -> [Triv] -> Tail -> Tail
LetPrimCallT [(Var, Ty)]
binds Prim
prim [Triv]
rands
            Tail
bod' <- Bool -> Maybe (Tail -> Tail) -> Tail -> PassM Tail
rearrangeFreeExp Bool
is_main ((Tail -> Tail) -> Maybe (Tail -> Tail)
forall a. a -> Maybe a
Just Tail -> Tail
clos) Tail
bod
            Tail -> PassM Tail
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Tail
bod'
            -- withTail (bod', undefined) (\trvs -> clos (RetValsT trvs))
        Prim
_ -> [(Var, Ty)] -> Prim -> [Triv] -> Tail -> Tail
LetPrimCallT [(Var, Ty)]
binds Prim
prim [Triv]
rands (Tail -> Tail) -> PassM Tail -> PassM Tail
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tail -> PassM Tail
go Tail
bod

    -- RetValsT{} -> return tail
    RetValsT [Triv]
ls -> do DynFlags
dflags <- PassM DynFlags
forall (m :: * -> *). MonadReader Config m => m DynFlags
getDynFlags
                      let countRegions :: Bool
countRegions = (GeneralFlag -> DynFlags -> Bool
gopt GeneralFlag
Opt_CountAllRegions DynFlags
dflags) Bool -> Bool -> Bool
|| (GeneralFlag -> DynFlags -> Bool
gopt GeneralFlag
Opt_CountParRegions DynFlags
dflags)
                          print_reg_count :: Tail -> Tail
                          print_reg_count :: Tail -> Tail
print_reg_count = if Bool
is_main Bool -> Bool -> Bool
&& Bool
countRegions
                                              then ([(Var, Ty)] -> Prim -> [Triv] -> Tail -> Tail
LetPrimCallT [] Prim
PrintRegionCount [])
                                              else Tail -> Tail
forall a. a -> a
id
                      case Maybe (Tail -> Tail)
frees of
                        Just Tail -> Tail
f  -> Tail -> PassM Tail
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Tail -> PassM Tail) -> Tail -> PassM Tail
forall a b. (a -> b) -> a -> b
$ Tail -> Tail
f (Tail -> Tail) -> Tail -> Tail
forall a b. (a -> b) -> a -> b
$ Tail -> Tail
print_reg_count ([Triv] -> Tail
RetValsT [Triv]
ls)
                        Maybe (Tail -> Tail)
Nothing -> Tail -> PassM Tail
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Tail -> Tail
print_reg_count Tail
tail)

    -- Straightforward recursion
    Switch Var
lbl Triv
trv Alts
alts Maybe Tail
bod_maybe -> do
      Alts
alts' <- case Alts
alts of
                TagAlts [(Tag, Tail)]
ls -> do
                  [(Tag, Tail)]
ls' <- ((Tag, Tail) -> PassM (Tag, Tail))
-> [(Tag, Tail)] -> PassM [(Tag, Tail)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(Tag
x,Tail
tl) -> (Tag
x,) (Tail -> (Tag, Tail)) -> PassM Tail -> PassM (Tag, Tail)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tail -> PassM Tail
go Tail
tl) [(Tag, Tail)]
ls
                  Alts -> PassM Alts
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Alts -> PassM Alts) -> Alts -> PassM Alts
forall a b. (a -> b) -> a -> b
$ [(Tag, Tail)] -> Alts
TagAlts [(Tag, Tail)]
ls'
                IntAlts [(Int64, Tail)]
ls -> do
                  [(Int64, Tail)]
ls' <- ((Int64, Tail) -> PassM (Int64, Tail))
-> [(Int64, Tail)] -> PassM [(Int64, Tail)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(Int64
x,Tail
tl) -> (Int64
x,) (Tail -> (Int64, Tail)) -> PassM Tail -> PassM (Int64, Tail)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tail -> PassM Tail
go Tail
tl) [(Int64, Tail)]
ls
                  Alts -> PassM Alts
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Alts -> PassM Alts) -> Alts -> PassM Alts
forall a b. (a -> b) -> a -> b
$ [(Int64, Tail)] -> Alts
IntAlts [(Int64, Tail)]
ls'
      Tail -> PassM Tail
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Tail -> PassM Tail) -> Tail -> PassM Tail
forall a b. (a -> b) -> a -> b
$ Var -> Triv -> Alts -> Maybe Tail -> Tail
Switch Var
lbl Triv
trv Alts
alts' Maybe Tail
bod_maybe
    Goto{} -> Tail -> PassM Tail
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Tail
tail
    AssnValsT{} -> Tail -> PassM Tail
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Tail
tail
    LetCallT Bool
async [(Var, Ty)]
binds Var
rator [Triv]
rands Tail
bod ->
      Bool -> [(Var, Ty)] -> Var -> [Triv] -> Tail -> Tail
LetCallT Bool
async [(Var, Ty)]
binds Var
rator [Triv]
rands (Tail -> Tail) -> PassM Tail -> PassM Tail
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
        Tail -> PassM Tail
go Tail
bod
    LetTrivT (Var, Ty, Triv)
bnd Tail
bod ->
      (Var, Ty, Triv) -> Tail -> Tail
LetTrivT (Var, Ty, Triv)
bnd (Tail -> Tail) -> PassM Tail -> PassM Tail
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tail -> PassM Tail
go Tail
bod
    LetIfT [(Var, Ty)]
binds (Triv
trv,Tail
tl1,Tail
tl2) Tail
bod -> do
      Tail
tl1' <- Tail -> PassM Tail
go Tail
tl1
      Tail
tl2' <- Tail -> PassM Tail
go Tail
tl2
      [(Var, Ty)] -> (Triv, Tail, Tail) -> Tail -> Tail
LetIfT [(Var, Ty)]
binds (Triv
trv, Tail
tl1', Tail
tl2') (Tail -> Tail) -> PassM Tail -> PassM Tail
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
        Tail -> PassM Tail
go Tail
bod
    LetUnpackT [(Var, Ty)]
binds Var
ptr Tail
bod ->
      [(Var, Ty)] -> Var -> Tail -> Tail
LetUnpackT [(Var, Ty)]
binds Var
ptr (Tail -> Tail) -> PassM Tail -> PassM Tail
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tail -> PassM Tail
go Tail
bod
    LetAllocT Var
lhs [(Ty, Triv)]
vals Tail
bod -> do
      Var -> [(Ty, Triv)] -> Tail -> Tail
LetAllocT Var
lhs [(Ty, Triv)]
vals (Tail -> Tail) -> PassM Tail -> PassM Tail
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tail -> PassM Tail
go Tail
bod
    LetAvailT [Var]
vs Tail
bod -> do
      [Var] -> Tail -> Tail
LetAvailT [Var]
vs (Tail -> Tail) -> PassM Tail -> PassM Tail
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tail -> PassM Tail
go Tail
bod
    IfT Triv
tst Tail
con Tail
els ->
      Triv -> Tail -> Tail -> Tail
IfT Triv
tst (Tail -> Tail -> Tail) -> PassM Tail -> PassM (Tail -> Tail)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tail -> PassM Tail
go Tail
con PassM (Tail -> Tail) -> PassM Tail -> PassM Tail
forall a b. PassM (a -> b) -> PassM a -> PassM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Tail -> PassM Tail
go Tail
els
    ErrT{} -> Tail -> PassM Tail
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Tail
tail
    LetTimedT Bool
isIter [(Var, Ty)]
binds Tail
timed Tail
bod ->
      Bool -> [(Var, Ty)] -> Tail -> Tail -> Tail
LetTimedT Bool
isIter [(Var, Ty)]
binds Tail
timed (Tail -> Tail) -> PassM Tail -> PassM Tail
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
        Tail -> PassM Tail
go Tail
bod
    LetArenaT Var
v Tail
bod ->
      Var -> Tail -> Tail
LetArenaT Var
v (Tail -> Tail) -> PassM Tail -> PassM Tail
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tail -> PassM Tail
go Tail
bod
    TailCall{} -> Tail -> PassM Tail
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Tail
tail

  where go :: Tail -> PassM Tail
go = Bool -> Maybe (Tail -> Tail) -> Tail -> PassM Tail
rearrangeFreeExp Bool
is_main Maybe (Tail -> Tail)
frees