module Gibbon.Passes.AddTraversals
  (addTraversals, needsTraversalCase) where

import Control.Monad ( forM, when )
import qualified Data.List as L
import Data.Map as M
import Data.Set as S

import Gibbon.Common
import Gibbon.DynFlags
import Gibbon.Passes.InferEffects ( inferExp )
import Gibbon.L1.Syntax hiding (StartOfPkdCursor)
import Gibbon.L2.Syntax as L2

--------------------------------------------------------------------------------

{- Note [Adding dummy traversals]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

..TODO..

-}

type Deps = M.Map LocVar LocVar

-- Maps a location to a region
type RegEnv = M.Map LocVar Var

addTraversals :: Prog2 -> PassM Prog2
addTraversals :: Prog2 -> PassM Prog2
addTraversals prg :: Prog2
prg@Prog{DDefs (TyOf Exp2)
ddefs :: DDefs (TyOf Exp2)
ddefs :: forall ex. Prog ex -> DDefs (TyOf ex)
ddefs,FunDefs Exp2
fundefs :: FunDefs Exp2
fundefs :: forall ex. Prog ex -> FunDefs ex
fundefs,Maybe (Exp2, TyOf Exp2)
mainExp :: Maybe (Exp2, TyOf Exp2)
mainExp :: forall ex. Prog ex -> Maybe (ex, TyOf ex)
mainExp} = do
  FunDefs Exp2
funs <- (FunDef2 -> PassM FunDef2) -> FunDefs Exp2 -> PassM (FunDefs Exp2)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Map Var a -> m (Map Var b)
mapM (DDefs Ty2 -> FunDefs Exp2 -> FunDef2 -> PassM FunDef2
addTraversalsFn DDefs (TyOf Exp2)
DDefs Ty2
ddefs FunDefs Exp2
fundefs) FunDefs Exp2
fundefs
  Maybe (Exp2, Ty2)
mainExp' <-
    case Maybe (Exp2, TyOf Exp2)
mainExp of
      Just (Exp2
ex,TyOf Exp2
ty) -> (Exp2, Ty2) -> Maybe (Exp2, Ty2)
forall a. a -> Maybe a
Just ((Exp2, Ty2) -> Maybe (Exp2, Ty2))
-> (Exp2 -> (Exp2, Ty2)) -> Exp2 -> Maybe (Exp2, Ty2)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (,TyOf Exp2
Ty2
ty) (Exp2 -> Maybe (Exp2, Ty2))
-> PassM Exp2 -> PassM (Maybe (Exp2, Ty2))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DDefs Ty2
-> FunDefs Exp2
-> Env2 Ty2
-> RegEnv
-> String
-> Exp2
-> PassM Exp2
addTraversalsExp DDefs (TyOf Exp2)
DDefs Ty2
ddefs FunDefs Exp2
fundefs Env2 Ty2
forall a. Env2 a
emptyEnv2 RegEnv
forall k a. Map k a
M.empty String
"mainExp" Exp2
ex
      Maybe (Exp2, TyOf Exp2)
Nothing -> Maybe (Exp2, Ty2) -> PassM (Maybe (Exp2, Ty2))
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Exp2, Ty2)
forall a. Maybe a
Nothing
  Prog2 -> PassM Prog2
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Prog2
prg { ddefs :: DDefs (TyOf Exp2)
ddefs = DDefs (TyOf Exp2)
ddefs
             , fundefs :: FunDefs Exp2
fundefs = FunDefs Exp2
funs
             , mainExp :: Maybe (Exp2, TyOf Exp2)
mainExp = Maybe (Exp2, TyOf Exp2)
Maybe (Exp2, Ty2)
mainExp'
             }

addTraversalsFn :: DDefs Ty2 -> FunDefs2 -> FunDef2 -> PassM FunDef2
addTraversalsFn :: DDefs Ty2 -> FunDefs Exp2 -> FunDef2 -> PassM FunDef2
addTraversalsFn DDefs Ty2
ddefs FunDefs Exp2
fundefs f :: FunDef2
f@FunDef{Var
funName :: Var
funName :: forall ex. FunDef ex -> Var
funName, [Var]
funArgs :: [Var]
funArgs :: forall ex. FunDef ex -> [Var]
funArgs, ArrowTy (TyOf Exp2)
funTy :: ArrowTy (TyOf Exp2)
funTy :: forall ex. FunDef ex -> ArrowTy (TyOf ex)
funTy, Exp2
funBody :: Exp2
funBody :: forall ex. FunDef ex -> ex
funBody} = do
    let inlocs :: [Var]
inlocs = ArrowTy2 Ty2 -> [Var]
forall ty2. ArrowTy2 ty2 -> [Var]
inLocVars ArrowTy (TyOf Exp2)
ArrowTy2 Ty2
funTy
        eff :: Set Effect
eff = ArrowTy2 Ty2 -> Set Effect
forall ty2. ArrowTy2 ty2 -> Set Effect
arrEffs ArrowTy (TyOf Exp2)
ArrowTy2 Ty2
funTy
    if Set Var -> Bool
forall a. Set a -> Bool
S.null (([Var] -> Set Var
forall a. Ord a => [a] -> Set a
S.fromList [Var]
inlocs) Set Var -> Set Var -> Set Var
forall a. Ord a => Set a -> Set a -> Set a
`S.difference` ((Effect -> Var) -> Set Effect -> Set Var
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map (\(Traverse Var
v) -> Var
v) Set Effect
eff)) Bool -> Bool -> Bool
&& Bool -> Bool
not (ArrowTy2 Ty2 -> Bool
forall ty2. ArrowTy2 ty2 -> Bool
hasParallelism ArrowTy (TyOf Exp2)
ArrowTy2 Ty2
funTy)
      then FunDef2 -> PassM FunDef2
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return FunDef2
f
      else do
        let funenv :: TyEnv (ArrowTy (TyOf Exp2))
funenv = FunDefs Exp2 -> TyEnv (ArrowTy (TyOf Exp2))
forall a. FunDefs a -> TyEnv (ArrowTy (TyOf a))
initFunEnv FunDefs Exp2
fundefs
            tyenv :: Map Var Ty2
tyenv = [(Var, Ty2)] -> Map Var Ty2
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Var, Ty2)] -> Map Var Ty2) -> [(Var, Ty2)] -> Map Var Ty2
forall a b. (a -> b) -> a -> b
$ [Var] -> [Ty2] -> [(Var, Ty2)]
forall a b.
(Show a, Show b, HasCallStack) =>
[a] -> [b] -> [(a, b)]
fragileZip [Var]
funArgs (ArrowTy Ty2 -> [Ty2]
forall ty. FunctionTy ty => ArrowTy ty -> [ty]
inTys ArrowTy (TyOf Exp2)
ArrowTy Ty2
funTy)
            env2 :: Env2 Ty2
env2 = Map Var Ty2 -> TyEnv (ArrowTy Ty2) -> Env2 Ty2
forall a. TyEnv a -> TyEnv (ArrowTy a) -> Env2 a
Env2 Map Var Ty2
tyenv TyEnv (ArrowTy (TyOf Exp2))
TyEnv (ArrowTy Ty2)
funenv
            renv :: RegEnv
renv = [(Var, Var)] -> RegEnv
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Var, Var)] -> RegEnv) -> [(Var, Var)] -> RegEnv
forall a b. (a -> b) -> a -> b
$ (LRM -> (Var, Var)) -> [LRM] -> [(Var, Var)]
forall a b. (a -> b) -> [a] -> [b]
L.map (\LRM
lrm -> (LRM -> Var
lrmLoc LRM
lrm, Region -> Var
regionToVar (LRM -> Region
lrmReg LRM
lrm)))
                                      (ArrowTy2 Ty2 -> [LRM]
forall ty2. ArrowTy2 ty2 -> [LRM]
locVars ArrowTy (TyOf Exp2)
ArrowTy2 Ty2
funTy)
        Exp2
bod' <- DDefs Ty2
-> FunDefs Exp2
-> Env2 Ty2
-> RegEnv
-> String
-> Exp2
-> PassM Exp2
addTraversalsExp DDefs Ty2
ddefs FunDefs Exp2
fundefs Env2 Ty2
env2 RegEnv
renv (Var -> String
fromVar Var
funName) Exp2
funBody
        FunDef2 -> PassM FunDef2
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef2 -> PassM FunDef2) -> FunDef2 -> PassM FunDef2
forall a b. (a -> b) -> a -> b
$ FunDef2
f {funBody :: Exp2
funBody = Exp2
bod'}

-- Generate traversals for the first (n-1) packed elements
addTraversalsExp :: DDefs Ty2 -> FunDefs2 -> Env2 Ty2 -> RegEnv -> String -> Exp2 -> PassM Exp2
addTraversalsExp :: DDefs Ty2
-> FunDefs Exp2
-> Env2 Ty2
-> RegEnv
-> String
-> Exp2
-> PassM Exp2
addTraversalsExp DDefs Ty2
ddefs FunDefs Exp2
fundefs Env2 Ty2
env2 RegEnv
renv String
context Exp2
ex =
  case Exp2
ex of
    CaseE scrt :: Exp2
scrt@(VarE Var
sv) [(String, [(Var, Var)], Exp2)]
brs -> do
        let PackedTy String
_tycon Var
tyloc = Var -> Env2 Ty2 -> Ty2
forall a. Out a => Var -> Env2 a -> a
lookupVEnv Var
sv Env2 Ty2
env2
            reg :: Var
reg = RegEnv
renv RegEnv -> Var -> Var
forall a b.
(Ord a, Out a, Out b, Show a, HasCallStack) =>
Map a b -> a -> b
# Var
tyloc
        Exp2 -> [(String, [(Var, Var)], Exp2)] -> Exp2
forall (ext :: * -> * -> *) loc dec.
PreExp ext loc dec
-> [(String, [(Var, loc)], PreExp ext loc dec)]
-> PreExp ext loc dec
CaseE Exp2
scrt ([(String, [(Var, Var)], Exp2)] -> Exp2)
-> PassM [(String, [(Var, Var)], Exp2)] -> PassM Exp2
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((String, [(Var, Var)], Exp2)
 -> PassM (String, [(Var, Var)], Exp2))
-> [(String, [(Var, Var)], Exp2)]
-> PassM [(String, [(Var, Var)], Exp2)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Var
-> (String, [(Var, Var)], Exp2)
-> PassM (String, [(Var, Var)], Exp2)
docase Var
reg) [(String, [(Var, Var)], Exp2)]
brs

    CaseE Exp2
scrt [(String, [(Var, Var)], Exp2)]
_ -> String -> PassM Exp2
forall a. HasCallStack => String -> a
error (String -> PassM Exp2) -> String -> PassM Exp2
forall a b. (a -> b) -> a -> b
$ String
"addTraversalsExp: Scrutinee is not flat " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Exp2 -> String
forall a. Out a => a -> String
sdoc Exp2
scrt

    -- standard recursion here
    VarE{}    -> Exp2 -> PassM Exp2
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp2
ex
    LitE{}    -> Exp2 -> PassM Exp2
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp2
ex
    CharE{}   -> Exp2 -> PassM Exp2
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp2
ex
    FloatE{}  -> Exp2 -> PassM Exp2
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp2
ex
    LitSymE{} -> Exp2 -> PassM Exp2
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp2
ex
    AppE Var
f [Var]
locs [Exp2]
args -> Var -> [Var] -> [Exp2] -> Exp2
forall (ext :: * -> * -> *) loc dec.
Var -> [loc] -> [PreExp ext loc dec] -> PreExp ext loc dec
AppE Var
f [Var]
locs ([Exp2] -> Exp2) -> PassM [Exp2] -> PassM Exp2
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp2 -> PassM Exp2) -> [Exp2] -> PassM [Exp2]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Exp2 -> PassM Exp2
go [Exp2]
args
    PrimAppE Prim Ty2
f [Exp2]
args  -> Prim Ty2 -> [Exp2] -> Exp2
forall (ext :: * -> * -> *) loc dec.
Prim dec -> [PreExp ext loc dec] -> PreExp ext loc dec
PrimAppE Prim Ty2
f ([Exp2] -> Exp2) -> PassM [Exp2] -> PassM Exp2
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp2 -> PassM Exp2) -> [Exp2] -> PassM [Exp2]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Exp2 -> PassM Exp2
go [Exp2]
args
    WithArenaE Var
v Exp2
e -> Var -> Exp2 -> Exp2
forall (ext :: * -> * -> *) loc dec.
Var -> PreExp ext loc dec -> PreExp ext loc dec
WithArenaE Var
v (Exp2 -> Exp2) -> PassM Exp2 -> PassM Exp2
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DDefs Ty2
-> FunDefs Exp2
-> Env2 Ty2
-> RegEnv
-> String
-> Exp2
-> PassM Exp2
addTraversalsExp DDefs Ty2
ddefs FunDefs Exp2
fundefs (Var -> Ty2 -> Env2 Ty2 -> Env2 Ty2
forall a. Var -> a -> Env2 a -> Env2 a
extendVEnv Var
v Ty2
forall loc. UrTy loc
ArenaTy Env2 Ty2
env2) RegEnv
renv String
context Exp2
e
    LetE (Var
v,[Var]
loc,Ty2
ty,Exp2
rhs) Exp2
bod -> do
      (Var, [Var], Ty2, Exp2) -> Exp2 -> Exp2
forall (ext :: * -> * -> *) loc dec.
(Var, [loc], dec, PreExp ext loc dec)
-> PreExp ext loc dec -> PreExp ext loc dec
LetE ((Var, [Var], Ty2, Exp2) -> Exp2 -> Exp2)
-> (Exp2 -> (Var, [Var], Ty2, Exp2)) -> Exp2 -> Exp2 -> Exp2
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Var
v,[Var]
loc,Ty2
ty,) (Exp2 -> Exp2 -> Exp2) -> PassM Exp2 -> PassM (Exp2 -> Exp2)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp2 -> PassM Exp2
go Exp2
rhs PassM (Exp2 -> Exp2) -> PassM Exp2 -> PassM Exp2
forall a b. PassM (a -> b) -> PassM a -> PassM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
        DDefs Ty2
-> FunDefs Exp2
-> Env2 Ty2
-> RegEnv
-> String
-> Exp2
-> PassM Exp2
addTraversalsExp DDefs Ty2
ddefs FunDefs Exp2
fundefs (Var -> Ty2 -> Env2 Ty2 -> Env2 Ty2
forall a. Var -> a -> Env2 a -> Env2 a
extendVEnv Var
v Ty2
ty Env2 Ty2
env2) RegEnv
renv String
context Exp2
bod
    IfE Exp2
a Exp2
b Exp2
c  -> Exp2 -> Exp2 -> Exp2 -> Exp2
forall (ext :: * -> * -> *) loc dec.
PreExp ext loc dec
-> PreExp ext loc dec -> PreExp ext loc dec -> PreExp ext loc dec
IfE (Exp2 -> Exp2 -> Exp2 -> Exp2)
-> PassM Exp2 -> PassM (Exp2 -> Exp2 -> Exp2)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp2 -> PassM Exp2
go Exp2
a PassM (Exp2 -> Exp2 -> Exp2) -> PassM Exp2 -> PassM (Exp2 -> Exp2)
forall a b. PassM (a -> b) -> PassM a -> PassM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp2 -> PassM Exp2
go Exp2
b PassM (Exp2 -> Exp2) -> PassM Exp2 -> PassM Exp2
forall a b. PassM (a -> b) -> PassM a -> PassM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp2 -> PassM Exp2
go Exp2
c
    MkProdE [Exp2]
xs -> [Exp2] -> Exp2
forall (ext :: * -> * -> *) loc dec.
[PreExp ext loc dec] -> PreExp ext loc dec
MkProdE ([Exp2] -> Exp2) -> PassM [Exp2] -> PassM Exp2
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp2 -> PassM Exp2) -> [Exp2] -> PassM [Exp2]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Exp2 -> PassM Exp2
go [Exp2]
xs
    ProjE Int
i Exp2
e  -> Int -> Exp2 -> Exp2
forall (ext :: * -> * -> *) loc dec.
Int -> PreExp ext loc dec -> PreExp ext loc dec
ProjE Int
i (Exp2 -> Exp2) -> PassM Exp2 -> PassM Exp2
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp2 -> PassM Exp2
go Exp2
e
    DataConE{} -> Exp2 -> PassM Exp2
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp2
ex
    TimeIt Exp2
e Ty2
ty Bool
b -> do
      Exp2
e' <- Exp2 -> PassM Exp2
go Exp2
e
      Exp2 -> PassM Exp2
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp2 -> PassM Exp2) -> Exp2 -> PassM Exp2
forall a b. (a -> b) -> a -> b
$ Exp2 -> Ty2 -> Bool -> Exp2
forall (ext :: * -> * -> *) loc dec.
PreExp ext loc dec -> dec -> Bool -> PreExp ext loc dec
TimeIt Exp2
e' Ty2
ty Bool
b
    SpawnE{} -> Exp2 -> PassM Exp2
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp2
ex -- error "addTraversalsExp: Cannot compile SpawnE"
    Exp2
SyncE    -> Exp2 -> PassM Exp2
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp2
ex -- error "addTraversalsExp: Cannot compile SyncE"
    Ext E2Ext Var Ty2
ext ->
      case E2Ext Var Ty2
ext of
        LetRegionE Region
reg RegionSize
sz Maybe RegionType
ty Exp2
bod -> E2Ext Var Ty2 -> Exp2
forall (ext :: * -> * -> *) loc dec.
ext loc dec -> PreExp ext loc dec
Ext (E2Ext Var Ty2 -> Exp2) -> (Exp2 -> E2Ext Var Ty2) -> Exp2 -> Exp2
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Region -> RegionSize -> Maybe RegionType -> Exp2 -> E2Ext Var Ty2
forall loc dec.
Region
-> RegionSize -> Maybe RegionType -> E2 loc dec -> E2Ext loc dec
LetRegionE Region
reg RegionSize
sz Maybe RegionType
ty (Exp2 -> Exp2) -> PassM Exp2 -> PassM Exp2
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp2 -> PassM Exp2
go Exp2
bod
        LetParRegionE Region
reg RegionSize
sz Maybe RegionType
ty Exp2
bod -> E2Ext Var Ty2 -> Exp2
forall (ext :: * -> * -> *) loc dec.
ext loc dec -> PreExp ext loc dec
Ext (E2Ext Var Ty2 -> Exp2) -> (Exp2 -> E2Ext Var Ty2) -> Exp2 -> Exp2
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Region -> RegionSize -> Maybe RegionType -> Exp2 -> E2Ext Var Ty2
forall loc dec.
Region
-> RegionSize -> Maybe RegionType -> E2 loc dec -> E2Ext loc dec
LetParRegionE Region
reg RegionSize
sz Maybe RegionType
ty (Exp2 -> Exp2) -> PassM Exp2 -> PassM Exp2
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp2 -> PassM Exp2
go Exp2
bod
        L2.StartOfPkdCursor Var
cur -> Exp2 -> PassM Exp2
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp2 -> PassM Exp2) -> Exp2 -> PassM Exp2
forall a b. (a -> b) -> a -> b
$ E2Ext Var Ty2 -> Exp2
forall (ext :: * -> * -> *) loc dec.
ext loc dec -> PreExp ext loc dec
Ext (E2Ext Var Ty2 -> Exp2) -> E2Ext Var Ty2 -> Exp2
forall a b. (a -> b) -> a -> b
$ Var -> E2Ext Var Ty2
forall loc dec. Var -> E2Ext loc dec
L2.StartOfPkdCursor Var
cur
        LetLocE Var
loc PreLocExp Var
FreeLE  Exp2
bod ->
          E2Ext Var Ty2 -> Exp2
forall (ext :: * -> * -> *) loc dec.
ext loc dec -> PreExp ext loc dec
Ext (E2Ext Var Ty2 -> Exp2) -> (Exp2 -> E2Ext Var Ty2) -> Exp2 -> Exp2
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Var -> PreLocExp Var -> Exp2 -> E2Ext Var Ty2
forall loc dec. Var -> PreLocExp loc -> E2 loc dec -> E2Ext loc dec
LetLocE Var
loc PreLocExp Var
forall loc. PreLocExp loc
FreeLE (Exp2 -> Exp2) -> PassM Exp2 -> PassM Exp2
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
            DDefs Ty2
-> FunDefs Exp2
-> Env2 Ty2
-> RegEnv
-> String
-> Exp2
-> PassM Exp2
addTraversalsExp DDefs Ty2
ddefs FunDefs Exp2
fundefs Env2 Ty2
env2 RegEnv
renv String
context Exp2
bod
        LetLocE Var
loc PreLocExp Var
locexp  Exp2
bod ->
          let reg :: Var
reg = case PreLocExp Var
locexp of
                      StartOfRegionLE Region
r  -> Region -> Var
regionToVar Region
r
                      InRegionLE Region
r -> Region -> Var
regionToVar Region
r
                      AfterConstantLE Int
_ Var
lc   -> RegEnv
renv RegEnv -> Var -> Var
forall a b.
(Ord a, Out a, Out b, Show a, HasCallStack) =>
Map a b -> a -> b
# Var
lc
                      AfterVariableLE Var
_ Var
lc Bool
_ -> RegEnv
renv RegEnv -> Var -> Var
forall a b.
(Ord a, Out a, Out b, Show a, HasCallStack) =>
Map a b -> a -> b
# Var
lc
                      FromEndLE Var
lc           -> RegEnv
renv RegEnv -> Var -> Var
forall a b.
(Ord a, Out a, Out b, Show a, HasCallStack) =>
Map a b -> a -> b
# Var
lc -- TODO: This needs to be fixed
          in E2Ext Var Ty2 -> Exp2
forall (ext :: * -> * -> *) loc dec.
ext loc dec -> PreExp ext loc dec
Ext (E2Ext Var Ty2 -> Exp2) -> (Exp2 -> E2Ext Var Ty2) -> Exp2 -> Exp2
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Var -> PreLocExp Var -> Exp2 -> E2Ext Var Ty2
forall loc dec. Var -> PreLocExp loc -> E2 loc dec -> E2Ext loc dec
LetLocE Var
loc PreLocExp Var
locexp (Exp2 -> Exp2) -> PassM Exp2 -> PassM Exp2
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
               DDefs Ty2
-> FunDefs Exp2
-> Env2 Ty2
-> RegEnv
-> String
-> Exp2
-> PassM Exp2
addTraversalsExp DDefs Ty2
ddefs FunDefs Exp2
fundefs Env2 Ty2
env2 (Var -> Var -> RegEnv -> RegEnv
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Var
loc Var
reg RegEnv
renv) String
context Exp2
bod
        E2Ext Var Ty2
_ -> Exp2 -> PassM Exp2
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp2
ex
    MapE{}  -> String -> PassM Exp2
forall a. HasCallStack => String -> a
error String
"addTraversalsExp: TODO MapE"
    FoldE{} -> String -> PassM Exp2
forall a. HasCallStack => String -> a
error String
"addTraversalsExp: TODO FoldE"

  where
    go :: Exp2 -> PassM Exp2
go = DDefs Ty2
-> FunDefs Exp2
-> Env2 Ty2
-> RegEnv
-> String
-> Exp2
-> PassM Exp2
addTraversalsExp DDefs Ty2
ddefs FunDefs Exp2
fundefs Env2 Ty2
env2 RegEnv
renv String
context

    docase :: Var
-> (String, [(Var, Var)], Exp2)
-> PassM (String, [(Var, Var)], Exp2)
docase Var
reg (String
dcon,[(Var, Var)]
vlocs,Exp2
rhs) = do
      let ([Var]
vars,[Var]
locs) = [(Var, Var)] -> ([Var], [Var])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Var, Var)]
vlocs
          env21 :: Env2 Ty2
env21 = HasCallStack =>
String -> DDefs Ty2 -> [Var] -> [Var] -> Env2 Ty2 -> Env2 Ty2
String -> DDefs Ty2 -> [Var] -> [Var] -> Env2 Ty2 -> Env2 Ty2
extendPatternMatchEnv String
dcon DDefs Ty2
ddefs [Var]
vars [Var]
locs Env2 Ty2
env2
          renv1 :: RegEnv
renv1 = (Var -> RegEnv -> RegEnv) -> RegEnv -> [Var] -> RegEnv
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
L.foldr (\Var
lc RegEnv
acc -> Var -> Var -> RegEnv -> RegEnv
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Var
lc Var
reg RegEnv
acc) RegEnv
renv [Var]
locs
          needs_traversal :: Maybe [(Var, Var)]
needs_traversal = DDefs Ty2
-> FunDefs Exp2
-> Env2 Ty2
-> (String, [(Var, Var)], Exp2)
-> Maybe [(Var, Var)]
needsTraversalCase DDefs Ty2
ddefs FunDefs Exp2
fundefs Env2 Ty2
env21 (String
dcon,[(Var, Var)]
vlocs,Exp2
rhs)
      case Maybe [(Var, Var)]
needs_traversal of
        Maybe [(Var, Var)]
Nothing -> (String, [(Var, Var)], Exp2) -> PassM (String, [(Var, Var)], Exp2)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String
dcon, [(Var, Var)]
vlocs, Exp2
rhs)
        Just [(Var, Var)]
ls -> do
          Bool
dump_op <- DebugFlag -> DynFlags -> Bool
dopt DebugFlag
Opt_D_Dump_Repair (DynFlags -> Bool) -> PassM DynFlags -> PassM Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PassM DynFlags
forall (m :: * -> *). MonadReader Config m => m DynFlags
getDynFlags

          Bool -> PassM () -> PassM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
dump_op (PassM () -> PassM ()) -> PassM () -> PassM ()
forall a b. (a -> b) -> a -> b
$
            Int -> String -> PassM () -> PassM ()
forall a. Int -> String -> a -> a
dbgTrace Int
2 (String
"Adding traversals: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> String
forall a. Out a => a -> String
sdoc String
context) (() -> PassM ()
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
          -- Generate traversals: assuming that InferLocs has already generated
          -- the traversal functions, we only use it here.
          [(Var, [Var], Ty2, Exp2)]
trav_binds <- [(Exp2, Ty2)] -> PassM [(Var, [Var], Ty2, Exp2)]
genTravBinds (((Var, Var) -> (Exp2, Ty2)) -> [(Var, Var)] -> [(Exp2, Ty2)]
forall a b. (a -> b) -> [a] -> [b]
L.map (\(Var
p_var, Var
_p_loc) -> (Var -> Exp2
forall (ext :: * -> * -> *) loc dec. Var -> PreExp ext loc dec
VarE Var
p_var, Var -> Env2 Ty2 -> Ty2
forall a. Out a => Var -> Env2 a -> a
lookupVEnv Var
p_var Env2 Ty2
env21)) [(Var, Var)]
ls)
          (String
dcon,[(Var, Var)]
vlocs,) (Exp2 -> (String, [(Var, Var)], Exp2))
-> (Exp2 -> Exp2) -> Exp2 -> (String, [(Var, Var)], Exp2)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Var, [Var], Ty2, Exp2)] -> Exp2 -> Exp2
forall loc dec (ext :: * -> * -> *).
[(Var, [loc], dec, PreExp ext loc dec)]
-> PreExp ext loc dec -> PreExp ext loc dec
mkLets [(Var, [Var], Ty2, Exp2)]
trav_binds (Exp2 -> (String, [(Var, Var)], Exp2))
-> PassM Exp2 -> PassM (String, [(Var, Var)], Exp2)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
            DDefs Ty2
-> FunDefs Exp2
-> Env2 Ty2
-> RegEnv
-> String
-> Exp2
-> PassM Exp2
addTraversalsExp DDefs Ty2
ddefs FunDefs Exp2
fundefs Env2 Ty2
env21 RegEnv
renv1 String
context Exp2
rhs


-- | Collect all non-static items that need to be traversed (uses InferEffects).
--
-- If we cannot unpack all the pattern matched variables:
-- (1) Everything after the first packed element should be unused in the RHS
-- (2) Otherwise, we must traverse the first (n-1) packed elements
needsTraversalCase :: DDefs Ty2 -> FunDefs2 -> Env2 Ty2 -> (DataCon, [(Var, LocVar)], Exp2) -> Maybe [(Var, LocVar)]
needsTraversalCase :: DDefs Ty2
-> FunDefs Exp2
-> Env2 Ty2
-> (String, [(Var, Var)], Exp2)
-> Maybe [(Var, Var)]
needsTraversalCase DDefs Ty2
ddefs FunDefs Exp2
fundefs Env2 Ty2
env2 (String
dcon,[(Var, Var)]
vlocs,Exp2
rhs) =
  if String -> Bool
isAbsRANDataCon String
dcon Bool -> Bool -> Bool
|| String -> Bool
isRelRANDataCon String
dcon then Maybe [(Var, Var)]
forall a. Maybe a
Nothing else
  let ([Var]
vars, [Var]
_locs) = [(Var, Var)] -> ([Var], [Var])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Var, Var)]
vlocs
      tys :: [Ty2]
tys     = DDefs Ty2 -> String -> [Ty2]
forall a. Out a => DDefs a -> String -> [a]
lookupDataCon DDefs Ty2
ddefs String
dcon
      tyenv :: Map Var Ty2
tyenv   = [(Var, Ty2)] -> Map Var Ty2
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([Var] -> [Ty2] -> [(Var, Ty2)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Var]
vars [Ty2]
tys)
      funenv :: Map Var (ArrowTy2 Ty2)
funenv  = (FunDef2 -> ArrowTy2 Ty2) -> FunDefs Exp2 -> Map Var (ArrowTy2 Ty2)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map FunDef2 -> ArrowTy (TyOf Exp2)
FunDef2 -> ArrowTy2 Ty2
forall ex. FunDef ex -> ArrowTy (TyOf ex)
funTy FunDefs Exp2
fundefs
      dps :: RegEnv
dps     = [Var] -> RegEnv
forall {a}. Ord a => [a] -> Map a a
makeDps ([Var] -> [Var]
forall a. [a] -> [a]
reverse ([Var] -> [Var]) -> [Var] -> [Var]
forall a b. (a -> b) -> a -> b
$ ((Var, Var) -> Var) -> [(Var, Var)] -> [Var]
forall a b. (a -> b) -> [a] -> [b]
L.map (Var, Var) -> Var
forall a b. (a, b) -> b
snd [(Var, Var)]
vlocs)
      (Set Effect
eff,Maybe Var
_) = DDefs Ty2
-> Map Var (ArrowTy2 Ty2)
-> Map Var Ty2
-> RegEnv
-> Exp2
-> (Set Effect, Maybe Var)
inferExp DDefs Ty2
ddefs Map Var (ArrowTy2 Ty2)
funenv (Map Var Ty2 -> Map Var Ty2 -> Map Var Ty2
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union Map Var Ty2
tyenv (Env2 Ty2 -> Map Var Ty2
forall a. Env2 a -> TyEnv a
vEnv Env2 Ty2
env2)) RegEnv
dps Exp2
rhs
      -- Note: Do not use Data.Map.filter. It changes the order sometimes.
      packedOnly :: [(Var, Ty2)]
packedOnly = ((Var, Ty2) -> Bool) -> [(Var, Ty2)] -> [(Var, Ty2)]
forall a. (a -> Bool) -> [a] -> [a]
L.filter (\(Var
_,Ty2
b) -> Ty2 -> Bool
forall a. Show a => UrTy a -> Bool
hasPacked Ty2
b) ([Var] -> [Ty2] -> [(Var, Ty2)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Var]
vars [Ty2]
tys)

      makeDps :: [a] -> Map a a
makeDps [] = Map a a
forall k a. Map k a
M.empty
      makeDps [a
_] = Map a a
forall k a. Map k a
M.empty
      makeDps (a
v:a
v':[a]
vs) = a -> a -> Map a a -> Map a a
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert a
v a
v' ([a] -> Map a a
makeDps [a]
vs)

      effToLoc :: Effect -> Var
effToLoc (Traverse Var
loc_var) = Var
loc_var

      not_traversed :: [Var]
not_traversed = case [(Var, Ty2)]
packedOnly of
                        [] -> []
                        [(Var, Ty2)]
ls -> let locenv :: RegEnv
locenv = [(Var, Var)] -> RegEnv
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Var, Var)]
vlocs
                                  packedlocs :: [Effect]
packedlocs = ((Var, Ty2) -> Effect) -> [(Var, Ty2)] -> [Effect]
forall a b. (a -> b) -> [a] -> [b]
L.map (\(Var
a,Ty2
_) -> Var -> Effect
Traverse (RegEnv
locenv RegEnv -> Var -> Var
forall a b.
(Ord a, Out a, Out b, Show a, HasCallStack) =>
Map a b -> a -> b
# Var
a)) [(Var, Ty2)]
ls
                                  -- Get the locations of all non-static things which the RHS does not traverse.
                                  -- Note: Using Data.Set changes the order of packedlocs, and we would
                                  -- like to preserve it.
                              in (Effect -> Var) -> [Effect] -> [Var]
forall a b. (a -> b) -> [a] -> [b]
L.map Effect -> Var
effToLoc ([Effect] -> [Var]) -> [Effect] -> [Var]
forall a b. (a -> b) -> a -> b
$ [Effect]
packedlocs [Effect] -> [Effect] -> [Effect]
forall a. Eq a => [a] -> [a] -> [a]
L.\\ (Set Effect -> [Effect]
forall a. Set a -> [a]
S.toList Set Effect
eff)

   in case (Ty2 -> Bool) -> [Ty2] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
L.findIndex Ty2 -> Bool
forall a. UrTy a -> Bool
isPackedTy [Ty2]
tys of
        Maybe Int
Nothing -> Maybe [(Var, Var)]
forall a. Maybe a
Nothing
        Just Int
i  -> case [Var]
not_traversed of
                     [] -> Maybe [(Var, Var)]
forall a. Maybe a
Nothing
                     [Var]
_ls -> do
                       let -- Why (i+1): findIndex is 0-based, and drop is not
                           should_be_unused :: Set Var
should_be_unused = [Var] -> Set Var
forall a. Ord a => [a] -> Set a
S.fromList ([Var] -> Set Var) -> [Var] -> Set Var
forall a b. (a -> b) -> a -> b
$ Int -> [Var] -> [Var]
forall a. Int -> [a] -> [a]
L.drop (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) [Var]
vars

                           -- LocVar -> Var
                           loc_var_mp :: RegEnv
loc_var_mp = [(Var, Var)] -> RegEnv
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Var, Var)] -> RegEnv) -> [(Var, Var)] -> RegEnv
forall a b. (a -> b) -> a -> b
$ ((Var, Var) -> (Var, Var)) -> [(Var, Var)] -> [(Var, Var)]
forall a b. (a -> b) -> [a] -> [b]
L.map (\(Var
a,Var
b) -> (Var
b,Var
a)) [(Var, Var)]
vlocs

                           -- POLICY: We only traverse the first (n-1) packed elements.
                           -- However if (n==1), we traverse that element. Need to audit this.
                           ls :: [(Var, Var)]
ls = (Var -> (Var, Var)) -> [Var] -> [(Var, Var)]
forall a b. (a -> b) -> [a] -> [b]
L.map (\Var
a -> (RegEnv
loc_var_mp RegEnv -> Var -> Var
forall a b.
(Ord a, Out a, Out b, Show a, HasCallStack) =>
Map a b -> a -> b
# Var
a, Var
a)) [Var]
not_traversed
                           trav :: [(Var, Var)]
trav = if [(Var, Var)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Var, Var)]
ls Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
                                  -- [2020.05.01]: CSK, I haven't thought about this change too much. Maybe we need to revisit this.
                                  then []
                                  else [(Var, Var)] -> [(Var, Var)]
forall a. HasCallStack => [a] -> [a]
init [(Var, Var)]
ls

                       -- If the problematic elements are unused, we don't need to add traversals
                       if Bool -> Bool
not (Set Var -> Exp2 -> Bool
occurs Set Var
should_be_unused Exp2
rhs)
                       then Maybe [(Var, Var)]
forall a. Maybe a
Nothing
                       else [(Var, Var)] -> Maybe [(Var, Var)]
forall a. a -> Maybe a
Just [(Var, Var)]
trav

genTravBinds :: [(Exp2, Ty2)] -> PassM [(Var, [LocVar], Ty2, Exp2)]
genTravBinds :: [(Exp2, Ty2)] -> PassM [(Var, [Var], Ty2, Exp2)]
genTravBinds [(Exp2, Ty2)]
ls = [[(Var, [Var], Ty2, Exp2)]] -> [(Var, [Var], Ty2, Exp2)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(Var, [Var], Ty2, Exp2)]] -> [(Var, [Var], Ty2, Exp2)])
-> PassM [[(Var, [Var], Ty2, Exp2)]]
-> PassM [(Var, [Var], Ty2, Exp2)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
  ([(Exp2, Ty2)]
-> ((Exp2, Ty2) -> PassM [(Var, [Var], Ty2, Exp2)])
-> PassM [[(Var, [Var], Ty2, Exp2)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Exp2, Ty2)]
ls (((Exp2, Ty2) -> PassM [(Var, [Var], Ty2, Exp2)])
 -> PassM [[(Var, [Var], Ty2, Exp2)]])
-> ((Exp2, Ty2) -> PassM [(Var, [Var], Ty2, Exp2)])
-> PassM [[(Var, [Var], Ty2, Exp2)]]
forall a b. (a -> b) -> a -> b
$ \(Exp2
e,Ty2
ty) ->
      case Ty2
ty of
        PackedTy String
tycon Var
loc1 -> do
          Var
w <- Var -> PassM Var
forall (m :: * -> *). MonadState Int m => Var -> m Var
gensym Var
"trav"
          let fn_name :: Var
fn_name = String -> Var
mkTravFunName String
tycon
          [(Var, [Var], Ty2, Exp2)] -> PassM [(Var, [Var], Ty2, Exp2)]
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return [(Var
w,[],[Ty2] -> Ty2
forall loc. [UrTy loc] -> UrTy loc
ProdTy [], Var -> [Var] -> [Exp2] -> Exp2
forall (ext :: * -> * -> *) loc dec.
Var -> [loc] -> [PreExp ext loc dec] -> PreExp ext loc dec
AppE Var
fn_name [Var
loc1] [Exp2
e])]
        -- TODO: Write a testcase for this path.
        ProdTy [Ty2]
tys -> do
          -- So that we don't have to make assumptions about the 'e' being a VarE
          Var
tmp <- Var -> PassM Var
forall (m :: * -> *). MonadState Int m => Var -> m Var
gensym Var
"tmp_trav"
          [(Var, [Var], Ty2, Exp2)]
proj_binds <-
            [(Exp2, Ty2)] -> PassM [(Var, [Var], Ty2, Exp2)]
genTravBinds (([(Exp2, Ty2)] -> (Ty2, Int) -> [(Exp2, Ty2)])
-> [(Exp2, Ty2)] -> [(Ty2, Int)] -> [(Exp2, Ty2)]
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl (\[(Exp2, Ty2)]
acc (Ty2
ty1,Int
idx) ->
                                   if Ty2 -> Bool
forall a. UrTy a -> Bool
isPackedTy Ty2
ty1
                                   then (Int -> Exp2 -> Exp2
forall (ext :: * -> * -> *) loc dec.
Int -> PreExp ext loc dec -> PreExp ext loc dec
mkProj Int
idx (Var -> Exp2
forall (ext :: * -> * -> *) loc dec. Var -> PreExp ext loc dec
VarE Var
tmp), Ty2
ty1) (Exp2, Ty2) -> [(Exp2, Ty2)] -> [(Exp2, Ty2)]
forall a. a -> [a] -> [a]
: [(Exp2, Ty2)]
acc
                                   else [(Exp2, Ty2)]
acc)
                                  []
                                  ([Ty2] -> [Int] -> [(Ty2, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ty2]
tys [Int
0..]))
          [(Var, [Var], Ty2, Exp2)] -> PassM [(Var, [Var], Ty2, Exp2)]
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return ([(Var, [Var], Ty2, Exp2)] -> PassM [(Var, [Var], Ty2, Exp2)])
-> [(Var, [Var], Ty2, Exp2)] -> PassM [(Var, [Var], Ty2, Exp2)]
forall a b. (a -> b) -> a -> b
$ [(Var
tmp,[],Ty2
ty,Exp2
e)] [(Var, [Var], Ty2, Exp2)]
-> [(Var, [Var], Ty2, Exp2)] -> [(Var, [Var], Ty2, Exp2)]
forall a. [a] -> [a] -> [a]
++ [(Var, [Var], Ty2, Exp2)]
proj_binds
        Ty2
_ -> [(Var, [Var], Ty2, Exp2)] -> PassM [(Var, [Var], Ty2, Exp2)]
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return [])