Skip to content
This repository has been archived by the owner on Oct 18, 2021. It is now read-only.

Shadowing related optimiser fixes #265

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .ghci
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ import qualified System.Environment as E
:def amc-trace \xs -> pure $ "E.setEnv \"AMC_TRACE\" \"" ++ xs ++ "\""

:set -fobject-code
:def l const (pure ":list")
:def c const (pure ":continue")
:def s const (pure ":step")
18 changes: 17 additions & 1 deletion src/Core/Lint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ data CoreError
| InfoMismatch CoVar VarInfo VarInfo
| InfoIllegal CoVar VarInfo VarInfo
| NoSuchVar CoVar
| Duplicate CoVar
| IllegalUnbox
| InvalidCoercion Coercion
| PatternMismatch [(CoVar, Type)] [(CoVar, Type)]
Expand Down Expand Up @@ -73,6 +74,7 @@ instance Pretty CoreError where
text " got var info" <+> string (show r) </>
text "for" <+> pretty v
pretty (NoSuchVar a) = text "No such variable" <+> pretty a
pretty (Duplicate a) = text "Duplicate declaration of" <+> pretty a
pretty IllegalUnbox = text "Illegal unboxed type"
pretty (InvalidCoercion a) = text "Illegal coercion" <+> pretty a
pretty (PatternMismatch l r) = text "Expected vars" <+> pVs l </>
Expand Down Expand Up @@ -115,8 +117,8 @@ checkStmt s (Foreign v ty b:xs) = do
es <- gatherError' . liftError $
-- Ensure we're declaring a value
unless (varInfo v == ValueVar) (pushError (InfoIllegal (toVar v) ValueVar (varInfo v)))
-- And the type is well formed
*> checkType s ty
*> checkNodup v (vars s)

((Foreign v ty b, es):) <$> checkStmt (s { vars = insertVar v ty (vars s) }) xs

Expand All @@ -130,6 +132,7 @@ checkStmt s (StmtLet (One (v, ty, e)):xs) = do
_ -> pure ())
*> unless (varInfo v == ValueVar) (pushError (InfoIllegal (toVar v) ValueVar (varInfo v)))
*> checkType s ty
*> checkNodup v (vars s)

((StmtLet (One (v, ty, e')), es):) <$> checkStmt (s { vars = insertVar v ty (vars s) }) xs
checkStmt s (StmtLet (Many vs):xs) = do
Expand All @@ -144,6 +147,7 @@ checkStmt s (StmtLet (Many vs):xs) = do
_ -> pure ())
*> unless (varInfo v == ValueVar) (pushError (InfoIllegal (toVar v) ValueVar (varInfo v)))
*> checkType s ty
*> checkNodup v (vars s)
pure ((v, ty, e'), es)

((StmtLet (Many vs'), es):) <$> checkStmt s' xs
Expand Down Expand Up @@ -197,6 +201,7 @@ checkTerm s (Lam arg@(TermArgument a ty) bod) = do
-- Ensure type is valid and we're declaring a value
unless (varInfo a == ValueVar) (pushError (InfoIllegal (toVar a) ValueVar (varInfo a)))
*> checkType s ty
*> checkNodup a (vars s)

(bty, bod') <- checkTerm (s { vars = insertVar a ty (vars s) }) bod
pure ( ForallTy Irrelevant ty <$> bty
Expand All @@ -206,6 +211,7 @@ checkTerm s (Lam arg@(TypeArgument a ty) bod) = do
-- Ensure type is valid and we're declaring a tyvar
unless (varInfo a == TypeVar) (pushError (InfoIllegal (toVar a) TypeVar (varInfo a)))
*> checkType (s { tyVars = VarSet.insert (toVar a) (tyVars s) }) ty
*> checkNodup' a (tyVars s)

(bty, bod') <- checkTerm (s { tyVars = VarSet.insert (toVar a) (tyVars s) }) bod
pure ( ForallTy (Relevant (toVar a)) ty <$> bty
Expand All @@ -222,6 +228,7 @@ checkTerm s (Let (One (v, ty, e)) r) = do
_ -> pure ())
*> unless (varInfo v == ValueVar) (pushError (InfoIllegal (toVar v) ValueVar (varInfo v)))
*> checkType s ty
*> checkNodup v (vars s)

pure ( tyr, AnnLet es (One (v, ty, e')) r')

Expand All @@ -236,6 +243,7 @@ checkTerm s (Let (Many vs) r) = do
Just ty' | ty `apart` ty' -> pushError (TypeMismatch ty ty')
_ -> pure ())
*> unless (varInfo v == ValueVar) (pushError (InfoIllegal (toVar v) ValueVar (varInfo v)))
*> checkNodup v (vars s)
*> checkType s ty
pure ((v, ty, e'), es)

Expand All @@ -257,6 +265,7 @@ checkTerm s (Match e bs) = do
_ -> pure ())
*> when (vs /= patternVars p) (pushError (PatternMismatch (first toVar <$> patternVars p) (first toVar <$> vs)))
*> checkPattern s ty p
*> traverse_ (\(x, _) -> checkNodup x (vars s)) pVars
pure ((tyr, Arm p ty r' vs tvs), es)

-- Verify the types are consistent
Expand Down Expand Up @@ -495,6 +504,13 @@ liftError m = case runErrors m of
Left e -> throwError e
Right x -> pure x

checkNodup :: IsVar a => a -> VarMap.Map b -> Errors CoreErrors ()
checkNodup v m = when (toVar v `VarMap.member` m) (pushError (Duplicate (toVar v)))

checkNodup' :: IsVar a => a -> VarSet.Set -> Errors CoreErrors ()
checkNodup' v m = when (toVar v `VarSet.member` m) (pushError (Duplicate (toVar v)))


gatherError :: MonadWriter LintResult m => ExceptT CoreErrors m b -> m (Maybe b, CoreErrors)
gatherError m = do
res <- runExceptT m
Expand Down
22 changes: 12 additions & 10 deletions src/Core/Optimise/Reduce.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ import Control.Arrow hiding ((<+>))
import qualified Data.Map.Strict as Map
import qualified Data.VarMap as VarMap
import qualified Data.VarSet as VarSet
import Data.Foldable
import Data.Triple
import Data.Graph
import Data.Maybe
import Data.List

import Core.Optimise.Reduce.Pattern
import Core.Optimise.Reduce.Inline
Expand Down Expand Up @@ -471,15 +471,17 @@ reduceTermK _ (AnnMatch _ test arms) cont = do
. (armBody .~ substituteInTys tySubst body')

reduceBody :: (Term a -> m (Term a)) -> Subst a -> AnnTerm VarSet.Set (OccursVar a) -> m (Term a)
reduceBody cont subst body =
let (sub, binds) = foldr
(\(var, a) (sub, binds) ->
(VarMap.insert (toVar var) (basicDef var (Atom a)) sub,
if isTrivialAtom a
then binds
else Let (One (var, approximateAtomType a, Atom a)) . binds))
(mempty, id) subst
in binds <$> local (varScope %~ VarMap.union sub) (reduceTermK UsedOther body cont)
reduceBody cont subst body = do
(sub, binds) <- foldrM (\(var, a) (sub, binds) ->
if isTrivialAtom a
then pure ( VarMap.insert (toVar var) (basicDef var (Atom a)) sub, binds )
else do
let ty = approximateAtomType a
v <- freshFrom' var
pure ( VarMap.insert (toVar var) (basicDef var (Atom (Ref (toVar v) ty))) sub
, Let (One (v, ty, Atom a)) . binds ))
(mempty, id) subst
binds <$> local (varScope %~ VarMap.union sub) (reduceTermK UsedOther body cont)

foldVar :: [(a, Type)] -> (a, Atom) -> Maybe (VarMap.Map Type) -> Maybe (VarMap.Map Type)
foldVar _ _ Nothing = Nothing
Expand Down
11 changes: 5 additions & 6 deletions src/Core/Optimise/SAT.hs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ import Data.Semigroup
import Data.Maybe
import Data.List


-- | Do the static argument transformation on a whole program.
staticArgsPass :: (MonadNamey m, IsVar a) => [Stmt a] -> m [Stmt a]
staticArgsPass = traverse staticArg_stmt
Expand Down Expand Up @@ -174,13 +173,13 @@ doStaticArgs the_func the_type the_body =
mkShadow worker =
let go_dynamic args = do
inside <- mkApps (Ref worker worker_ty) worker_ty args
pure $ foldr Lam inside args
refresh $ foldr Lam inside args

go (Static (TypeArgument _ k):xs) = do
x <- fromVar . mkTyvar <$> genName
x <- fresh' TypeVar
Lam (TypeArgument x k) <$> go xs
go (Static (TermArgument _ k):xs) = do
x <- fromVar . mkVal <$> genName
x <- fresh' ValueVar
Lam (TermArgument x k) <$> go xs
go [] = go_dynamic non_static_bndrs
go _ = error "NonStatic binder in static_binders"
Expand Down Expand Up @@ -229,11 +228,11 @@ isStatic _ _ = NonStatic
mkApps :: forall a m. (IsVar a, MonadNamey m) => Atom -> Type -> [Argument a] -> m (Term a)
mkApps at _ [] = pure $ Atom at
mkApps at (ForallTy Irrelevant _ t) (TermArgument x tau:xs) = do
this_app <- fromVar . mkVal <$> genName
this_app <- fresh' ValueVar
Let (One (this_app, t, App at (Ref (toVar x) tau))) <$>
mkApps (Ref (toVar this_app) t) t xs
mkApps at (ForallTy r _ t) (TypeArgument v _:xs) = do
this_app <- fromVar . mkVal <$> genName
this_app <- fresh' ValueVar
let t' =
case r of
Relevant binder -> substituteInType (VarMap.singleton binder (VarTy (toVar v))) t
Expand Down
10 changes: 10 additions & 0 deletions tests/lua/opt_sat_inline.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
do
local E = { __tag = "E" }
local function foldr_sat(zero, x)
if x.__tag ~= "T" then return zero end
local tmp = x[1]
local tmp0 = tmp._2._2
return foldr_sat(tmp._1 + foldr_sat(zero, tmp0._2), tmp0._1)
end
foldr_sat(0, E)
end
11 changes: 11 additions & 0 deletions tests/lua/opt_sat_inline.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
external val (+) : int -> int -> int = "function(x, y) return x + y end"

type sz_tree 'a =
| E
| T of 'a * int * sz_tree 'a * sz_tree 'a

let rec foldr f zero = function
| E -> zero
| T (x, _, l, r) -> foldr f (f x (foldr f zero r)) l

let _ = foldr (+) 0 E