module Core.CaseTree(CaseDef(..), SC, SC'(..), CaseAlt, CaseAlt'(..),
Phase(..), CaseTree,
simpleCase, small, namesUsed, findCalls, findUsedArgs) where
import Core.TT
import Control.Monad.State
import Data.Maybe
import Data.List hiding (partition)
import Debug.Trace
data CaseDef = CaseDef [Name] !SC [Term]
deriving Show
data SC' t = Case Name [CaseAlt' t]
| ProjCase t [CaseAlt' t]
| STerm !t
| UnmatchedCase String
| ImpossibleCase
deriving (Eq, Ord, Functor)
type SC = SC' Term
data CaseAlt' t = ConCase Name Int [Name] !(SC' t)
| FnCase Name [Name] !(SC' t)
| ConstCase Const !(SC' t)
| SucCase Name !(SC' t)
| DefaultCase !(SC' t)
deriving (Show, Eq, Ord, Functor)
type CaseAlt = CaseAlt' Term
instance Show t => Show (SC' t) where
show sc = show' 1 sc
where
show' i (Case n alts) = "case " ++ show n ++ " of\n" ++ indent i ++
showSep ("\n" ++ indent i) (map (showA i) alts)
show' i (ProjCase tm alts) = "case " ++ show tm ++ " of " ++
showSep ("\n" ++ indent i) (map (showA i) alts)
show' i (STerm tm) = show tm
show' i (UnmatchedCase str) = "error " ++ show str
show' i ImpossibleCase = "impossible"
indent i = concat $ take i (repeat " ")
showA i (ConCase n t args sc)
= show n ++ "(" ++ showSep (", ") (map show args) ++ ") => "
++ show' (i+1) sc
showA i (FnCase n args sc)
= "FN " ++ show n ++ "(" ++ showSep (", ") (map show args) ++ ") => "
++ show' (i+1) sc
showA i (ConstCase t sc)
= show t ++ " => " ++ show' (i+1) sc
showA i (SucCase n sc)
= show n ++ "+1 => " ++ show' (i+1) sc
showA i (DefaultCase sc)
= "_ => " ++ show' (i+1) sc
type CaseTree = SC
type Clause = ([Pat], (Term, Term))
type CS = ([Term], Int)
instance TermSize SC where
termsize n (Case n' as) = termsize n as
termsize n (ProjCase n' as) = termsize n as
termsize n (STerm t) = termsize n t
termsize n _ = 1
instance TermSize CaseAlt where
termsize n (ConCase _ _ _ s) = termsize n s
termsize n (FnCase _ _ s) = termsize n s
termsize n (ConstCase _ s) = termsize n s
termsize n (SucCase _ s) = termsize n s
termsize n (DefaultCase s) = termsize n s
small :: Name -> [Name] -> SC -> Bool
small n args t = let as = findAllUsedArgs t args in
length as == length (nub as) &&
termsize n t < 10
namesUsed :: SC -> [Name]
namesUsed sc = nub $ nu' [] sc where
nu' ps (Case n alts) = nub (concatMap (nua ps) alts) \\ [n]
nu' ps (ProjCase t alts) = nub $ (nut ps t ++
(concatMap (nua ps) alts))
nu' ps (STerm t) = nub $ nut ps t
nu' ps _ = []
nua ps (ConCase n i args sc) = nub (nu' (ps ++ args) sc) \\ args
nua ps (FnCase n args sc) = nub (nu' (ps ++ args) sc) \\ args
nua ps (ConstCase _ sc) = nu' ps sc
nua ps (SucCase _ sc) = nu' ps sc
nua ps (DefaultCase sc) = nu' ps sc
nut ps (P _ n _) | n `elem` ps = []
| otherwise = [n]
nut ps (App f a) = nut ps f ++ nut ps a
nut ps (Proj t _) = nut ps t
nut ps (Bind n (Let t v) sc) = nut ps v ++ nut (n:ps) sc
nut ps (Bind n b sc) = nut (n:ps) sc
nut ps _ = []
findCalls :: SC -> [Name] -> [(Name, [[Name]])]
findCalls sc topargs = nub $ nu' topargs sc where
nu' ps (Case n alts) = nub (concatMap (nua (n : ps)) alts)
nu' ps (ProjCase t alts) = nub (nut ps t ++ concatMap (nua ps) alts)
nu' ps (STerm t) = nub $ nut ps t
nu' ps _ = []
nua ps (ConCase n i args sc) = nub (nu' (ps ++ args) sc)
nua ps (FnCase n args sc) = nub (nu' (ps ++ args) sc)
nua ps (ConstCase _ sc) = nu' ps sc
nua ps (SucCase _ sc) = nu' ps sc
nua ps (DefaultCase sc) = nu' ps sc
nut ps (P Ref n _) | n `elem` ps = []
| otherwise = [(n, [])]
nut ps fn@(App f a)
| (P Ref n _, args) <- unApply fn
= if n `elem` ps then nut ps f ++ nut ps a
else [(n, map argNames args)] ++ concatMap (nut ps) args
| (P (TCon _ _) n _, _) <- unApply fn = []
| otherwise = nut ps f ++ nut ps a
nut ps (Bind n (Let t v) sc) = nut ps v ++ nut (n:ps) sc
nut ps (Proj t _) = nut ps t
nut ps (Bind n b sc) = nut (n:ps) sc
nut ps _ = []
argNames tm = let ns = directUse tm in
filter (\x -> x `elem` ns) topargs
directUse :: Eq n => TT n -> [n]
directUse (P _ n _) = [n]
directUse (Bind n (Let t v) sc) = nub $ directUse v ++ (directUse sc \\ [n])
++ directUse t
directUse (Bind n b sc) = nub $ directUse (binderTy b) ++ (directUse sc \\ [n])
directUse fn@(App f a)
| (P Ref n _, args) <- unApply fn = []
| otherwise = nub $ directUse f ++ directUse a
directUse (Proj x i) = nub $ directUse x
directUse _ = []
findUsedArgs :: SC -> [Name] -> [Name]
findUsedArgs sc topargs = nub (findAllUsedArgs sc topargs)
findAllUsedArgs sc topargs = filter (\x -> x `elem` topargs) (nu' sc) where
nu' (Case n alts) = n : concatMap nua alts
nu' (ProjCase t alts) = directUse t ++ concatMap nua alts
nu' (STerm t) = directUse t
nu' _ = []
nua (ConCase n i args sc) = nu' sc
nua (FnCase n args sc) = nu' sc
nua (ConstCase _ sc) = nu' sc
nua (SucCase _ sc) = nu' sc
nua (DefaultCase sc) = nu' sc
data Phase = CompileTime | RunTime
deriving (Show, Eq)
simpleCase :: Bool -> Bool -> Bool ->
Phase -> FC -> [([Name], Term, Term)] ->
TC CaseDef
simpleCase tc cover reflect phase fc cs
= sc' tc cover phase fc (filter (\(_, _, r) ->
case r of
Impossible -> False
_ -> True) cs)
where
sc' tc cover phase fc []
= return $ CaseDef [] (UnmatchedCase "No pattern clauses") []
sc' tc cover phase fc cs
= let proj = phase == RunTime
pats = map (\ (avs, l, r) ->
(avs, toPats reflect tc l, (l, r))) cs
chkPats = mapM chkAccessible pats in
case chkPats of
OK pats ->
let numargs = length (fst (head pats))
ns = take numargs args
(ns', ps') = order ns pats
(tree, st) = runState
(match ns' ps' (defaultCase cover)) ([], numargs)
t = CaseDef ns (prune proj (depatt ns' tree)) (fst st) in
if proj then return (stripLambdas t) else return t
Error err -> Error (At fc err)
where args = map (\i -> MN i "e") [0..]
defaultCase True = STerm Erased
defaultCase False = UnmatchedCase "Error"
chkAccessible (avs, l, c)
| phase == RunTime || reflect = return (l, c)
| otherwise = do mapM_ (acc l) avs
return (l, c)
acc [] n = Error (Inaccessible n)
acc (PV x : xs) n | x == n = OK ()
acc (PCon _ _ ps : xs) n = acc (ps ++ xs) n
acc (PSuc p : xs) n = acc (p : xs) n
acc (_ : xs) n = acc xs n
data Pat = PCon Name Int [Pat]
| PConst Const
| PV Name
| PSuc Pat
| PReflected Name [Pat]
| PAny
deriving Show
toPats :: Bool -> Bool -> Term -> [Pat]
toPats reflect tc f = reverse (toPat reflect tc (getArgs f)) where
getArgs (App f a) = a : getArgs f
getArgs _ = []
toPat :: Bool -> Bool -> [Term] -> [Pat]
toPat reflect tc tms = evalState (mapM (\x -> toPat' x []) tms) []
where
toPat' (P (DCon t a) n _) args = do args' <- mapM (\x -> toPat' x []) args
return $ PCon n t args'
toPat' (P _ (UN "prim__addBigInt") _)
[p, Constant (BI 1)]
= do p' <- toPat' p []
return $ PSuc p'
toPat' (P (TCon t a) n _) args | tc
= do args' <- mapM (\x -> toPat' x []) args
return $ PCon n t args'
toPat' (Constant (AType (ATInt ITNative))) []
| tc = return $ PCon (UN "Int") 1 []
toPat' (Constant (AType ATFloat)) [] | tc = return $ PCon (UN "Float") 2 []
toPat' (Constant (AType (ATInt ITChar))) [] | tc = return $ PCon (UN "Char") 3 []
toPat' (Constant StrType) [] | tc = return $ PCon (UN "String") 4 []
toPat' (Constant PtrType) [] | tc = return $ PCon (UN "Ptr") 5 []
toPat' (Constant (AType (ATInt ITBig))) []
| tc = return $ PCon (UN "Integer") 6 []
toPat' (Constant (AType (ATInt (ITFixed n)))) []
| tc = return $ PCon (UN (fixedN n)) (7 + fromEnum n) []
toPat' (P Bound n _) [] = do ns <- get
if n `elem` ns
then return PAny
else do put (n : ns)
return (PV n)
toPat' (App f a) args = toPat' f (a : args)
toPat' (Constant x) [] = return $ PConst x
toPat' (Bind n (Pi t) sc) [] | reflect && noOccurrence n sc
= do t' <- toPat' t []
sc' <- toPat' sc []
return $ PReflected (UN "->") (t':sc':[])
toPat' (P _ n _) args | reflect
= do args' <- mapM (\x -> toPat' x []) args
return $ PReflected n args'
toPat' t _ = return PAny
fixedN IT8 = "Bits8"
fixedN IT16 = "Bits16"
fixedN IT32 = "Bits32"
fixedN IT64 = "Bits64"
data Partition = Cons [Clause]
| Vars [Clause]
deriving Show
isVarPat (PV _ : ps , _) = True
isVarPat (PAny : ps , _) = True
isVarPat _ = False
isConPat (PCon _ _ _ : ps, _) = True
isConPat (PReflected _ _ : ps, _) = True
isConPat (PSuc _ : ps, _) = True
isConPat (PConst _ : ps, _) = True
isConPat _ = False
partition :: [Clause] -> [Partition]
partition [] = []
partition ms@(m : _)
| isVarPat m = let (vars, rest) = span isVarPat ms in
Vars vars : partition rest
| isConPat m = let (cons, rest) = span isConPat ms in
Cons cons : partition rest
partition xs = error $ "Partition " ++ show xs
order :: [Name] -> [Clause] -> ([Name], [Clause])
order [] cs = ([], cs)
order ns [] = (ns, [])
order ns cs = let patnames = transpose (map (zip ns) (map fst cs))
pats' = transpose (sortBy moreDistinct (reverse patnames)) in
(getNOrder pats', zipWith rebuild pats' cs)
where
getNOrder [] = error $ "Failed order on " ++ show (ns, cs)
getNOrder (c : _) = map fst c
rebuild patnames clause = (map snd patnames, snd clause)
moreDistinct xs ys = compare (numNames [] (map snd ys))
(numNames [] (map snd xs))
numNames xs (PCon n _ _ : ps)
| not (Left n `elem` xs) = numNames (Left n : xs) ps
numNames xs (PConst c : ps)
| not (Right c `elem` xs) = numNames (Right c : xs) ps
numNames xs (_ : ps) = numNames xs ps
numNames xs [] = length xs
match :: [Name] -> [Clause] -> SC
-> State CS SC
match [] (([], ret) : xs) err
= do (ts, v) <- get
put (ts ++ (map (fst.snd) xs), v)
case snd ret of
Impossible -> return ImpossibleCase
tm -> return $ STerm tm
match vs cs err = do let ps = partition cs
mixture vs ps err
mixture :: [Name] -> [Partition] -> SC -> State CS SC
mixture vs [] err = return err
mixture vs (Cons ms : ps) err = do fallthrough <- mixture vs ps err
conRule vs ms fallthrough
mixture vs (Vars ms : ps) err = do fallthrough <- mixture vs ps err
varRule vs ms fallthrough
data ConType = CName Name Int
| CFn Name
| CSuc
| CConst Const
deriving (Show, Eq)
data Group = ConGroup ConType
[([Pat], Clause)]
deriving Show
conRule :: [Name] -> [Clause] -> SC -> State CS SC
conRule (v:vs) cs err = do groups <- groupCons cs
caseGroups (v:vs) groups err
caseGroups :: [Name] -> [Group] -> SC -> State CS SC
caseGroups (v:vs) gs err = do g <- altGroups gs
return $ Case v (sort g)
where
altGroups [] = return [DefaultCase err]
altGroups (ConGroup (CName n i) args : cs)
= do g <- altGroup n i args
rest <- altGroups cs
return (g : rest)
altGroups (ConGroup (CFn n) args : cs)
= do g <- altFnGroup n args
rest <- altGroups cs
return (g : rest)
altGroups (ConGroup CSuc args : cs)
= do g <- altSucGroup args
rest <- altGroups cs
return (g : rest)
altGroups (ConGroup (CConst c) args : cs)
= do g <- altConstGroup c args
rest <- altGroups cs
return (g : rest)
altGroup n i gs = do (newArgs, nextCs) <- argsToAlt gs
matchCs <- match (newArgs ++ vs) nextCs err
return $ ConCase n i newArgs matchCs
altFnGroup n gs = do (newArgs, nextCs) <- argsToAlt gs
matchCs <- match (newArgs ++ vs) nextCs err
return $ FnCase n newArgs matchCs
altSucGroup gs = do ([newArg], nextCs) <- argsToAlt gs
matchCs <- match (newArg:vs) nextCs err
return $ SucCase newArg matchCs
altConstGroup n gs = do (_, nextCs) <- argsToAlt gs
matchCs <- match vs nextCs err
return $ ConstCase n matchCs
argsToAlt :: [([Pat], Clause)] -> State CS ([Name], [Clause])
argsToAlt [] = return ([], [])
argsToAlt rs@((r, m) : rest)
= do newArgs <- getNewVars r
return (newArgs, addRs rs)
where
getNewVars [] = return []
getNewVars ((PV n) : ns) = do v <- getVar "e"
nsv <- getNewVars ns
return (v : nsv)
getNewVars (PAny : ns) = do v <- getVar "i"
nsv <- getNewVars ns
return (v : nsv)
getNewVars (_ : ns) = do v <- getVar "e"
nsv <- getNewVars ns
return (v : nsv)
addRs [] = []
addRs ((r, (ps, res)) : rs) = ((r++ps, res) : addRs rs)
uniq i (UN n) = MN i n
uniq i n = n
getVar :: String -> State CS Name
getVar b = do (t, v) <- get; put (t, v+1); return (MN v b)
groupCons :: [Clause] -> State CS [Group]
groupCons cs = gc [] cs
where
gc acc [] = return acc
gc acc ((p : ps, res) : cs) =
do acc' <- addGroup p ps res acc
gc acc' cs
addGroup p ps res acc = case p of
PCon con i args -> return $ addg (CName con i) args (ps, res) acc
PConst cval -> return $ addConG cval (ps, res) acc
PSuc n -> return $ addg CSuc [n] (ps, res) acc
PReflected fn args -> return $ addg (CFn fn) args (ps, res) acc
pat -> fail $ show pat ++ " is not a constructor or constant (can't happen)"
addg c conargs res []
= [ConGroup c [(conargs, res)]]
addg c conargs res (g@(ConGroup c' cs):gs)
| c == c' = ConGroup c (cs ++ [(conargs, res)]) : gs
| otherwise = g : addg c conargs res gs
addConG con res [] = [ConGroup (CConst con) [([], res)]]
addConG con res (g@(ConGroup (CConst n) cs) : gs)
| con == n = ConGroup (CConst n) (cs ++ [([], res)]) : gs
addConG con res (g : gs) = g : addConG con res gs
varRule :: [Name] -> [Clause] -> SC -> State CS SC
varRule (v : vs) alts err =
do let alts' = map (repVar v) alts
match vs alts' err
where
repVar v (PV p : ps , (lhs, res))
= (ps, (lhs, subst p (P Bound v Erased) res))
repVar v (PAny : ps , res) = (ps, res)
depatt :: [Name] -> SC -> SC
depatt ns tm = dp [] tm
where
dp ms (STerm tm) = STerm (applyMaps ms tm)
dp ms (Case x alts) = Case x (map (dpa ms x) alts)
dp ms sc = sc
dpa ms x (ConCase n i args sc)
= ConCase n i args (dp ((x, (n, args)) : ms) sc)
dpa ms x (FnCase n args sc)
= FnCase n args (dp ((x, (n, args)) : ms) sc)
dpa ms x (ConstCase c sc) = ConstCase c (dp ms sc)
dpa ms x (SucCase n sc) = SucCase n (dp ms sc)
dpa ms x (DefaultCase sc) = DefaultCase (dp ms sc)
applyMaps ms f@(App _ _)
| (P nt cn pty, args) <- unApply f
= let args' = map (applyMaps ms) args in
applyMap ms nt cn pty args'
where
applyMap [] nt cn pty args' = mkApp (P nt cn pty) args'
applyMap ((x, (n, args)) : ms) nt cn pty args'
| and ((length args == length args') :
(n == cn) : zipWith same args args') = P Ref x Erased
| otherwise = applyMap ms nt cn pty args'
same n (P _ n' _) = n == n'
same _ _ = False
applyMaps ms (App f a) = App (applyMaps ms f) (applyMaps ms a)
applyMaps ms t = t
prune :: Bool
-> SC -> SC
prune proj (Case n alts)
= let alts' = filter notErased (map pruneAlt alts) in
case alts' of
[] -> ImpossibleCase
as@[ConCase cn i args sc] -> if proj then mkProj n 0 args sc
else Case n as
as@[SucCase cn sc] -> if proj then mkProj n (1) [cn] sc
else Case n as
as@[ConstCase _ sc] -> prune proj sc
[s@(SucCase _ _), DefaultCase dc]
-> Case n [ConstCase (BI 0) dc, s]
as -> Case n as
where pruneAlt (ConCase cn i ns sc) = ConCase cn i ns (prune proj sc)
pruneAlt (FnCase cn ns sc) = FnCase cn ns (prune proj sc)
pruneAlt (ConstCase c sc) = ConstCase c (prune proj sc)
pruneAlt (SucCase n sc) = SucCase n (prune proj sc)
pruneAlt (DefaultCase sc) = DefaultCase (prune proj sc)
notErased (DefaultCase (STerm Erased)) = False
notErased (DefaultCase ImpossibleCase) = False
notErased _ = True
mkProj n i [] sc = prune proj sc
mkProj n i (x : xs) sc = mkProj n (i + 1) xs (projRep x n i sc)
projRep :: Name -> Name -> Int -> SC -> SC
projRep arg n i (Case x alts)
| x == arg = ProjCase (Proj (P Bound n Erased) i)
(map (projRepAlt arg n i) alts)
| otherwise = Case x (map (projRepAlt arg n i) alts)
projRep arg n i (ProjCase t alts)
= ProjCase (projRepTm arg n i t) (map (projRepAlt arg n i) alts)
projRep arg n i (STerm t) = STerm (projRepTm arg n i t)
projRep arg n i c = c
projRepAlt arg n i (ConCase cn t args rhs)
= ConCase cn t args (projRep arg n i rhs)
projRepAlt arg n i (FnCase cn args rhs)
= FnCase cn args (projRep arg n i rhs)
projRepAlt arg n i (ConstCase t rhs)
= ConstCase t (projRep arg n i rhs)
projRepAlt arg n i (SucCase sn rhs)
= SucCase sn (projRep arg n i rhs)
projRepAlt arg n i (DefaultCase rhs)
= DefaultCase (projRep arg n i rhs)
projRepTm arg n i t = subst arg (Proj (P Bound n Erased) i) t
prune _ t = t
stripLambdas :: CaseDef -> CaseDef
stripLambdas (CaseDef ns (STerm (Bind x (Lam _) sc)) tm)
= stripLambdas (CaseDef (ns ++ [x]) (STerm (instantiate (P Bound x Erased) sc)) tm)
stripLambdas x = x