Skip to content

Commit

Permalink
Remove roles (they can now be inferred from the type/kind)
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed May 7, 2024
1 parent 4045489 commit 7d21536
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 94 deletions.
30 changes: 8 additions & 22 deletions src/lib/Generalize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ import Subst
import Types.Primitives
import Types.Top

type RolePiBinder = WithAttrB RoleExpl CBinder
type RolePiBinders = Nest RolePiBinder

generalizeIxDict :: EnvReader m => CDict n -> m n (Generalized CoreIR CDict n)
generalizeIxDict dict = liftGeneralizerM do
dict' <- sinkM dict
Expand Down Expand Up @@ -127,8 +124,8 @@ traverseTyParams (StuckTy _ _) _ = error "shouldn't have StuckTy left"
traverseTyParams (TyCon ty) f = liftM TyCon $ getDistinct >>= \Distinct -> case ty of
DictTy dictTy -> DictTy <$> case dictTy of
DictType sn name params -> do
Abs paramRoles UnitE <- getClassRoleBinders name
params' <- traverseRoleBinders f paramRoles params
ClassDef _ _ _ _ _ bs _ _ <- lookupClassDef name
params' <- traverseRoleBinders f bs params
return $ DictType sn name params'
IxDictType t -> IxDictType <$> f' TypeParam TyKind t
TabPi (TabPiType d (b:>iTy) resultTy) -> do
Expand All @@ -147,8 +144,8 @@ traverseTyParams (TyCon ty) f = liftM TyCon $ getDistinct >>= \Distinct -> case
Nat -> return Nat
Fin n -> Fin <$> f DataParam NatTy n
UserADTType sn def (TyConParams infs params) -> do
Abs roleBinders UnitE <- getDataDefRoleBinders def
params' <- traverseRoleBinders f roleBinders params
TyConDef _ _ bs _ <- lookupTyCon def
params' <- traverseRoleBinders f bs params
return $ UserADTType sn def $ TyConParams infs params'
_ -> error $ "Not implemented: " ++ pprint ty
where
Expand All @@ -159,34 +156,23 @@ traverseTyParams (TyCon ty) f = liftM TyCon $ getDistinct >>= \Distinct -> case
traverseRoleBinders
:: forall m n n'. EnvReader m
=> (forall l . DExt n l => ParamRole -> Type CoreIR l -> Atom CoreIR l -> m l (Atom CoreIR l))
-> RolePiBinders n n' -> [Atom CoreIR n] -> m n [Atom CoreIR n]
-> Nest CBinder n n' -> [Atom CoreIR n] -> m n [Atom CoreIR n]
traverseRoleBinders f allBinders allParams =
runSubstReaderT idSubst $ go allBinders allParams
where
go :: forall i i'. RolePiBinders i i' -> [Atom CoreIR n]
go :: forall i i'. Nest CBinder i i' -> [Atom CoreIR n]
-> SubstReaderT AtomSubstVal m i n [Atom CoreIR n]
go Empty [] = return []
go (Nest (WithAttrB (role, _) b) bs) (param:params) = do
go (Nest b bs) (param:params) = do
ty' <- substM $ binderType b
role <- inferRoleFromType ty'
Distinct <- getDistinct
param' <- liftSubstReaderT $ f role ty' param
params'' <- extendSubst (b@>SubstVal param') $ go bs params
return $ param' : params''
go _ _ = error "zip error"
{-# INLINE traverseRoleBinders #-}

getDataDefRoleBinders :: EnvReader m => TyConName n -> m n (Abs RolePiBinders UnitE n)
getDataDefRoleBinders def = do
TyConDef _ attrs bs _ <- lookupTyCon def
return $ Abs (zipAttrs attrs bs) UnitE
{-# INLINE getDataDefRoleBinders #-}

getClassRoleBinders :: EnvReader m => ClassName n -> m n (Abs RolePiBinders UnitE n)
getClassRoleBinders def = do
ClassDef _ _ _ _ roleExpls bs _ _ <- lookupClassDef def
return $ Abs (zipAttrs roleExpls bs) UnitE
{-# INLINE getClassRoleBinders #-}

-- === instances ===

instance GenericB GeneralizationEmission where
Expand Down
77 changes: 28 additions & 49 deletions src/lib/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

module Inference
( inferTopUDecl, checkTopUType, inferTopUExpr , generalizeDict, asTopBlock
, UDeclInferenceResult (..), asFFIFunType) where
, UDeclInferenceResult (..), asFFIFunType, ParamRole (..), inferRoleFromType) where

import Prelude hiding ((.), id)
import Control.Category
Expand Down Expand Up @@ -93,7 +93,7 @@ inferTopUDecl (UInterface paramBs methodTys className methodNames) result = do
UDeclResultDone <$> applyRename subst result
inferTopUDecl (UInstance className bs params methods maybeName expl) result = do
let (InternalName _ _ className') = className
def <- liftInfererM $ withRoleUBinders bs \(ZipB roleExpls bs') -> do
def <- liftInfererM $ withUBinders bs \(ZipB roleExpls bs') -> do
ClassDef _ _ _ _ _ paramBinders _ _ <- lookupClassDef (sink className')
params' <- checkInstanceParams paramBinders params
body <- checkInstanceBody (sink className') params' methods
Expand Down Expand Up @@ -136,11 +136,11 @@ asTopBlock block = do
return (TopLam False (PiType Empty ty) (LamExpr Empty block), ty)

getInstanceType :: EnvReader m => InstanceDef n -> m n (CorePiType n)
getInstanceType (InstanceDef className roleExpls bs params _) = liftEnvReaderM do
getInstanceType (InstanceDef className expls bs params _) = liftEnvReaderM do
refreshAbs (Abs bs (ListE params)) \bs' (ListE params') -> do
className' <- sinkM className
dTy <- toType <$> dictType className' params'
return $ CorePiType ImplicitApp (snd <$> roleExpls) bs' dTy
return $ CorePiType ImplicitApp expls bs' dTy

-- === Inferer monad ===

Expand Down Expand Up @@ -806,8 +806,7 @@ applySigmaAtom (SigmaUVar _ _ f) args = case f of
f'' <- inlineTypeAliases f'
emit =<< mkApp f'' args
UTyConVar f' -> do
TyConDef sn roleExpls _ _ <- lookupTyCon f'
let expls = snd <$> roleExpls
TyConDef sn expls _ _ <- lookupTyCon f'
return $ toAtom $ UserADTType sn f' (TyConParams expls args)
UDataConVar v -> do
(tyCon, i) <- lookupDataCon v
Expand Down Expand Up @@ -1141,7 +1140,7 @@ instanceFun instanceName appExpl = do
liftEnvReaderM $ refreshAbs (Abs bs UnitE) \bs' UnitE -> do
args <- mapM toAtomVar $ nestToNames bs'
result <- toAtom <$> mkInstanceDict (sink instanceName) (toAtom <$> args)
let piTy = CorePiType appExpl (snd<$>expls) bs' (getType result)
let piTy = CorePiType appExpl expls bs' (getType result)
return $ toAtom $ CoreLamExpr piTy (LamExpr bs' $ Atom result)

checkMaybeAnnExpr :: Emits o => Maybe (UType i) -> UExpr i -> InfererM i o (CAtom o)
Expand All @@ -1153,25 +1152,24 @@ checkMaybeAnnExpr ty expr = confuseGHC >>= \_ -> case ty of

inferTyConDef :: UDataDef i -> InfererM i o (TyConDef o)
inferTyConDef (UDataDef tyConName paramBs dataCons) = do
withRoleUBinders paramBs \(ZipB roleExpls paramBs') -> do
withUBinders paramBs \(ZipB expls paramBs') -> do
dataCons' <- ADTCons <$> mapM inferDataCon dataCons
return (TyConDef tyConName roleExpls paramBs' dataCons')
return (TyConDef tyConName expls paramBs' dataCons')

inferStructDef :: UStructDef i -> InfererM i o (TyConDef o)
inferStructDef (UStructDef tyConName paramBs fields _) = do
withRoleUBinders paramBs \(ZipB roleExpls paramBs') -> do
withUBinders paramBs \(ZipB expls paramBs') -> do
let (fieldNames, fieldTys) = unzip fields
tys <- mapM checkUType fieldTys
let dataConDefs = StructFields $ zip (withoutSrc <$> fieldNames) tys
return $ TyConDef tyConName roleExpls paramBs' dataConDefs
return $ TyConDef tyConName expls paramBs' dataConDefs

inferDotMethod
:: TyConName o
-> Abs (Nest UAnnBinder) (Abs UAtomBinder ULamExpr) i
-> InfererM i o (CoreLamExpr o)
inferDotMethod tc (Abs uparamBs (Abs selfB lam)) = do
TyConDef sn roleExpls paramBs _ <- lookupTyCon tc
let expls = snd <$> roleExpls
TyConDef sn expls paramBs _ <- lookupTyCon tc
withFreshBindersInf expls (Abs paramBs UnitE) \paramBs' UnitE -> do
let paramVs = bindersVars paramBs'
extendRenamer (uparamBs @@> (atomVarName <$> paramVs)) do
Expand Down Expand Up @@ -1223,16 +1221,16 @@ inferClassDef
:: SourceName -> [SourceName] -> Nest UAnnBinder i i' -> [UType i']
-> InfererM i o (ClassDef o)
inferClassDef className methodNames paramBs methodTys = do
withRoleUBinders paramBs \(ZipB roleExpls paramBs') -> do
withUBinders paramBs \(ZipB expls paramBs') -> do
let paramNames = catMaybes $ nestToListFlip paramBs \(UAnnBinder expl b _ _) ->
case expl of Inferred _ (Synth _) -> Nothing
_ -> Just $ Just $ getSourceName b
methodTys' <- forM methodTys \m -> do
checkUType m >>= \case
TyCon (Pi t) -> return t
t -> return $ CorePiType ImplicitApp [] Empty t
PairB paramBs'' superclassBs <- partitionBinders rootSrcId (zipAttrs roleExpls paramBs') $
\b@(WithAttrB (_, expl) b') -> case expl of
PairB paramBs'' superclassBs <- partitionBinders rootSrcId (zipAttrs expls paramBs') $
\b@(WithAttrB expl b') -> case expl of
Explicit -> return $ LeftB b
-- TODO: Add a proper SrcId here. We'll need to plumb it through from the original UBinders
Inferred _ Unify -> throw rootSrcId InterfacesNoImplicitParams
Expand Down Expand Up @@ -1275,29 +1273,6 @@ inferUBinders (Nest (UAnnBinder expl (WithSrcB sid b) ann cs) bs) cont = do
Abs bs' e <- inferUBinders bs \vs -> cont (sink (binderName b') : vs)
return $ Abs (Nest (WithAttrB expl b') bs') e

withRoleUBinders :: Nest UAnnBinder i i' -> InfererCPSB2 (Nest (WithRoleExpl CBinder)) i i' o a
withRoleUBinders bs cont = do
withUBinders bs \(ZipB expls bs') -> do
let tys = getType <$> bindersVars bs'
roleExpls <- forM (zip tys expls) \(ty, expl) -> do
role <- inferRole ty expl
return (role, expl)
cont (zipAttrs roleExpls bs')
where
inferRole :: CType o -> Explicitness -> InfererM i o ParamRole
inferRole ty = \case
Inferred _ (Synth _) -> return DictParam
_ -> case ty of
TyKind -> return TypeParam
_ -> isData ty >>= \case
True -> return DataParam
-- TODO(dougalm): the `False` branch should throw an error but that's
-- currently too conservative. e.g. `data RangeFrom q:Type i:q = ...`
-- fails because `q` isn't data. We should be able to fix it once we
-- have a `Data a` class (see issue #680).
False -> return DataParam
{-# INLINE inferRole #-}

inferAnn :: SrcId -> UAnn i -> [UConstraint i] -> InfererM i o (CType o)
inferAnn binderSrcId ann cs = case ann of
UAnn ty -> checkUType ty
Expand Down Expand Up @@ -1474,8 +1449,7 @@ checkCasePat (WithSrcB sid pat) scrutineeTy cont = case pat of

inferParams :: Emits o => SrcId -> CType o -> TyConName o -> InfererM i o (TyConParams o)
inferParams sid ty dataDefName = do
TyConDef sourceName roleExpls paramBs _ <- lookupTyCon dataDefName
let paramExpls = snd <$> roleExpls
TyConDef sourceName paramExpls paramBs _ <- lookupTyCon dataDefName
let inferenceExpls = paramExpls <&> \case
Explicit -> Inferred Nothing Unify
expl -> expl
Expand Down Expand Up @@ -1880,8 +1854,8 @@ generalizeDict ty dict = do
generalizeDictRec :: CType n -> CDict n -> InfererM i n (CDict n)
generalizeDictRec targetTy (DictCon dict) = case dict of
InstanceDict _ instanceName args -> do
InstanceDef _ roleExpls bs _ _ <- lookupInstanceDef instanceName
liftSolverM $ generalizeInstanceArgs roleExpls bs args \args' -> do
InstanceDef _ _ bs _ _ <- lookupInstanceDef instanceName
liftSolverM $ generalizeInstanceArgs bs args \args' -> do
d <- mkInstanceDict (sink instanceName) args'
-- We use rootSrcId here because we only call this after type inference so
-- precise source info isn't needed.
Expand All @@ -1893,17 +1867,23 @@ generalizeDictRec targetTy (DictCon dict) = case dict of
IxRawFin _ -> error "not a simplified dict"
generalizeDictRec _ _ = error "not a simplified dict"

data ParamRole = TypeParam | DictParam | DataParam deriving (Show, Generic, Eq)

inferRoleFromType :: EnvReader m => CType n -> m n ParamRole
inferRoleFromType = undefined -- TODO

generalizeInstanceArgs
:: Zonkable e => [RoleExpl] -> Nest CBinder o any -> [CAtom o]
:: Zonkable e => Nest CBinder o any -> [CAtom o]
-> (forall o'. DExt o o' => [CAtom o'] -> SolverM i o' (e o'))
-> SolverM i o (e o)
generalizeInstanceArgs [] Empty [] cont = withDistinct $ cont []
generalizeInstanceArgs ((role,_):expls) (Nest (b:>ty) bs) (arg:args) cont = do
generalizeInstanceArgs Empty [] cont = withDistinct $ cont []
generalizeInstanceArgs (Nest (b:>ty) bs) (arg:args) cont = do
role <- inferRoleFromType ty
generalizeInstanceArg role ty arg \arg' -> do
Abs bs' UnitE <- applySubst (b@>SubstVal arg') (Abs bs UnitE)
generalizeInstanceArgs expls bs' (sink <$> args) \args' ->
generalizeInstanceArgs bs' (sink <$> args) \args' ->
cont $ sink arg' : args'
generalizeInstanceArgs _ _ _ _ = error "zip error"
generalizeInstanceArgs _ _ _ = error "zip error"

generalizeInstanceArg
:: Zonkable e => ParamRole -> CType o -> CAtom o
Expand Down Expand Up @@ -2115,7 +2095,6 @@ instance SinkableE Givens where
-- === Inference-specific builder patterns ===

type WithExpl = WithAttrB Explicitness
type WithRoleExpl = WithAttrB RoleExpl

buildBlockInfWithRecon
:: HasNamesE e
Expand Down
11 changes: 5 additions & 6 deletions src/lib/QueryType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ getUVarType = \case
UDataConVar v -> getDataConNameType v
UPunVar v -> getStructDataConType v
UClassVar v -> do
ClassDef _ _ _ _ roleExpls bs _ _ <- lookupClassDef v
return $ toType $ CorePiType ExplicitApp (map snd roleExpls) bs TyKind
ClassDef _ _ _ _ expls bs _ _ <- lookupClassDef v
return $ toType $ CorePiType ExplicitApp expls bs TyKind
UMethodVar v -> getMethodNameType v

getMethodNameType :: EnvReader m => MethodName n -> m n (CType n)
Expand Down Expand Up @@ -180,7 +180,7 @@ getTyConNameType v = do
TyConDef _ expls bs _ <- lookupTyCon v
case bs of
Empty -> return TyKind
_ -> return $ toType $ CorePiType ExplicitApp (snd <$> expls) bs TyKind
_ -> return $ toType $ CorePiType ExplicitApp expls bs TyKind

getDataConNameType :: EnvReader m => DataConName n -> m n (Type CoreIR n)
getDataConNameType dataCon = liftEnvReaderM $ withSubstReaderT do
Expand Down Expand Up @@ -212,8 +212,7 @@ buildDataConType
=> TyConDef n
-> (forall l. DExt n l => [Explicitness] -> Nest CBinder n l -> [CAtomName l] -> TyConParams l -> m l a)
-> m n a
buildDataConType (TyConDef _ roleExpls bs _) cont = do
let expls = snd <$> roleExpls
buildDataConType (TyConDef _ expls bs _) cont = do
expls' <- forM expls \case
Explicit -> return $ Inferred Nothing Unify
expl -> return $ expl
Expand All @@ -225,7 +224,7 @@ buildDataConType (TyConDef _ roleExpls bs _) cont = do
makeTyConParams :: EnvReader m => TyConName n -> [CAtom n] -> m n (TyConParams n)
makeTyConParams tc params = do
TyConDef _ expls _ _ <- lookupTyCon tc
return $ TyConParams (map snd expls) params
return $ TyConParams expls params

dictType :: EnvReader m => ClassName n -> [CAtom n] -> m n (DictType n)
dictType className params = do
Expand Down
Loading

0 comments on commit 7d21536

Please sign in to comment.