Skip to content

Commit

Permalink
Make type variables scope like they do in GHC.
Browse files Browse the repository at this point in the history
I.e., without an explicit 'forall' type variables do not
scope over the body of a function.
  • Loading branch information
augustss committed Sep 20, 2024
1 parent 42b6317 commit 924f18d
Show file tree
Hide file tree
Showing 8 changed files with 4,537 additions and 4,517 deletions.
1 change: 1 addition & 0 deletions TODO
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
* Better naming of internal identifiers
* Add mask&co to exceptions
* Make deriving refer to identifiers that don't need to be in scope
* Add reductions for underapplied K2,K3,K4

Bugs:
* Missing IO in ccall shows wrong location
Expand Down
8,954 changes: 4,478 additions & 4,476 deletions generated/mhs.c

Large diffs are not rendered by default.

23 changes: 13 additions & 10 deletions src/MicroHs/Expr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ module MicroHs.Expr(
Assoc(..), Fixity,
getBindsVars,
HasLoc(..),
eForall,
eForall, eForall',
eDummy,
impossible, impossibleShow,
getArrow, getArrows,
Expand Down Expand Up @@ -124,7 +124,7 @@ data Expr
| EViewPat Expr EPat
| ELazy Bool EPat -- True indicates ~p, False indicates !p
-- only in types
| EForall [IdKind] EType
| EForall Bool [IdKind] EType -- True indicates explicit forall in the code
-- only while type checking
| EUVar Int
-- only after type checking
Expand Down Expand Up @@ -357,8 +357,8 @@ instance HasLoc Expr where
getSLoc (ELazy _ e) = getSLoc e
getSLoc (EUVar _) = error "getSLoc EUVar"
getSLoc (ECon c) = getSLoc c
getSLoc (EForall [] e) = getSLoc e
getSLoc (EForall iks _) = getSLoc iks
getSLoc (EForall _ [] e) = getSLoc e
getSLoc (EForall _ iks _) = getSLoc iks

instance forall a . HasLoc a => HasLoc [a] where
getSLoc [] = noSLoc -- XXX shouldn't happen
Expand Down Expand Up @@ -422,7 +422,7 @@ subst s =
EApp f a -> EApp (sub f) (sub a)
ESign e t -> ESign (sub e) t
EUVar _ -> ae
EForall iks t -> EForall iks $ subst [ x | x@(i, _) <- s, not (elem i is) ] t
EForall b iks t -> EForall b iks $ subst [ x | x@(i, _) <- s, not (elem i is) ] t
where is = map idKindIdent iks
ELit _ _ -> ae
_ -> error "subst unimplemented"
Expand Down Expand Up @@ -511,7 +511,7 @@ allVarsExpr' aexpr =
ELazy _ p -> allVarsExpr' p
EUVar _ -> id
ECon c -> (conIdent c :)
EForall iks e -> (map (\ (IdKind i _) -> i) iks ++) . allVarsExpr' e
EForall _ iks e -> (map (\ (IdKind i _) -> i) iks ++) . allVarsExpr' e
where field (EField _ e) = allVarsExpr' e
field (EFieldPun is) = (last is :)
field EFieldWild = impossible
Expand Down Expand Up @@ -713,7 +713,7 @@ ppExprR raw = ppE
ELazy False p -> text "!" <> ppE p
EUVar i -> text ("_a" ++ show i)
ECon c -> ppCon c
EForall iks e -> ppForall iks <+> ppEType e
EForall _ iks e -> ppForall iks <+> ppEType e

ppApp :: [Expr] -> Expr -> Doc
ppApp as (EApp f a) = ppApp (a:as) f
Expand Down Expand Up @@ -814,8 +814,11 @@ getBindsVars :: [EBind] -> [Ident]
getBindsVars = concatMap getBindVars

eForall :: [IdKind] -> EType -> EType
eForall [] t = t
eForall vs t = EForall vs t
eForall = eForall' True

eForall' :: Bool -> [IdKind] -> EType -> EType
eForall' _ [] t = t
eForall' b vs t = EForall b vs t

eDummy :: Expr
eDummy = EVar dummyIdent
Expand Down Expand Up @@ -875,7 +878,7 @@ freeTyVars = foldr (go []) []
| elem tv acc = acc
| isConIdent tv = acc
| otherwise = tv : acc
go bound (EForall tvs ty) acc = go (map idKindIdent tvs ++ bound) ty acc
go bound (EForall _ tvs ty) acc = go (map idKindIdent tvs ++ bound) ty acc
go bound (EApp fun arg) acc = go bound fun (go bound arg acc)
go _bound (EUVar _) acc = acc
go _bound (ECon _) acc = acc
Expand Down
2 changes: 1 addition & 1 deletion src/MicroHs/Parse.hs
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ pType :: P EType
pType = do
vs <- pForall
t <- pTypeOp
pure $ if null vs then t else EForall vs t
pure $ if null vs then t else EForall True vs t

pForall :: P [IdKind]
pForall = (forallKW *> esome pIdKind <* pSymbol ".") <|< pure []
Expand Down
61 changes: 31 additions & 30 deletions src/MicroHs/TypeCheck.hs
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ mkTModule impt tds tcs =
[ TypeExport i (tentry i) [] | Type (i, _) _ <- tds ]

-- All type synonym definitions.
ses = [ (qualIdent mn i, EForall vs t) | Type (i, vs) t <- tds ]
ses = [ (qualIdent mn i, EForall True vs t) | Type (i, vs) t <- tds ]

-- All fixity declaration.
fes = [ (qualIdent mn i, fx) | Infix fx is <- tds, i <- is ]
Expand Down Expand Up @@ -609,8 +609,8 @@ primTypes =
tuple n =
let
i = tupleConstr builtinLoc n
in (i, [entry i $ EForall [kk] $ foldr kArrow kv (replicate n kv)])
kImplies = EForall [kk] $ kConstraint `kArrow` (kv `kArrow` kv)
in (i, [entry i $ EForall True [kk] $ foldr kArrow kv (replicate n kv)])
kImplies = EForall True [kk] $ kConstraint `kArrow` (kv `kArrow` kv)
in
[
-- The function arrow et al are bothersome to define in Primitives, so keep them here.
Expand Down Expand Up @@ -640,7 +640,7 @@ primValues =
vks = [IdKind (mkIdent ("a" ++ show i)) kType | i <- enumFromTo 1 n]
ts = map tVarK vks
r = tApps c ts
in (c, [Entry (ECon $ ConData [(c, n)] c []) $ EForall vks $ EForall [] $ foldr tArrow r ts ])
in (c, [Entry (ECon $ ConData [(c, n)] c []) $ EForall True vks $ EForall True [] $ foldr tArrow r ts ])
in map tuple (enumFromTo 2 10)

kArrow :: EKind -> EKind -> EKind
Expand Down Expand Up @@ -686,7 +686,7 @@ expandSyn at = do
syns <- gets synTable
case M.lookup i syns of
Nothing -> return $ eApps t ts
Just (EForall vks tt) ->
Just (EForall _ vks tt) ->
-- if length vks /= length ts then tcError (getSLoc i) $ "bad synonym use"
-- else expandSyn $ subst (zip (map idKindIdent vks) ts) tt
let s = zip (map idKindIdent vks) ts
Expand All @@ -700,7 +700,7 @@ expandSyn at = do
Just _ -> impossible
EUVar _ -> return $ eApps t ts
ESign a _ -> expandSyn a -- Throw away signatures, they don't affect unification
EForall iks tt | null ts -> EForall iks <$> expandSyn tt
EForall b iks tt | null ts -> EForall b iks <$> expandSyn tt
ELit _ (LStr _) -> return t
ELit _ (LInteger _) -> return t
_ -> impossible
Expand All @@ -711,7 +711,7 @@ mapEType fn = rec
where
rec (EApp f a) = EApp (rec f) (rec a)
rec (ESign t k) = ESign (rec t) k
rec (EForall iks t) = EForall iks (rec t)
rec (EForall b iks t) = EForall b iks (rec t)
rec t = fn t

derefUVar :: EType -> T EType
Expand All @@ -731,7 +731,7 @@ derefUVar at =
return t'
EVar _ -> return at
ESign t k -> flip ESign k <$> derefUVar t
EForall iks t -> EForall iks <$> derefUVar t
EForall b iks t -> EForall b iks <$> derefUVar t
ELit _ (LStr _) -> return at
ELit _ (LInteger _) -> return at
_ -> impossible
Expand Down Expand Up @@ -839,7 +839,7 @@ tLookupV i = do
tLookup (msgTCMode tcm) i

tInst :: HasCallStack => Expr -> EType -> T (Expr, EType)
tInst ae (EForall vks t) = do
tInst ae (EForall _ vks t) = do
t' <- tInstForall vks t
tInst ae t'
tInst ae at | Just (ctx, t) <- getImplies at = do
Expand All @@ -864,7 +864,7 @@ tInstForall vks t =
return (subst (zip vs us) t)

tInst' :: EType -> T EType
tInst' (EForall vks t) = tInstForall vks t
tInst' (EForall _ vks t) = tInstForall vks t
tInst' t = return t

extValE :: HasCallStack =>
Expand Down Expand Up @@ -1063,7 +1063,7 @@ addTypeSyn :: EDef -> T ()
addTypeSyn adef =
case adef of
Type (i, vs) t -> do
let t' = EForall vs t
let t' = EForall True vs t
extSyn i t'
mn <- gets moduleName
extSyn (qualIdent mn i) t'
Expand Down Expand Up @@ -1169,7 +1169,7 @@ expandClass impt dcls@(Class ctx (iCls, vks) fds ms) = do
mdflts = [ (i, eqns) | BFcn i eqns <- ms ]
dflttys = [ (i, t) | BDfltSign i t <- ms ]
tCtx = tApps (qualIdent mn iCls) (map (EVar . idKindIdent) vks)
mkDflt (BSign methId t) = [ Sign [iDflt] $ EForall vks $ tCtx `tImplies` ty, def $ lookup methId mdflts ]
mkDflt (BSign methId t) = [ Sign [iDflt] $ EForall True vks $ tCtx `tImplies` ty, def $ lookup methId mdflts ]
where ty = fromMaybe t $ lookup methId dflttys
def Nothing = Fcn iDflt $ simpleEqn noDflt
def (Just eqns) = Fcn iDflt eqns
Expand Down Expand Up @@ -1209,7 +1209,7 @@ defaultSuffix :: String
defaultSuffix = uniqIdentSep ++ "dflt"

splitInst :: EConstraint -> ([IdKind], [EConstraint], EConstraint)
splitInst (EForall iks t) =
splitInst (EForall _ iks t) =
case splitInst t of
(iks', ctx, ct) -> (iks ++ iks', ctx, ct)
splitInst act | Just (ctx, ct) <- getImplies act =
Expand Down Expand Up @@ -1331,7 +1331,7 @@ addValueType adef = do
tret = tApps (qualIdent mn tycon) (map tVarK vks)
addCon (Constr evks ectx c ets) = do
let ts = either id (map snd) ets
cty = EForall vks $ EForall evks $ addConstraints ectx $ foldr (tArrow . snd) tret ts
cty = EForall True vks $ EForall True evks $ addConstraints ectx $ foldr (tArrow . snd) tret ts
fs = either (const []) (map fst) ets
extValETop c cty (ECon $ ConData cti (qualIdent mn c) fs)
mapM_ addCon cs
Expand All @@ -1341,7 +1341,7 @@ addValueType adef = do
t = snd $ head $ either id (map snd) ets
tret = tApps (qualIdent mn tycon) (map tVarK vks)
fs = either (const []) (map fst) ets
extValETop c (EForall vks $ EForall [] $ tArrow t tret) (ECon $ ConNew (qualIdent mn c) fs)
extValETop c (EForall True vks $ EForall True [] $ tArrow t tret) (ECon $ ConNew (qualIdent mn c) fs)
addConFields tycon con
ForImp _ i t -> extValQTop i t
Class ctx (i, vks) fds ms -> addValueClass ctx i vks fds ms
Expand All @@ -1361,9 +1361,9 @@ addValueClass ctx iCls vks fds ms = do
tret = tApps qiCls (map tVarK vks)
cti = [ (qualIdent mn iCon, length targs) ]
iCon = mkClassConstructor iCls
iConTy = EForall vks $ foldr tArrow tret targs
iConTy = EForall True vks $ foldr tArrow tret targs
extValETop iCon iConTy (ECon $ ConData cti (qualIdent mn iCon) [])
let addMethod (BSign i t) = extValETop i (EForall vks $ tApps qiCls (map (EVar . idKindIdent) vks) `tImplies` t) (EVar $ qualIdent mn i)
let addMethod (BSign i t) = extValETop i (EForall True vks $ tApps qiCls (map (EVar . idKindIdent) vks) `tImplies` t) (EVar $ qualIdent mn i)
addMethod _ = impossible
-- traceM ("addValueClass " ++ showEType (ETuple ctx))
mapM_ addMethod meths
Expand Down Expand Up @@ -1399,14 +1399,14 @@ tcDefValue adef =

-- Add implicit forall and type check.
tCheckTypeTImpl :: HasCallStack => EType -> EType -> T EType
tCheckTypeTImpl tchk t@(EForall _ _) = tCheckTypeT tchk t
tCheckTypeTImpl tchk t@(EForall _ _ _) = tCheckTypeT tchk t
tCheckTypeTImpl tchk t = do
bvs <- stKeysLcl <$> gets valueTable -- bound outside
let fvs = freeTyVars [t] -- free variables in t
-- these are free, and need quantification. eDummy indicates missing kind
iks = map (\ i -> IdKind i eDummy) (fvs \\ bvs)
--when (not (null iks)) $ traceM ("tCheckTypeTImpl: " ++ show (t, eForall iks t))
tCheckTypeT tchk (eForall iks t)
tCheckTypeT tchk (eForall' False iks t)

tCheckTypeT :: HasCallStack => EType -> EType -> T EType
tCheckTypeT = tCheck tcTypeT
Expand Down Expand Up @@ -1711,11 +1711,11 @@ tcExprR mt ae =
e' <- instSigma loc e t' mt
checkSigma e' t'
-- Only happens in type&kind checking mode.
EForall vks t ->
EForall b vks t ->
-- assertTCMode (==TCType) $
withVks vks $ \ vks' -> do
tt <- tcExpr mt t
return (EForall vks' tt)
return (EForall b vks' tt)
EUpdate e flds -> do
ises <- concat <$> mapM (dsEField e) flds
me <- dsUpdate unsetField e ises
Expand Down Expand Up @@ -1909,7 +1909,8 @@ tcExprLam mt qs = do

tcEqns :: Bool -> EType -> [Eqn] -> T [Eqn]
--tcEqns _ t eqns | trace ("tcEqns: " ++ showEBind (BFcn dummyIdent eqns) ++ " :: " ++ show t) False = undefined
tcEqns top (EForall iks t) eqns = withExtTyps iks $ tcEqns top t eqns
tcEqns top (EForall expl iks t) eqns | expl = withExtTyps iks $ tcEqns top t eqns
| otherwise = tcEqns top t eqns
tcEqns top t eqns | Just (ctx, t') <- getImplies t = do
let loc = getSLoc eqns
d <- newADictIdent loc
Expand Down Expand Up @@ -2063,9 +2064,9 @@ tcPat mt ae =
-- traceM (show ipt)
case xpt of
-- Sanity check
EForall _ (EForall _ _) -> return ()
EForall _ _ (EForall _ _ _) -> return ()
_ -> impossibleShow i
EForall avs apt <- tInst' xpt
EForall _ avs apt <- tInst' xpt
(sks, spt) <- shallowSkolemise avs apt
(d, p, pt) <-
case getImplies spt of
Expand Down Expand Up @@ -2242,7 +2243,7 @@ dsType at =
EListish (LList [t]) -> tApp (tList (getSLoc at)) (dsType t)
ETuple ts -> tApps (tupleConstr (getSLoc at) (length ts)) (map dsType ts)
ESign t k -> ESign (dsType t) k
EForall iks t -> EForall iks (dsType t)
EForall b iks t -> EForall b iks (dsType t)
ELit _ (LStr _) -> at
ELit _ (LInteger _) -> at
_ -> impossible
Expand Down Expand Up @@ -2289,7 +2290,7 @@ quantify tvs ty = do
zipWithM_ (\ tv n -> setUVar tv (EVar n)) tvs newVars
ty' <- derefUVar ty
putUvarSubst osubst -- reset the setUVar we did above
return (EForall newVarsK ty')
return (EForall False newVarsK ty')

allBinders :: [Ident] -- a,b,..z, a1, b1,... z1, a2, b2,...
allBinders = [ mkIdent [x] | x <- ['a' .. 'z'] ] ++
Expand All @@ -2305,7 +2306,7 @@ skolemise :: HasCallStack =>
Sigma -> T ([TyVar], Rho)
-- Performs deep skolemisation, returning the
-- skolem constants and the skolemised type.
skolemise (EForall tvs ty) = do -- Rule PRPOLY
skolemise (EForall _ tvs ty) = do -- Rule PRPOLY
(sks1, ty') <- shallowSkolemise tvs ty
(sks2, ty'') <- skolemise ty'
return (sks1 ++ sks2, ty'')
Expand Down Expand Up @@ -2333,7 +2334,7 @@ metaTvs tys = foldr go [] tys
| elem tv acc = acc
| otherwise = tv : acc
go (EVar _) acc = acc
go (EForall _ ty) acc = go ty acc
go (EForall _ _ ty) acc = go ty acc
go (EApp fun arg) acc = go fun (go arg acc)
go (ELit _ _) acc = acc
go _ _ = impossible
Expand Down Expand Up @@ -2370,7 +2371,7 @@ checkSigma expr sigma = do
subsCheckRho :: HasCallStack =>
SLoc -> Expr -> Sigma -> Rho -> T Expr
--subsCheckRho _ e1 t1 t2 | trace ("subsCheckRho: " ++ show e1 ++ " :: " ++ show t1 ++ " = " ++ show t2) False = undefined
subsCheckRho loc exp1 sigma1@(EForall _ _) rho2 = do -- Rule SPEC
subsCheckRho loc exp1 sigma1@(EForall _ _ _) rho2 = do -- Rule SPEC
(exp1', rho1) <- tInst exp1 sigma1
subsCheckRho loc exp1' rho1 rho2
subsCheckRho loc exp1 arho1 rho2 | Just _ <- getImplies arho1 = do
Expand Down Expand Up @@ -2733,7 +2734,7 @@ substEUVar _ t@(EVar _) = t
substEUVar _ t@(ELit _ _) = t
substEUVar s (EApp f a) = EApp (substEUVar s f) (substEUVar s a)
substEUVar s t@(EUVar i) = fromMaybe t $ lookup i s
substEUVar s (EForall iks t) = EForall iks (substEUVar s t)
substEUVar s (EForall b iks t) = EForall b iks (substEUVar s t)
substEUVar _ _ = impossible

-- Length of lists match, because of kind correctness.
Expand Down
1 change: 1 addition & 0 deletions tests/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ test:
$(TMHS) ImpMet && $(EVAL) > ImpMet.out && diff ImpMet.ref ImpMet.out
$(TMHS) MultiIf && $(EVAL) > MultiIf.out && diff MultiIf.ref MultiIf.out
$(TMHS) LameCase && $(EVAL) > LameCase.out && diff LameCase.ref LameCase.out
$(TMHS) NoForall && $(EVAL) > NoForall.out && diff NoForall.ref NoForall.out

errtest:
sh errtester.sh $(MHS) < errmsg.test
Expand Down
11 changes: 11 additions & 0 deletions tests/NoForall.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module NoForall where

-- Without an explicit forall the 'a' is not bound in the body.
f :: a -> ((a,a),(a,a))
f x =
let g :: a -> (a,a)
g a = (a,a)
in g (x,x)

main :: IO ()
main = print (f True)
1 change: 1 addition & 0 deletions tests/NoForall.ref
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
((True,True),(True,True))

0 comments on commit 924f18d

Please sign in to comment.