module Gibbon.Passes.Unariser
  (unariser, unariserExp) where

import qualified Data.Map as M
import qualified Data.List as L

import Gibbon.Common
import Gibbon.L1.Syntax
import Gibbon.L3.Syntax
import Gibbon.Passes.Flatten()


-- | This pass gets ready for Lower by converting most uses of
-- projection and tuple-construction into finer-grained bindings.
--
-- OUTPUT INVARIANTS:
--
-- (1) only flat tuples as function arguments (no nesting), all
-- arguments immediately present, e.g. `AppE "f" (MkProd [x,y,z])`
--  rather than `AppE "f" (MkProdE [x,MkProdE[y,z]])`
--
-- (2) The only MkProdE allowed outside of function operands is within
-- return/tail position (of a function or If branch).
--
-- (3) Primitives are allowed to return tuples, but are let-bound
-- (these will turn into LetPrimCall).  The references to these tuples
-- are all of the form `ProjE i (VarE v)` and they are then
-- transformed to varrefs in lower.
--
-- [Aditya Gupta, Oct 2021]
-- NOTE: I am limiting flattening to only intermediate expressions. i.e. the tail value
-- of main expression is a terminal expression and shouldn't be flattened.
-- We can recursively propagate terminality based on expression type. This way all intermediate
-- expressions will enjoy benefit from flattening, but we still retain same output for terminal expressions.
-- We can have a separate function to recover after unarising but that won't have the env2/ddefs values
-- and we won't be able to fuse it into unariser cases. But on the other hands, defining a separate function
-- can eliminate missed cases, but there are only few, so combining recovering terminal expressions in unariser
-- seems best.

unariser :: Prog3 -> PassM Prog3
unariser :: Prog3 -> PassM Prog3
unariser Prog{DDefs (TyOf Exp3)
ddefs :: DDefs (TyOf Exp3)
ddefs :: forall ex. Prog ex -> DDefs (TyOf ex)
ddefs,FunDefs Exp3
fundefs :: FunDefs Exp3
fundefs :: forall ex. Prog ex -> FunDefs ex
fundefs,Maybe (Exp3, TyOf Exp3)
mainExp :: Maybe (Exp3, TyOf Exp3)
mainExp :: forall ex. Prog ex -> Maybe (ex, TyOf ex)
mainExp} = do
  Maybe (Exp3, Ty3)
mn <- case Maybe (Exp3, TyOf Exp3)
mainExp of
          -- type should remain same and main output is a terminal expression
          Just (Exp3
m,TyOf Exp3
t) -> do Exp3
m' <- Bool -> DDefs Ty3 -> ProjStack -> Env2 Ty3 -> Exp3 -> PassM Exp3
unariserExp Bool
True DDefs (TyOf Exp3)
DDefs Ty3
ddefs [] (TyEnv Ty3 -> TyEnv (ArrowTy Ty3) -> Env2 Ty3
forall a. TyEnv a -> TyEnv (ArrowTy a) -> Env2 a
Env2 TyEnv Ty3
forall k a. Map k a
M.empty Map Var ([Ty3], Ty3)
TyEnv (ArrowTy Ty3)
funEnv) Exp3
m
                           Maybe (Exp3, Ty3) -> PassM (Maybe (Exp3, Ty3))
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Exp3, Ty3) -> PassM (Maybe (Exp3, Ty3)))
-> Maybe (Exp3, Ty3) -> PassM (Maybe (Exp3, Ty3))
forall a b. (a -> b) -> a -> b
$ (Exp3, Ty3) -> Maybe (Exp3, Ty3)
forall a. a -> Maybe a
Just (Exp3
m', TyOf Exp3
Ty3
t)
          Maybe (Exp3, TyOf Exp3)
Nothing -> Maybe (Exp3, Ty3) -> PassM (Maybe (Exp3, Ty3))
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Exp3, Ty3)
forall a. Maybe a
Nothing
  FunDefs Exp3
fds' <- (FunDef3 -> PassM FunDef3) -> FunDefs Exp3 -> PassM (FunDefs Exp3)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Map Var a -> m (Map Var b)
mapM FunDef3 -> PassM FunDef3
unariserFun FunDefs Exp3
fundefs
  Prog3 -> PassM Prog3
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Prog3 -> PassM Prog3) -> Prog3 -> PassM Prog3
forall a b. (a -> b) -> a -> b
$ DDefs (TyOf Exp3)
-> FunDefs Exp3 -> Maybe (Exp3, TyOf Exp3) -> Prog3
forall ex.
DDefs (TyOf ex) -> FunDefs ex -> Maybe (ex, TyOf ex) -> Prog ex
Prog DDefs (TyOf Exp3)
ddefs FunDefs Exp3
fds' Maybe (Exp3, TyOf Exp3)
Maybe (Exp3, Ty3)
mn


  -- Modifies function to satisfy output invariant (1)
  --
  where
    funEnv :: Map Var ([Ty3], Ty3)
funEnv = (FunDef3 -> ([Ty3], Ty3)) -> FunDefs Exp3 -> Map Var ([Ty3], Ty3)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map FunDef3 -> ([Ty3], Ty3)
FunDef3 -> ArrowTy (TyOf Exp3)
forall ex. FunDef ex -> ArrowTy (TyOf ex)
funTy FunDefs Exp3
fundefs

    unariserFun :: FunDef3 -> PassM FunDef3
    unariserFun :: FunDef3 -> PassM FunDef3
unariserFun f :: FunDef3
f@FunDef{ArrowTy (TyOf Exp3)
funTy :: forall ex. FunDef ex -> ArrowTy (TyOf ex)
funTy :: ArrowTy (TyOf Exp3)
funTy,[Var]
funArgs :: [Var]
funArgs :: forall ex. FunDef ex -> [Var]
funArgs,Exp3
funBody :: Exp3
funBody :: forall ex. FunDef ex -> ex
funBody} = do
      let in_tys :: [Ty3]
in_tys  = ArrowTy Ty3 -> [Ty3]
forall ty. FunctionTy ty => ArrowTy ty -> [ty]
inTys ArrowTy (TyOf Exp3)
ArrowTy Ty3
funTy
          in_tys' :: [Ty3]
in_tys' = (Ty3 -> Ty3) -> [Ty3] -> [Ty3]
forall a b. (a -> b) -> [a] -> [b]
map Ty3 -> Ty3
flattenTy [Ty3]
in_tys
          out_ty' :: Ty3
out_ty' = Ty3 -> Ty3
flattenTy (ArrowTy Ty3 -> Ty3
forall ty. FunctionTy ty => ArrowTy ty -> ty
outTy ArrowTy (TyOf Exp3)
ArrowTy Ty3
funTy)
          fun_body :: Exp3
fun_body =
               ((Var, Ty3) -> Exp3 -> Exp3) -> Exp3 -> [(Var, Ty3)] -> Exp3
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
                 (\(Var
a,Ty3
t) Exp3
acc ->
                    case Ty3
t of
                      ProdTy{} -> Var -> Ty3 -> Exp3 -> Exp3
flattenExp Var
a Ty3
t Exp3
acc
                      Ty3
_ -> Exp3
acc)
                 Exp3
funBody ([Var] -> [Ty3] -> [(Var, Ty3)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Var]
funArgs [Ty3]
in_tys)
          fn :: FunDef3
fn = FunDef3
f { funTy :: ArrowTy (TyOf Exp3)
funTy = ([Ty3]
in_tys', Ty3
out_ty')
                 , funBody :: Exp3
funBody = Exp3
fun_body }
          env2 :: Env2 Ty3
env2 = TyEnv Ty3 -> TyEnv (ArrowTy Ty3) -> Env2 Ty3
forall a. TyEnv a -> TyEnv (ArrowTy a) -> Env2 a
Env2 ([(Var, Ty3)] -> TyEnv Ty3
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Var, Ty3)] -> TyEnv Ty3) -> [(Var, Ty3)] -> TyEnv Ty3
forall a b. (a -> b) -> a -> b
$ [Var] -> [Ty3] -> [(Var, Ty3)]
forall a b. [a] -> [b] -> [(a, b)]
zip  [Var]
funArgs [Ty3]
in_tys) Map Var ([Ty3], Ty3)
TyEnv (ArrowTy Ty3)
funEnv
      -- all function bodies are intermediate expressions
      Exp3
bod <- Bool -> DDefs Ty3 -> ProjStack -> Env2 Ty3 -> Exp3 -> PassM Exp3
unariserExp Bool
False DDefs (TyOf Exp3)
DDefs Ty3
ddefs [] Env2 Ty3
env2 Exp3
funBody
      FunDef3 -> PassM FunDef3
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef3 -> PassM FunDef3) -> FunDef3 -> PassM FunDef3
forall a b. (a -> b) -> a -> b
$ FunDef3
fn { funBody :: Exp3
funBody = Exp3
bod }


-- | A projection stack can be viewed as a list of ProjE operations to
-- perform, from left to right.
type ProjStack = [Int]

unariserExp :: Bool -> DDefs Ty3 -> ProjStack -> Env2 Ty3 -> Exp3 -> PassM Exp3
unariserExp :: Bool -> DDefs Ty3 -> ProjStack -> Env2 Ty3 -> Exp3 -> PassM Exp3
unariserExp Bool
isTerminal DDefs Ty3
ddfs ProjStack
stk Env2 Ty3
env2 Exp3
ex =
  case Exp3
ex of
    LetE (Var
v,[()]
locs,Ty3
ty,Exp3
rhs) Exp3
bod ->
      (Var, [()], Ty3, Exp3) -> Exp3 -> Exp3
forall (ext :: * -> * -> *) loc dec.
(Var, [loc], dec, PreExp ext loc dec)
-> PreExp ext loc dec -> PreExp ext loc dec
LetE ((Var, [()], Ty3, Exp3) -> Exp3 -> Exp3)
-> (Exp3 -> (Var, [()], Ty3, Exp3)) -> Exp3 -> Exp3 -> Exp3
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Var
v,[()]
locs, Ty3 -> Ty3
flattenTy Ty3
ty,)
        (Exp3 -> Exp3 -> Exp3) -> PassM Exp3 -> PassM (Exp3 -> Exp3)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Env2 Ty3 -> Exp3 -> PassM Exp3
go Bool
False Env2 Ty3
env2 Exp3
rhs
        PassM (Exp3 -> Exp3) -> PassM Exp3 -> PassM Exp3
forall a b. PassM (a -> b) -> PassM a -> PassM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Bool -> Env2 Ty3 -> Exp3 -> PassM Exp3
go Bool
isTerminal (Var -> Ty3 -> Env2 Ty3 -> Env2 Ty3
forall a. Var -> a -> Env2 a -> Env2 a
extendVEnv Var
v Ty3
ty Env2 Ty3
env2) Exp3
bod

    MkProdE [Exp3]
es ->
      -- if terminal, don't flatten product
      if Bool
isTerminal then Exp3 -> PassM Exp3
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp3 -> PassM Exp3) -> Exp3 -> PassM Exp3
forall a b. (a -> b) -> a -> b
$ Exp3 -> Exp3
recover0 Exp3
ex else
        case ProjStack
stk of
          [] -> DDefs Ty3 -> ProjStack -> Env2 Ty3 -> Exp3 -> PassM Exp3
flattenProd DDefs Ty3
ddfs ProjStack
stk Env2 Ty3
env2 Exp3
ex
          (Int
ix:ProjStack
s') -> Bool -> DDefs Ty3 -> ProjStack -> Env2 Ty3 -> Exp3 -> PassM Exp3
unariserExp Bool
False DDefs Ty3
ddfs ProjStack
s' Env2 Ty3
env2 ([Exp3]
es [Exp3] -> Int -> Exp3
forall {a}. Out a => [a] -> Int -> a
! Int
ix)

    -- When projecting a value out of a nested tuple, we have to update the index
    -- to match the flattened representation. And if the ith projection was a
    -- product before, we have to reconstruct it here, since it will be flattened
    -- after this pass.
    --
    -- if it's a terminal expression, then ith projection should be a terminal,
    -- we can reuse reconstruciton logic.
    ProjE Int
i Exp3
e ->
      case Exp3
e of
        MkProdE [Exp3]
ls -> Bool -> Env2 Ty3 -> Exp3 -> PassM Exp3
go Bool
isTerminal Env2 Ty3
env2 ([Exp3]
ls [Exp3] -> Int -> Exp3
forall {a}. Out a => [a] -> Int -> a
! Int
i)
        Exp3
_ -> do
          let ety :: TyOf Exp3
ety = DDefs (TyOf Exp3) -> Env2 (TyOf Exp3) -> Exp3 -> TyOf Exp3
forall e.
Typeable e =>
DDefs (TyOf e) -> Env2 (TyOf e) -> e -> TyOf e
gRecoverType DDefs (TyOf Exp3)
DDefs Ty3
ddfs Env2 (TyOf Exp3)
Env2 Ty3
env2 Exp3
e -- type before flattening
              j :: Int
j   = Int -> Ty3 -> Int
flatProjIdx Int
i TyOf Exp3
Ty3
ety -- index in flattened type
              ity :: Ty3
ity = Int -> Ty3 -> Ty3
forall a. Out a => Int -> UrTy a -> UrTy a
projTy Int
i TyOf Exp3
Ty3
ety -- ith index type before flattening
              fty :: Ty3
fty = Ty3 -> Ty3
flattenTy Ty3
ity -- jth index type in flattened
          Exp3
e' <- Bool -> Env2 Ty3 -> Exp3 -> PassM Exp3
go Bool
False Env2 Ty3
env2 Exp3
e -- recusrively unarise, but since this is an intermediate value, we can flatten it
          case Exp3
e' of
            -- if we get a product after flattening, get jth element of flattened element
            MkProdE [Exp3]
xs -> Exp3 -> PassM Exp3
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Exp3]
xs [Exp3] -> Int -> Exp3
forall {a}. Out a => [a] -> Int -> a
! Int
j)
            Exp3
_ ->
              -- otherwise, check jth element
             case Ty3
fty of
               -- reconstruct, in case of nested tuple (whether terminal or not)
               ProdTy [Ty3]
tys -> do
                 Exp3 -> PassM Exp3
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp3 -> PassM Exp3) -> Exp3 -> PassM Exp3
forall a b. (a -> b) -> a -> b
$ [Exp3] -> Exp3
forall (ext :: * -> * -> *) loc dec.
[PreExp ext loc dec] -> PreExp ext loc dec
MkProdE ((Int -> Exp3) -> ProjStack -> [Exp3]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
k -> Int -> Exp3 -> Exp3
forall (ext :: * -> * -> *) loc dec.
Int -> PreExp ext loc dec -> PreExp ext loc dec
ProjE Int
k Exp3
e') [Int
j..(Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+[Ty3] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ty3]
tysInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)])
               -- if not a tuple, take projection
               Ty3
_ -> Exp3 -> PassM Exp3
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp3 -> PassM Exp3) -> Exp3 -> PassM Exp3
forall a b. (a -> b) -> a -> b
$ Int -> Exp3 -> Exp3
forall (ext :: * -> * -> *) loc dec.
Int -> PreExp ext loc dec -> PreExp ext loc dec
ProjE Int
j Exp3
e'


    -- Straightforward recursion
    VarE{} -> Exp3 -> PassM Exp3
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp3 -> PassM Exp3) -> (Exp3 -> Exp3) -> Exp3 -> PassM Exp3
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (if Bool
isTerminal then Exp3 -> Exp3
recover0 else Exp3 -> Exp3
forall a. a -> a
id) (Exp3 -> PassM Exp3) -> Exp3 -> PassM Exp3
forall a b. (a -> b) -> a -> b
$ ProjStack -> Exp3 -> Exp3
discharge ProjStack
stk Exp3
ex

    LitE{} ->
      case ProjStack
stk of
        [] -> Exp3 -> PassM Exp3
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp3
ex
        ProjStack
_  -> [Char] -> PassM Exp3
forall a. HasCallStack => [Char] -> a
error ([Char] -> PassM Exp3) -> [Char] -> PassM Exp3
forall a b. (a -> b) -> a -> b
$ [Char]
"Impossible. Non-empty projection stack on LitE "[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ProjStack -> [Char]
forall a. Show a => a -> [Char]
show ProjStack
stk

    CharE{} ->
      case ProjStack
stk of
        [] -> Exp3 -> PassM Exp3
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp3
ex
        ProjStack
_  -> [Char] -> PassM Exp3
forall a. HasCallStack => [Char] -> a
error ([Char] -> PassM Exp3) -> [Char] -> PassM Exp3
forall a b. (a -> b) -> a -> b
$ [Char]
"Impossible. Non-empty projection stack on LitE "[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ProjStack -> [Char]
forall a. Show a => a -> [Char]
show ProjStack
stk

    FloatE{} ->
      case ProjStack
stk of
        [] -> Exp3 -> PassM Exp3
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp3
ex
        ProjStack
_  -> [Char] -> PassM Exp3
forall a. HasCallStack => [Char] -> a
error ([Char] -> PassM Exp3) -> [Char] -> PassM Exp3
forall a b. (a -> b) -> a -> b
$ [Char]
"Impossible. Non-empty projection stack on LitE "[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ProjStack -> [Char]
forall a. Show a => a -> [Char]
show ProjStack
stk

    LitSymE{} ->
      case ProjStack
stk of
        [] -> Exp3 -> PassM Exp3
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp3
ex
        ProjStack
_  -> [Char] -> PassM Exp3
forall a. HasCallStack => [Char] -> a
error ([Char] -> PassM Exp3) -> [Char] -> PassM Exp3
forall a b. (a -> b) -> a -> b
$ [Char]
"Impossible. Non-empty projection stack on LitSymE "[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ProjStack -> [Char]
forall a. Show a => a -> [Char]
show ProjStack
stk

    -- For function output, we need to recover the application nto the arguments
    AppE Var
v [()]
locs [Exp3]
args -> do
      Exp3
exp0 <- ProjStack -> Exp3 -> Exp3
discharge ProjStack
stk (Exp3 -> Exp3) -> PassM Exp3 -> PassM Exp3
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Var -> [()] -> [Exp3] -> Exp3
forall (ext :: * -> * -> *) loc dec.
Var -> [loc] -> [PreExp ext loc dec] -> PreExp ext loc dec
AppE Var
v [()]
locs ([Exp3] -> Exp3) -> PassM [Exp3] -> PassM Exp3
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp3 -> PassM Exp3) -> [Exp3] -> PassM [Exp3]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Bool -> Env2 Ty3 -> Exp3 -> PassM Exp3
go Bool
False Env2 Ty3
env2) [Exp3]
args)
      if Bool
isTerminal
        then do
          Var
tmp <- Var -> PassM Var
forall (m :: * -> *). MonadState Int m => Var -> m Var
gensym Var
"tmp_app"
          let ty' :: TyOf Exp3
ty' = DDefs (TyOf Exp3) -> Env2 (TyOf Exp3) -> Exp3 -> TyOf Exp3
forall e.
Typeable e =>
DDefs (TyOf e) -> Env2 (TyOf e) -> e -> TyOf e
gRecoverType DDefs (TyOf Exp3)
DDefs Ty3
ddfs Env2 (TyOf Exp3)
Env2 Ty3
env2 Exp3
exp0
          Exp3 -> PassM Exp3
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp3 -> PassM Exp3) -> Exp3 -> PassM Exp3
forall a b. (a -> b) -> a -> b
$ (Var, [()], Ty3, Exp3) -> Exp3 -> Exp3
forall (ext :: * -> * -> *) loc dec.
(Var, [loc], dec, PreExp ext loc dec)
-> PreExp ext loc dec -> PreExp ext loc dec
LetE (Var
tmp, [], Ty3 -> Ty3
flattenTy TyOf Exp3
Ty3
ty', Exp3
exp0) (Exp3 -> Ty3 -> Exp3
recover (Var -> Exp3
forall (ext :: * -> * -> *) loc dec. Var -> PreExp ext loc dec
VarE Var
tmp) TyOf Exp3
Ty3
ty')
        else
          Exp3 -> PassM Exp3
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp3
exp0

    PrimAppE Prim Ty3
pr [Exp3]
args -> ProjStack -> Exp3 -> Exp3
discharge ProjStack
stk (Exp3 -> Exp3) -> PassM Exp3 -> PassM Exp3
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Prim Ty3 -> [Exp3] -> Exp3
forall (ext :: * -> * -> *) loc dec.
Prim dec -> [PreExp ext loc dec] -> PreExp ext loc dec
PrimAppE Prim Ty3
pr ([Exp3] -> Exp3) -> PassM [Exp3] -> PassM Exp3
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp3 -> PassM Exp3) -> [Exp3] -> PassM [Exp3]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Bool -> Env2 Ty3 -> Exp3 -> PassM Exp3
go Bool
False Env2 Ty3
env2) [Exp3]
args)

    -- condition is an intermediate expression, we only care about then and else branches as terminal expressions
    IfE Exp3
a Exp3
b Exp3
c  -> Exp3 -> Exp3 -> Exp3 -> Exp3
forall (ext :: * -> * -> *) loc dec.
PreExp ext loc dec
-> PreExp ext loc dec -> PreExp ext loc dec -> PreExp ext loc dec
IfE (Exp3 -> Exp3 -> Exp3 -> Exp3)
-> PassM Exp3 -> PassM (Exp3 -> Exp3 -> Exp3)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Env2 Ty3 -> Exp3 -> PassM Exp3
go Bool
False Env2 Ty3
env2 Exp3
a PassM (Exp3 -> Exp3 -> Exp3) -> PassM Exp3 -> PassM (Exp3 -> Exp3)
forall a b. PassM (a -> b) -> PassM a -> PassM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Bool -> Env2 Ty3 -> Exp3 -> PassM Exp3
go Bool
isTerminal Env2 Ty3
env2 Exp3
b PassM (Exp3 -> Exp3) -> PassM Exp3 -> PassM Exp3
forall a b. PassM (a -> b) -> PassM a -> PassM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Bool -> Env2 Ty3 -> Exp3 -> PassM Exp3
go Bool
isTerminal Env2 Ty3
env2 Exp3
c

    CaseE Exp3
e [([Char], [(Var, ())], Exp3)]
ls -> do
      -- Add pattern matched vars to the environment
      -- data constructor arguments are also terminal
      let f :: [Char] -> [(Var, ())] -> Env2 Ty3
f [Char]
dcon [(Var, ())]
vlocs = TyEnv Ty3 -> Env2 Ty3 -> Env2 Ty3
forall a. Map Var a -> Env2 a -> Env2 a
extendsVEnv ([(Var, Ty3)] -> TyEnv Ty3
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Var, Ty3)] -> TyEnv Ty3) -> [(Var, Ty3)] -> TyEnv Ty3
forall a b. (a -> b) -> a -> b
$ [Var] -> [Ty3] -> [(Var, Ty3)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Var, ()) -> Var) -> [(Var, ())] -> [Var]
forall a b. (a -> b) -> [a] -> [b]
map (Var, ()) -> Var
forall a b. (a, b) -> a
fst [(Var, ())]
vlocs) (DDefs Ty3 -> [Char] -> [Ty3]
forall a. Out a => DDefs a -> [Char] -> [a]
lookupDataCon DDefs Ty3
ddfs [Char]
dcon)) Env2 Ty3
env2
      Exp3 -> [([Char], [(Var, ())], Exp3)] -> Exp3
forall (ext :: * -> * -> *) loc dec.
PreExp ext loc dec
-> [([Char], [(Var, loc)], PreExp ext loc dec)]
-> PreExp ext loc dec
CaseE (Exp3 -> [([Char], [(Var, ())], Exp3)] -> Exp3)
-> PassM Exp3 -> PassM ([([Char], [(Var, ())], Exp3)] -> Exp3)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Env2 Ty3 -> Exp3 -> PassM Exp3
go Bool
False Env2 Ty3
env2 Exp3
e PassM ([([Char], [(Var, ())], Exp3)] -> Exp3)
-> PassM [([Char], [(Var, ())], Exp3)] -> PassM Exp3
forall a b. PassM (a -> b) -> PassM a -> PassM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [PassM ([Char], [(Var, ())], Exp3)]
-> PassM [([Char], [(Var, ())], Exp3)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [ ([Char]
k,[(Var, ())]
ls',) (Exp3 -> ([Char], [(Var, ())], Exp3))
-> PassM Exp3 -> PassM ([Char], [(Var, ())], Exp3)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Env2 Ty3 -> Exp3 -> PassM Exp3
go Bool
isTerminal ([Char] -> [(Var, ())] -> Env2 Ty3
f [Char]
k [(Var, ())]
ls') Exp3
x | ([Char]
k,[(Var, ())]
ls',Exp3
x) <- [([Char], [(Var, ())], Exp3)]
ls ]

    DataConE ()
loc [Char]
dcon [Exp3]
args ->
      case ProjStack
stk of
        -- data constructor arguments are also terminal
        [] -> ProjStack -> Exp3 -> Exp3
discharge ProjStack
stk (Exp3 -> Exp3) -> PassM Exp3 -> PassM Exp3
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
              (() -> [Char] -> [Exp3] -> Exp3
forall (ext :: * -> * -> *) loc dec.
loc -> [Char] -> [PreExp ext loc dec] -> PreExp ext loc dec
DataConE ()
loc [Char]
dcon ([Exp3] -> Exp3) -> PassM [Exp3] -> PassM Exp3
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp3 -> PassM Exp3) -> [Exp3] -> PassM [Exp3]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Bool -> Env2 Ty3 -> Exp3 -> PassM Exp3
go Bool
False Env2 Ty3
env2) [Exp3]
args)
        ProjStack
_  -> [Char] -> PassM Exp3
forall a. HasCallStack => [Char] -> a
error ([Char] -> PassM Exp3) -> [Char] -> PassM Exp3
forall a b. (a -> b) -> a -> b
$ [Char]
"Impossible. Non-empty projection stack on DataConE "[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ProjStack -> [Char]
forall a. Show a => a -> [Char]
show ProjStack
stk

    TimeIt Exp3
e Ty3
ty Bool
b -> do
      Var
tmp <- Var -> PassM Var
forall (m :: * -> *). MonadState Int m => Var -> m Var
gensym (Var -> PassM Var) -> Var -> PassM Var
forall a b. (a -> b) -> a -> b
$ [Char] -> Var
toVar [Char]
"timed"
      Exp3
e'  <- Bool -> Env2 Ty3 -> Exp3 -> PassM Exp3
go Bool
isTerminal Env2 Ty3
env2 Exp3
e
      Exp3 -> PassM Exp3
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp3 -> PassM Exp3) -> Exp3 -> PassM Exp3
forall a b. (a -> b) -> a -> b
$ (Var, [()], Ty3, Exp3) -> Exp3 -> Exp3
forall (ext :: * -> * -> *) loc dec.
(Var, [loc], dec, PreExp ext loc dec)
-> PreExp ext loc dec -> PreExp ext loc dec
LetE (Var
tmp,[],Ty3 -> Ty3
flattenTy Ty3
ty, Exp3 -> Ty3 -> Bool -> Exp3
forall (ext :: * -> * -> *) loc dec.
PreExp ext loc dec -> dec -> Bool -> PreExp ext loc dec
TimeIt Exp3
e' Ty3
ty Bool
b) (Var -> Exp3
forall (ext :: * -> * -> *) loc dec. Var -> PreExp ext loc dec
VarE Var
tmp)

    WithArenaE Var
v Exp3
e -> Var -> Exp3 -> Exp3
forall (ext :: * -> * -> *) loc dec.
Var -> PreExp ext loc dec -> PreExp ext loc dec
WithArenaE Var
v (Exp3 -> Exp3) -> PassM Exp3 -> PassM Exp3
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Env2 Ty3 -> Exp3 -> PassM Exp3
go Bool
isTerminal Env2 Ty3
env2 Exp3
e

    SpawnE Var
v [()]
locs [Exp3]
args -> ProjStack -> Exp3 -> Exp3
discharge ProjStack
stk (Exp3 -> Exp3) -> PassM Exp3 -> PassM Exp3
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                            (Var -> [()] -> [Exp3] -> Exp3
forall (ext :: * -> * -> *) loc dec.
Var -> [loc] -> [PreExp ext loc dec] -> PreExp ext loc dec
SpawnE Var
v [()]
locs ([Exp3] -> Exp3) -> PassM [Exp3] -> PassM Exp3
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp3 -> PassM Exp3) -> [Exp3] -> PassM [Exp3]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Bool -> Env2 Ty3 -> Exp3 -> PassM Exp3
go Bool
False Env2 Ty3
env2) [Exp3]
args)

    Exp3
SyncE -> Exp3 -> PassM Exp3
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp3
forall (ext :: * -> * -> *) loc dec. PreExp ext loc dec
SyncE

    Ext (RetE [Exp3]
ls) -> do
      (MkProdE [Exp3]
ls1) <- Bool -> Env2 Ty3 -> Exp3 -> PassM Exp3
go Bool
isTerminal Env2 Ty3
env2 ([Exp3] -> Exp3
forall (ext :: * -> * -> *) loc dec.
[PreExp ext loc dec] -> PreExp ext loc dec
MkProdE [Exp3]
ls)
      Exp3 -> PassM Exp3
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp3 -> PassM Exp3) -> Exp3 -> PassM Exp3
forall a b. (a -> b) -> a -> b
$ E3Ext () Ty3 -> Exp3
forall (ext :: * -> * -> *) loc dec.
ext loc dec -> PreExp ext loc dec
Ext (E3Ext () Ty3 -> Exp3) -> E3Ext () Ty3 -> Exp3
forall a b. (a -> b) -> a -> b
$ [Exp3] -> E3Ext () Ty3
forall loc dec. [PreExp E3Ext loc dec] -> E3Ext loc dec
RetE [Exp3]
ls1

    Ext (LetAvail [Var]
vs Exp3
bod) -> do
        Exp3
bod' <- Bool -> Env2 Ty3 -> Exp3 -> PassM Exp3
go Bool
isTerminal Env2 Ty3
env2 Exp3
bod
        Exp3 -> PassM Exp3
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return(Exp3 -> PassM Exp3) -> Exp3 -> PassM Exp3
forall a b. (a -> b) -> a -> b
$ E3Ext () Ty3 -> Exp3
forall (ext :: * -> * -> *) loc dec.
ext loc dec -> PreExp ext loc dec
Ext (E3Ext () Ty3 -> Exp3) -> E3Ext () Ty3 -> Exp3
forall a b. (a -> b) -> a -> b
$ [Var] -> Exp3 -> E3Ext () Ty3
forall loc dec. [Var] -> PreExp E3Ext loc dec -> E3Ext loc dec
LetAvail [Var]
vs Exp3
bod'
    Ext{}   -> Exp3 -> PassM Exp3
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp3
ex
    MapE{}  -> [Char] -> PassM Exp3
forall a. HasCallStack => [Char] -> a
error [Char]
"unariserExp: MapE TODO"
    FoldE{} -> [Char] -> PassM Exp3
forall a. HasCallStack => [Char] -> a
error [Char]
"unariserExp: FoldE TODO"

  where
    go :: Bool -> Env2 Ty3 -> Exp3 -> PassM Exp3
go Bool
isTerminal' = Bool -> DDefs Ty3 -> ProjStack -> Env2 Ty3 -> Exp3 -> PassM Exp3
unariserExp Bool
isTerminal' DDefs Ty3
ddfs ProjStack
stk

    -- | Reify a stack of projections.
    discharge :: [Int] -> Exp3 -> Exp3
    discharge :: ProjStack -> Exp3 -> Exp3
discharge [] Exp3
e = Exp3
e
    discharge (Int
ix:ProjStack
rst) ((MkProdE [Exp3]
ls)) = ProjStack -> Exp3 -> Exp3
discharge ProjStack
rst ([Exp3]
ls [Exp3] -> Int -> Exp3
forall {a}. Out a => [a] -> Int -> a
! Int
ix)
    discharge (Int
ix:ProjStack
rst) Exp3
e = ProjStack -> Exp3 -> Exp3
discharge ProjStack
rst (Int -> Exp3 -> Exp3
forall (ext :: * -> * -> *) loc dec.
Int -> PreExp ext loc dec -> PreExp ext loc dec
ProjE Int
ix Exp3
e)

    recover0 :: Exp3 -> Exp3
recover0 Exp3
ex0 = let ty :: TyOf Exp3
ty = DDefs (TyOf Exp3) -> Env2 (TyOf Exp3) -> Exp3 -> TyOf Exp3
forall e.
Typeable e =>
DDefs (TyOf e) -> Env2 (TyOf e) -> e -> TyOf e
gRecoverType DDefs (TyOf Exp3)
DDefs Ty3
ddfs Env2 (TyOf Exp3)
Env2 Ty3
env2 Exp3
ex0 in Exp3 -> Ty3 -> Exp3
recover Exp3
ex0 TyOf Exp3
Ty3
ty
    recover :: Exp3 -> Ty3 -> Exp3
    recover :: Exp3 -> Ty3 -> Exp3
recover Exp3
ex0 (ProdTy [Ty3]
tys) =
      [Exp3] -> Exp3
forall (ext :: * -> * -> *) loc dec.
[PreExp ext loc dec] -> PreExp ext loc dec
mkProd ([Exp3] -> Exp3)
-> (([Exp3], Int) -> [Exp3]) -> ([Exp3], Int) -> Exp3
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Exp3], Int) -> [Exp3]
forall a b. (a, b) -> a
fst (([Exp3], Int) -> Exp3) -> ([Exp3], Int) -> Exp3
forall a b. (a -> b) -> a -> b
$ Int -> Exp3 -> [Ty3] -> ([Exp3], Int)
recover' Int
0 Exp3
ex0 [Ty3]
tys
    recover Exp3
ex0 Ty3
_ = Exp3
ex0
    recover' :: Int -> Exp3 -> [Ty3] -> ([Exp3], Int)
    recover' :: Int -> Exp3 -> [Ty3] -> ([Exp3], Int)
recover' Int
idx Exp3
_ [] = ([], Int
idx)
    recover' Int
_ ex0 :: Exp3
ex0@(MkProdE [Exp3]
xs) [Ty3]
tys =
      if [Exp3] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp3]
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Ty3] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ty3]
tys then ((Exp3 -> Ty3 -> Exp3) -> [Exp3] -> [Ty3] -> [Exp3]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Exp3 -> Ty3 -> Exp3
recover [Exp3]
xs [Ty3]
tys, Int
forall a. HasCallStack => a
undefined)
      else [Char] -> ([Exp3], Int)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ([Exp3], Int)) -> [Char] -> ([Exp3], Int)
forall a b. (a -> b) -> a -> b
$ [Char]
"recover': unmatched expression " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Exp3 -> [Char]
forall a. Out a => a -> [Char]
sdoc Exp3
ex0 [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" for type " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Ty3] -> [Char]
forall a. Out a => a -> [Char]
sdoc [Ty3]
tys
    recover' Int
idx Exp3
ex0 (Ty3
ty:[Ty3]
tys)=
      case Ty3
ty of
        ProdTy [Ty3]
tys' ->
          let ([Exp3]
res, Int
idx') = Int -> Exp3 -> [Ty3] -> ([Exp3], Int)
recover' Int
idx Exp3
ex0 [Ty3]
tys'
              ([Exp3]
res', Int
idx'') = Int -> Exp3 -> [Ty3] -> ([Exp3], Int)
recover' Int
idx' Exp3
ex0 [Ty3]
tys
          in  ([Exp3] -> Exp3
forall (ext :: * -> * -> *) loc dec.
[PreExp ext loc dec] -> PreExp ext loc dec
mkProd [Exp3]
resExp3 -> [Exp3] -> [Exp3]
forall a. a -> [a] -> [a]
:[Exp3]
res', Int
idx'')
        Ty3
_ ->
          let ([Exp3]
res, Int
idx') = Int -> Exp3 -> [Ty3] -> ([Exp3], Int)
recover' (Int
idxInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Exp3
ex0 [Ty3]
tys
          in  (Int -> Exp3 -> Exp3
forall (ext :: * -> * -> *) loc dec.
Int -> PreExp ext loc dec -> PreExp ext loc dec
mkProj Int
idx Exp3
ex0Exp3 -> [Exp3] -> [Exp3]
forall a. a -> [a] -> [a]
:[Exp3]
res, Int
idx')


    [a]
ls ! :: [a] -> Int -> a
! Int
i = if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
ls
             then [a]
ls[a] -> Int -> a
forall a. HasCallStack => [a] -> Int -> a
!!Int
i
             else [Char] -> a
forall a. HasCallStack => [Char] -> a
error([Char] -> a) -> [Char] -> a
forall a b. (a -> b) -> a -> b
$ [Char]
"unariserExp: attempt to project index "[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++Int -> [Char]
forall a. Show a => a -> [Char]
show Int
i[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
" of list:\n "[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[a] -> [Char]
forall a. Out a => a -> [Char]
sdoc [a]
ls


-- | Flatten nested tuples
flattenProd :: DDefs Ty3 -> ProjStack -> Env2 Ty3 -> Exp3 -> PassM Exp3
flattenProd :: DDefs Ty3 -> ProjStack -> Env2 Ty3 -> Exp3 -> PassM Exp3
flattenProd DDefs Ty3
ddfs ProjStack
stk Env2 Ty3
env2 Exp3
ex =
  case Exp3
ex of
    MkProdE{} -> do
      let flat1 :: [Exp3]
flat1 = Exp3 -> [Exp3]
go Exp3
ex
          tys :: [Ty3]
tys = (Exp3 -> Ty3) -> [Exp3] -> [Ty3]
forall a b. (a -> b) -> [a] -> [b]
L.map (Ty3 -> Ty3
flattenTy (Ty3 -> Ty3) -> (Exp3 -> Ty3) -> Exp3 -> Ty3
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DDefs (TyOf Exp3) -> Env2 (TyOf Exp3) -> Exp3 -> TyOf Exp3
forall e.
Typeable e =>
DDefs (TyOf e) -> Env2 (TyOf e) -> e -> TyOf e
gRecoverType DDefs (TyOf Exp3)
DDefs Ty3
ddfs Env2 (TyOf Exp3)
Env2 Ty3
env2) [Exp3]
flat1
      [Exp3] -> Exp3
forall (ext :: * -> * -> *) loc dec.
[PreExp ext loc dec] -> PreExp ext loc dec
MkProdE ([Exp3] -> Exp3) -> PassM [Exp3] -> PassM Exp3
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Ty3] -> [Exp3] -> PassM [Exp3]
go2 [Ty3]
tys [Exp3]
flat1

    Exp3
oth -> [Char] -> PassM Exp3
forall a. HasCallStack => [Char] -> a
error ([Char] -> PassM Exp3) -> [Char] -> PassM Exp3
forall a b. (a -> b) -> a -> b
$ [Char]
"flattenProd: Unexpected expression: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Exp3 -> [Char]
forall a. Out a => a -> [Char]
sdoc Exp3
oth
  where
    -- Structural flattening. Just flattens nested MkProdE's
    go :: Exp3 -> [Exp3]
    go :: Exp3 -> [Exp3]
go ((MkProdE [Exp3]
js)) = (Exp3 -> [Exp3]) -> [Exp3] -> [Exp3]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Exp3 -> [Exp3]
go [Exp3]
js
    go Exp3
e = [Exp3
e]

    -- Structural flattening might leave behind some nested tuples.
    -- We flatten them here using type information.
    -- Example: let v = [1,2,3]
    --              w = [v,4]
    --
    -- Here, `w` needs further flattening.
    -- We transform it as:
    --
    -- let v = [1,2,3]
    --     w = [proj 0 v, proj 1 v, proj 2 v, 4]
    --
    go2 :: [Ty3] -> [Exp3] -> PassM [Exp3]
    go2 :: [Ty3] -> [Exp3] -> PassM [Exp3]
go2 [] [] = [Exp3] -> PassM [Exp3]
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return []
    go2 (Ty3
t:[Ty3]
ts) (Exp3
e:[Exp3]
es) =
      case (Ty3
t,Exp3
e) of
        (ProdTy [Ty3]
tys, VarE{}) -> do
          let fs :: [Exp3]
fs = [Int -> Exp3 -> Exp3
forall (ext :: * -> * -> *) loc dec.
Int -> PreExp ext loc dec -> PreExp ext loc dec
ProjE Int
n Exp3
e | (Ty3
_ty,Int
n) <- [Ty3] -> ProjStack -> [(Ty3, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ty3]
tys [Int
0..]]
          [Exp3]
es' <- [Ty3] -> [Exp3] -> PassM [Exp3]
go2 [Ty3]
ts [Exp3]
es
          [Exp3] -> PassM [Exp3]
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Exp3] -> PassM [Exp3]) -> [Exp3] -> PassM [Exp3]
forall a b. (a -> b) -> a -> b
$ [Exp3]
fs [Exp3] -> [Exp3] -> [Exp3]
forall a. [a] -> [a] -> [a]
++ [Exp3]
es'
        (Ty3
_ty, ProjE{}) -> do
          Exp3
e' <- Bool -> DDefs Ty3 -> ProjStack -> Env2 Ty3 -> Exp3 -> PassM Exp3
unariserExp Bool
False DDefs Ty3
ddfs ProjStack
stk Env2 Ty3
env2 Exp3
e
          [Exp3]
es' <- [Ty3] -> [Exp3] -> PassM [Exp3]
go2 [Ty3]
ts [Exp3]
es
          [Exp3] -> PassM [Exp3]
forall a. a -> PassM a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Exp3] -> PassM [Exp3]) -> [Exp3] -> PassM [Exp3]
forall a b. (a -> b) -> a -> b
$ Exp3
e'Exp3 -> [Exp3] -> [Exp3]
forall a. a -> [a] -> [a]
: [Exp3]
es'
        (Ty3, Exp3)
_ -> ([Exp3
e] [Exp3] -> [Exp3] -> [Exp3]
forall a. [a] -> [a] -> [a]
++) ([Exp3] -> [Exp3]) -> PassM [Exp3] -> PassM [Exp3]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Ty3] -> [Exp3] -> PassM [Exp3]
go2 [Ty3]
ts [Exp3]
es
    go2 [Ty3]
ts [Exp3]
es = [Char] -> PassM [Exp3]
forall a. HasCallStack => [Char] -> a
error ([Char] -> PassM [Exp3]) -> [Char] -> PassM [Exp3]
forall a b. (a -> b) -> a -> b
$ [Char]
"Unexpected input: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Ty3] -> [Char]
forall a. Out a => a -> [Char]
sdoc [Ty3]
ts [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Exp3] -> [Char]
forall a. Out a => a -> [Char]
sdoc [Exp3]
es


-- | Return an updated index for the flattened type
--
-- >>> 1 (ProdTy [ProdTy [IntTy, IntTy, IntTy], IntTy])
-- 3
flatProjIdx :: Int -> Ty3 -> Int
flatProjIdx :: Int -> Ty3 -> Int
flatProjIdx Int
n Ty3
ty =
  case Ty3
ty of
    ProdTy [Ty3]
tys ->
      let ProdTy [Ty3]
tys' = Ty3 -> Ty3
flattenTy ([Ty3] -> Ty3
forall loc. [UrTy loc] -> UrTy loc
ProdTy ([Ty3] -> Ty3) -> [Ty3] -> Ty3
forall a b. (a -> b) -> a -> b
$ Int -> [Ty3] -> [Ty3]
forall a. Int -> [a] -> [a]
take Int
n [Ty3]
tys)
      in [Ty3] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ty3]
tys'
    Ty3
_ -> [Char] -> Int
forall a. HasCallStack => [Char] -> a
error ([Char] -> Int) -> [Char] -> Int
forall a b. (a -> b) -> a -> b
$ [Char]
"flatProjIdx: non-product type given: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Ty3 -> [Char]
forall a. Show a => a -> [Char]
show Ty3
ty


-- | Flatten nested tuple types.
-- Example:
--
-- ProdTy [IntTy, ProdTy [IntTy, IntTy, IntTy, ProdTy [IntTy, IntTy]]] =>
-- ProdTy [IntTy, IntTy, IntTy, IntTy, IntTy, IntTy]
--
flattenTy :: Ty3 -> Ty3
flattenTy :: Ty3 -> Ty3
flattenTy Ty3
ty = 
  case Ty3
ty of
    ProdTy [Ty3]
_ -> [Ty3] -> Ty3
forall loc. [UrTy loc] -> UrTy loc
ProdTy ([Ty3] -> Ty3) -> [Ty3] -> Ty3
forall a b. (a -> b) -> a -> b
$ Ty3 -> [Ty3]
go Ty3
ty
    Ty3
_ -> Ty3
ty
  where go :: Ty3 -> [Ty3]
        go :: Ty3 -> [Ty3]
go (ProdTy [Ty3]
tys) = (Ty3 -> [Ty3]) -> [Ty3] -> [Ty3]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Ty3 -> [Ty3]
go [Ty3]
tys
        go Ty3
ty' = [Ty3
ty']

-- | Flatten nested tuples in a type-safe way
--
flattenExp :: Var -> Ty3 -> Exp3 -> Exp3
flattenExp :: Var -> Ty3 -> Exp3 -> Exp3
flattenExp Var
v Ty3
ty Exp3
bod =
  case Ty3
ty of
    ProdTy [Ty3]
_ ->
      let
          -- | Generate projections for non-product types inside a tuple
          --
          -- Examples:
          -- (1) ProdTy [IntTy, ProdTy [IntTy, IntTy, IntTy]] ==
          --     [[0],[1,0],[1,1],[1,2]]
          --
          -- (2) ProdTy [IntTy, ProdTy [IntTy, IntTy, IntTy, ProdTy [IntTy, IntTy]]]
          --     [[0],[1,0],[1,1],[1,2],[1,3,0],[1,3,1]]
          --
          projections :: Ty3 -> ProjStack -> [ProjStack]
          projections :: Ty3 -> ProjStack -> [ProjStack]
projections (ProdTy [Ty3]
tys) ProjStack
acc =
            ((Ty3, Int) -> [ProjStack]) -> [(Ty3, Int)] -> [ProjStack]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(Ty3
ty',Int
i) -> Ty3 -> ProjStack -> [ProjStack]
projections Ty3
ty' (Int
iInt -> ProjStack -> ProjStack
forall a. a -> [a] -> [a]
:ProjStack
acc)) ([Ty3] -> ProjStack -> [(Ty3, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ty3]
tys [Int
0..])
          projections Ty3
_ ProjStack
acc = [ProjStack
acc]

          projs :: [ProjStack]
projs = Ty3 -> ProjStack -> [ProjStack]
projections Ty3
ty []
          substs :: [(Exp3, Exp3)]
substs = (ProjStack -> (Exp3, Exp3)) -> [ProjStack] -> [(Exp3, Exp3)]
forall a b. (a -> b) -> [a] -> [b]
map (\ProjStack
ps -> ((Int -> Exp3 -> Exp3) -> Exp3 -> ProjStack -> Exp3
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Int -> Exp3 -> Exp3
forall (ext :: * -> * -> *) loc dec.
Int -> PreExp ext loc dec -> PreExp ext loc dec
ProjE (Var -> Exp3
forall (ext :: * -> * -> *) loc dec. Var -> PreExp ext loc dec
VarE Var
v) ProjStack
ps,
                                Int -> Exp3 -> Exp3
forall (ext :: * -> * -> *) loc dec.
Int -> PreExp ext loc dec -> PreExp ext loc dec
ProjE (ProjStack -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ProjStack
ps) (Var -> Exp3
forall (ext :: * -> * -> *) loc dec. Var -> PreExp ext loc dec
VarE Var
v)))
                   [ProjStack]
projs
          -- FIXME: This is in-efficient because of the substE ?
      in ((Exp3, Exp3) -> Exp3 -> Exp3) -> Exp3 -> [(Exp3, Exp3)] -> Exp3
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(Exp3
from,Exp3
to) Exp3
acc -> Exp3 -> Exp3 -> Exp3 -> Exp3
forall (e :: * -> * -> *) l d.
HasSubstitutable e l d =>
PreExp e l d -> PreExp e l d -> PreExp e l d -> PreExp e l d
substE Exp3
from Exp3
to Exp3
acc) Exp3
bod [(Exp3, Exp3)]
substs
    Ty3
_ -> Exp3
bod