From cf06eb54aac8f685f9698a627ad43e7ed56420aa Mon Sep 17 00:00:00 2001 From: Dougal Date: Fri, 10 May 2024 13:56:56 -0400 Subject: [PATCH] Hoist `Hof` out of `PrimOp` and remove `Lam` case from generic op. This takes advantage of things becoming more first-order to reduce boilerplate. --- src/lib/CheapReduction.hs | 6 +- src/lib/CheckType.hs | 8 +- src/lib/Core.hs | 2 +- src/lib/Imp.hs | 2 +- src/lib/Inference.hs | 6 +- src/lib/Inline.hs | 4 +- src/lib/Linearize.hs | 2 +- src/lib/OccAnalysis.hs | 2 +- src/lib/Optimize.hs | 4 +- src/lib/QueryTypePure.hs | 4 +- src/lib/Simplify.hs | 4 +- src/lib/Transpose.hs | 2 +- src/lib/Types/Core.hs | 181 ++++++++++++++++++-------------------- src/lib/Vectorize.hs | 2 +- 14 files changed, 112 insertions(+), 117 deletions(-) diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index 1c32a6ef7..40c3182b5 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -340,7 +340,7 @@ visitAlt (Abs b body) = do traverseOpTerm :: (GenericOp e, Visitor m r i o, OpConst e r ~ OpConst e r) => e r i -> m (e r o) -traverseOpTerm e = traverseOp e visitGeneric visitGeneric visitGeneric +traverseOpTerm e = traverseOp e visitGeneric visitGeneric visitTypeDefault :: (IRRep r, Visitor (m i o) r i o, AtomSubstReader v m, EnvReader2 m) @@ -397,6 +397,7 @@ instance IRRep r => VisitGeneric (Expr r) r where ApplyMethod et m i xs -> ApplyMethod <$> visitGeneric et <*> visitGeneric m <*> pure i <*> mapM visitGeneric xs Project t i x -> Project <$> visitGeneric t <*> pure i <*> visitGeneric x Unwrap t x -> Unwrap <$> visitGeneric t <*> visitGeneric x + Hof op -> Hof <$> visitGeneric op instance IRRep r => VisitGeneric (PrimOp r) r where visitGeneric = \case @@ -405,8 +406,7 @@ instance IRRep r => VisitGeneric (PrimOp r) r where MemOp op -> MemOp <$> visitGeneric op VectorOp op -> VectorOp <$> visitGeneric op MiscOp op -> MiscOp <$> visitGeneric op - Hof op -> Hof <$> visitGeneric op - RefOp r op -> RefOp <$> visitGeneric r <*> traverseOp op visitGeneric visitGeneric visitGeneric + RefOp r op -> RefOp <$> visitGeneric r <*> traverseOp op visitGeneric visitGeneric instance IRRep r => VisitGeneric (TypedHof r) r where visitGeneric (TypedHof eff hof) = TypedHof <$> visitGeneric eff <*> visitGeneric hof diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index c4e3db2fd..e1845785d 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -266,6 +266,10 @@ instance IRRep r => CheckableE r (Expr r) where resultTy'' <- snd <$> unwrapNewtypeType con checkTypesEq resultTy' resultTy'' return $ Unwrap resultTy' x' + Hof (TypedHof effTy hof) -> do + effTy' <- checkE effTy + hof' <- checkHof effTy' hof + return $ Hof (TypedHof effTy' hof') instance CheckableE CoreIR TyConParams where checkE (TyConParams expls params) = TyConParams expls <$> mapM checkE params @@ -441,10 +445,6 @@ instance CheckableE CoreIR NewtypeTyCon where instance IRRep r => CheckableE r (PrimOp r) where checkE = \case - Hof (TypedHof effTy hof) -> do - effTy' <- checkE effTy - hof' <- checkHof effTy' hof - return $ Hof (TypedHof effTy' hof') VectorOp vOp -> VectorOp <$> checkE vOp BinOp binop x y -> do x' <- checkE x diff --git a/src/lib/Core.hs b/src/lib/Core.hs index 7d7b4dc55..bfa6db8e0 100644 --- a/src/lib/Core.hs +++ b/src/lib/Core.hs @@ -401,7 +401,7 @@ liftLamExpr (TopLam d ty (LamExpr bs body)) f = liftM (TopLam d ty) $ liftEnvRea fromNaryForExpr :: IRRep r => Int -> Expr r n -> Maybe (Int, LamExpr r n) fromNaryForExpr maxDepth | maxDepth <= 0 = error "expected non-negative number of args" fromNaryForExpr maxDepth = \case - PrimOp (Hof (TypedHof _ (For _ _ (UnaryLamExpr b body)))) -> + Hof (TypedHof _ (For _ _ (UnaryLamExpr b body))) -> extend <|> (Just $ (1, LamExpr (Nest b Empty) body)) where extend = do diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index a4740d4c9..e47de1171 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -317,6 +317,7 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of TabApp _ _ _ -> error "Unexpected `TabApp` in Imp pass." TabCon _ _ -> error "Unexpected `TabCon` in Imp pass." Project _ i x -> reduceProj i =<< substM x + Hof hof -> toImpTypedHof hof toImpRefOp :: Emits o => SAtom i -> RefOp SimpIR i -> SubstImpM i o (SAtom o) @@ -336,7 +337,6 @@ toImpRefOp refDest' m = do toImpOp :: forall i o . Emits o => PrimOp SimpIR i -> SubstImpM i o (SAtom o) toImpOp op = case op of - Hof hof -> toImpTypedHof hof RefOp refDest eff -> toImpRefOp refDest eff BinOp binOp x y -> returnIExprVal =<< emitInstr =<< (IBinOp binOp <$> fsa x <*> fsa y) UnOp unOp x -> returnIExprVal =<< emitInstr =<< (IUnOp unOp <$> fsa x) diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 387566383..ada21aa66 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -1067,7 +1067,11 @@ matchPrimApp = \case Just x' <- return $ toMaybeType x return $ Left x' _ -> return $ Right x - return $ fromJust $ toOp $ GenericOpRep op tyArgs dataArgs [] + let tyArgs' = case tyArgs of + [] -> Nothing + [t] -> Just t + _ -> error "Expected at most one type arg" + return $ fromJust $ toOp $ GenericOpRep op tyArgs' dataArgs pattern ExplicitCoreLam :: Nest CBinder n l -> CExpr l -> CAtom n pattern ExplicitCoreLam bs body <- Con (Lam (CoreLamExpr _ (LamExpr bs body))) diff --git a/src/lib/Inline.hs b/src/lib/Inline.hs index 7a22882c5..12185afe6 100644 --- a/src/lib/Inline.hs +++ b/src/lib/Inline.hs @@ -168,7 +168,7 @@ inlineDeclsSubst = \case -- since their main purpose is to force inlining in the simplifier, and if -- one just stuck like this it has become equivalent to a `for` anyway. ixDepthExpr :: Expr SimpIR n -> Int - ixDepthExpr (PrimOp (Hof (TypedHof _ (For _ _ (UnaryLamExpr _ body))))) = 1 + ixDepthExpr body + ixDepthExpr (Hof (TypedHof _ (For _ _ (UnaryLamExpr _ body)))) = 1 + ixDepthExpr body ixDepthExpr _ = 0 -- Should we decide to inline this binding wherever it appears, before we even @@ -316,7 +316,7 @@ reconstruct ctx e = case ctx of reconstructTabApp :: Emits o => Context SExpr e o -> SExpr o -> SAtom i -> InlineM i o (e o) reconstructTabApp ctx expr i = case expr of - PrimOp (Hof (TypedHof _ (For _ _ (UnaryLamExpr b body)))) -> do + Hof (TypedHof _ (For _ _ (UnaryLamExpr b body))) -> do -- See NoteReconstructTabAppDecisions AtomVar i' _ <- inline (EmitToNameCtx Stop) i dropSubst $ extendSubst (b@>Rename i') do diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index 6d937affd..999bdc314 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -404,10 +404,10 @@ linearizeExpr expr = case expr of Project _ i x -> do x' <- linearizeAtom x emitBoth x' \x'' -> mkProject i x'' + Hof (TypedHof _ e) -> linearizeHof e linearizeOp :: Emits o => PrimOp SimpIR i -> LinM i o SAtom SAtom linearizeOp op = case op of - Hof (TypedHof _ e) -> linearizeHof e RefOp ref m -> do ref' <- linearizeAtom ref case m of diff --git a/src/lib/OccAnalysis.hs b/src/lib/OccAnalysis.hs index 7d5f8079f..f2fb2b348 100644 --- a/src/lib/OccAnalysis.hs +++ b/src/lib/OccAnalysis.hs @@ -376,7 +376,7 @@ instance HasOCC SExpr where ty' <- occTy ty countFreeVarsAsOccurrences effs return $ Case scrut' alts' (EffTy effs ty') - PrimOp (Hof op) -> PrimOp . Hof <$> occ a op + Hof op -> Hof <$> occ a op PrimOp (RefOp ref op) -> do ref' <- occ a ref PrimOp . RefOp ref' <$> occ a op diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index fd2a6410b..5d8db74d3 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -68,7 +68,7 @@ instance ExprVisitorEmits (ULM i o) SimpIR i o where -- constant-foldable after inlining don't count towards it. ulExpr :: Emits o => SExpr i -> ULM i o (SAtom o) ulExpr expr = case expr of - PrimOp (Hof (TypedHof _ (For Fwd ixTy body))) -> + Hof (TypedHof _ (For Fwd ixTy body)) -> case ixTypeDict ixTy of DictCon (IxRawFin (IdxRepVal n)) -> do (body', bodyCost) <- withLocalAccounting $ visitLamEmits body @@ -133,7 +133,7 @@ hoistLoopInvariant lam = liftLamExpr lam hoistLoopInvariantExpr licmExpr :: Emits o => SExpr i -> LICMM i o (SAtom o) licmExpr = \case - PrimOp (Hof (TypedHof _ (For dir ix (LamExpr (UnaryNest b) body)))) -> undefined + Hof (TypedHof _ (For dir ix (LamExpr (UnaryNest b) body))) -> undefined -- ix' <- substM ix -- Abs hdecls destsAndBody <- visitBinders (UnaryNest b) \(UnaryNest b') -> do -- Abs decls ans <- buildScoped $ visitExprEmits body diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs index 3538114b8..43cc9b7ea 100644 --- a/src/lib/QueryTypePure.hs +++ b/src/lib/QueryTypePure.hs @@ -140,6 +140,7 @@ instance IRRep r => HasType r (Expr r) where ApplyMethod (EffTy _ t) _ _ _ -> t Project t _ _ -> t Unwrap t _ -> t + Hof (TypedHof (EffTy _ ty) _) -> ty instance IRRep r => HasType r (RepVal r) where getType (RepVal ty _) = ty @@ -148,7 +149,6 @@ instance IRRep r => HasType r (PrimOp r) where getType primOp = case primOp of BinOp op x _ -> TyCon $ BaseType $ typeBinOp op $ getTypeBaseType x UnOp op x -> TyCon $ BaseType $ typeUnOp op $ getTypeBaseType x - Hof (TypedHof (EffTy _ ty) _) -> ty MemOp op -> getType op MiscOp op -> getType op VectorOp op -> getType op @@ -258,6 +258,7 @@ instance IRRep r => HasEffects (Expr r) r where PrimOp primOp -> getEffects primOp Project _ _ _ -> Pure Unwrap _ _ -> Pure + Hof (TypedHof (EffTy eff _) _) -> eff instance IRRep r => HasEffects (DeclBinding r) r where getEffects (DeclBinding _ expr) = getEffects expr @@ -291,5 +292,4 @@ instance IRRep r => HasEffects (PrimOp r) r where MPut _ -> Effectful IndexRef _ _ -> Pure ProjRef _ _ -> Pure - Hof (TypedHof (EffTy eff _) _) -> eff {-# INLINE getEffects #-} diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 909bfcda1..f45d7b487 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -222,6 +222,7 @@ simplifyExpr = \case simplifyTabApp f' x' Atom x -> simplifyAtom x PrimOp op -> simplifyOp op + Hof (TypedHof (EffTy _ ty) hof) -> simplifyHof hof ApplyMethod (EffTy _ ty) dict i xs -> do xs' <- mapM simplifyAtom xs SimpCCon (WithSubst s (DictConAtom d)) <- simplifyAtom dict @@ -408,7 +409,6 @@ simplifyLam (LamExpr bsTop body) = case bsTop of simplifyOp :: Emits o => PrimOp CoreIR i -> SimplifyM i o (SimpVal o) simplifyOp op = case op of - Hof (TypedHof (EffTy _ ty) hof) -> simplifyHof hof MemOp op' -> simplifyGenericOp op' VectorOp op' -> simplifyGenericOp op' RefOp ref eff -> do @@ -433,7 +433,7 @@ simplifyGenericOp => op CoreIR i -> SimplifyM i o (SimpVal o) simplifyGenericOp op = do - op' <- traverseOp op getRepType toDataAtom (error "shouldn't have lambda left") + op' <- traverseOp op getRepType toDataAtom SimpAtom <$> emit op' {-# INLINE simplifyGenericOp #-} diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index 22d6cdb36..b1be05136 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -176,6 +176,7 @@ transposeExpr expr ct = case expr of forM_ (enumerate es) \(ordinalIdx, e) -> do i <- unsafeFromOrdinal idxTy (IdxRepVal $ fromIntegral ordinalIdx) tabApp ct i >>= transposeAtom e + Hof (TypedHof _ hof) -> transposeHof hof ct TabApp _ _ _ -> error "should have been handled by reference projection" Project _ _ _ -> error "should have been handled by reference projection" @@ -193,7 +194,6 @@ transposeOp op ct = case op of void $ emitEff $ MPut zero IndexRef _ _ -> notImplemented ProjRef _ _ -> notImplemented - Hof (TypedHof _ hof) -> transposeHof hof ct MiscOp miscOp -> transposeMiscOp miscOp ct UnOp FNeg x -> transposeAtom x =<< (emitLin $ UnOp FNeg ct) UnOp _ _ -> notLinear diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 8217e4bfa..04afee621 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -110,6 +110,7 @@ data Expr r n where Atom :: Atom r n -> Expr r n TabCon :: Type r n -> [Atom r n] -> Expr r n PrimOp :: PrimOp r n -> Expr r n + Hof :: TypedHof r n -> Expr r n Project :: Type r n -> Int -> Atom r n -> Expr r n App :: EffTy CoreIR n -> CAtom n -> [CAtom n] -> Expr CoreIR n Unwrap :: CType n -> CAtom n -> Expr CoreIR n @@ -259,14 +260,14 @@ class GenericOp (e::IR->E) where toOp :: GenericOpRep (OpConst e r) r n -> Maybe (e r n) data GenericOpRep (const :: *) (r::IR) (n::S) = - GenericOpRep const [Type r n] [Atom r n] [LamExpr r n] + GenericOpRep const (Maybe (Type r n)) [Atom r n] -- name, optional result type, args deriving (Show, Generic) instance GenericE (GenericOpRep const r) where - type RepE (GenericOpRep const r) = LiftE const `PairE` ListE (Type r) `PairE` ListE (Atom r) `PairE` ListE (LamExpr r) - fromE (GenericOpRep c ts xs lams) = LiftE c `PairE` ListE ts `PairE` ListE xs `PairE` ListE lams + type RepE (GenericOpRep const r) = LiftE const `PairE` MaybeE (Type r) `PairE` ListE (Atom r) + fromE (GenericOpRep c ts xs) = LiftE c `PairE` toMaybeE ts `PairE` ListE xs {-# INLINE fromE #-} - toE (LiftE c `PairE` ListE ts `PairE` ListE xs `PairE` ListE lams) = GenericOpRep c ts xs lams + toE (LiftE c `PairE` ts `PairE` ListE xs) = GenericOpRep c (fromMaybeE ts) xs {-# INLINE toE #-} instance IRRep r => SinkableE (GenericOpRep const r) where @@ -274,8 +275,8 @@ instance IRRep r => HoistableE (GenericOpRep const r) where instance (Eq const, IRRep r) => AlphaEqE (GenericOpRep const r) instance (Hashable const, IRRep r) => AlphaHashableE (GenericOpRep const r) instance IRRep r => RenameE (GenericOpRep const r) where - renameE env (GenericOpRep c ts xs ys) = - GenericOpRep c (map (renameE env) ts) (map (renameE env) xs) (map (renameE env) ys) + renameE env (GenericOpRep c ts xs) = + GenericOpRep c (fmap (renameE env) ts) (map (renameE env) xs) fromEGenericOpRep :: GenericOp e => e r n -> GenericOpRep (OpConst e r) r n fromEGenericOpRep = fromOp @@ -288,14 +289,12 @@ traverseOp => e r i -> (Type r i -> m (Type r' o)) -> (Atom r i -> m (Atom r' o)) - -> (LamExpr r i -> m (LamExpr r' o)) -> m (e r' o) -traverseOp op fType fAtom fLam = do - let GenericOpRep c tys atoms lams = fromOp op +traverseOp op fType fAtom = do + let GenericOpRep c tys atoms = fromOp op tys' <- mapM fType tys atoms' <- mapM fAtom atoms - lams' <- mapM fLam lams - return $ fromJust $ toOp $ GenericOpRep c tys' atoms' lams' + return $ fromJust $ toOp $ GenericOpRep c tys' atoms' -- === Various ops === @@ -305,7 +304,6 @@ data PrimOp (r::IR) (n::S) where MemOp :: MemOp r n -> PrimOp r n VectorOp :: VectorOp r n -> PrimOp r n MiscOp :: MiscOp r n -> PrimOp r n - Hof :: TypedHof r n -> PrimOp r n RefOp :: Atom r n -> RefOp r n -> PrimOp r n deriving instance IRRep r => Show (PrimOp r n) @@ -568,7 +566,7 @@ instance ToExpr (PrimOp r) r where toExpr = PrimOp instance ToExpr (MiscOp r) r where toExpr = PrimOp . MiscOp instance ToExpr (MemOp r) r where toExpr = PrimOp . MemOp instance ToExpr (VectorOp r) r where toExpr = PrimOp . VectorOp -instance ToExpr (TypedHof r) r where toExpr = PrimOp . Hof +instance ToExpr (TypedHof r) r where toExpr = Hof -- === Pattern synonyms === @@ -850,16 +848,16 @@ instance IRRep r => AlphaHashableE (Hof r) instance GenericOp RefOp where type OpConst RefOp r = P.RefOp fromOp = \case - MGet -> GenericOpRep P.MGet [] [] [] - MPut x -> GenericOpRep P.MPut [] [x] [] - IndexRef t x -> GenericOpRep P.IndexRef [t] [x] [] - ProjRef t p -> GenericOpRep (P.ProjRef p) [t] [] [] + MGet -> GenericOpRep P.MGet Nothing [] + MPut x -> GenericOpRep P.MPut Nothing [x] + IndexRef t x -> GenericOpRep P.IndexRef (Just t) [x] + ProjRef t p -> GenericOpRep (P.ProjRef p) (Just t) [] {-# INLINE fromOp #-} toOp = \case - GenericOpRep P.MGet [] [] [] -> Just $ MGet - GenericOpRep P.MPut [] [x] [] -> Just $ MPut x - GenericOpRep P.IndexRef [t] [x] [] -> Just $ IndexRef t x - GenericOpRep (P.ProjRef p) [t] [] [] -> Just $ ProjRef t p + GenericOpRep P.MGet Nothing [] -> Just $ MGet + GenericOpRep P.MPut Nothing [x] -> Just $ MPut x + GenericOpRep P.IndexRef (Just t) [x] -> Just $ IndexRef t x + GenericOpRep (P.ProjRef p) (Just t) [] -> Just $ ProjRef t p _ -> Nothing {-# INLINE toOp #-} @@ -991,12 +989,13 @@ instance IRRep r => GenericE (Expr r) where {- TopApp -} (WhenSimp r (EffTy r `PairE` TopFunName `PairE` ListE (Atom r))) {- Block -} (EffTy r `PairE` Block r) ) - ( EitherE5 + ( EitherE6 {- TabCon -} (Type r `PairE` ListE (Atom r)) {- PrimOp -} (PrimOp r) {- ApplyMethod -} (WhenCore r (EffTy r `PairE` Atom r `PairE` LiftE Int `PairE` ListE (Atom r))) {- Project -} (Type r `PairE` LiftE Int `PairE` Atom r) - {- Unwrap -} (WhenCore r (CType `PairE` CAtom))) + {- Unwrap -} (WhenCore r (CType `PairE` CAtom)) + {- Hof -} (TypedHof r)) fromE = \case App et f xs -> Case0 $ Case0 (WhenIRE (et `PairE` f `PairE` ListE xs)) TabApp t f x -> Case0 $ Case1 (t `PairE` f `PairE` x) @@ -1009,6 +1008,7 @@ instance IRRep r => GenericE (Expr r) where ApplyMethod et d i xs -> Case1 $ Case2 (WhenIRE (et `PairE` d `PairE` LiftE i `PairE` ListE xs)) Project ty i x -> Case1 $ Case3 (ty `PairE` LiftE i `PairE` x) Unwrap t x -> Case1 $ Case4 (WhenIRE (t `PairE` x)) + Hof hof -> Case1 $ Case5 hof {-# INLINE fromE #-} toE = \case Case0 case0 -> case case0 of @@ -1025,6 +1025,7 @@ instance IRRep r => GenericE (Expr r) where Case2 (WhenIRE (et `PairE` d `PairE` LiftE i `PairE` ListE xs)) -> ApplyMethod et d i xs Case3 (ty `PairE` LiftE i `PairE` x) -> Project ty i x Case4 (WhenIRE (t `PairE` x)) -> Unwrap t x + Case5 hof -> Hof hof _ -> error "impossible" _ -> error "impossible" {-# INLINE toE #-} @@ -1036,39 +1037,29 @@ instance IRRep r => AlphaHashableE (Expr r) instance IRRep r => RenameE (Expr r) instance IRRep r => GenericE (PrimOp r) where - type RepE (PrimOp r) = EitherE2 - ( EitherE5 - {- UnOp -} (LiftE P.UnOp `PairE` Atom r) - {- BinOp -} (LiftE P.BinOp `PairE` Atom r `PairE` Atom r) - {- MemOp -} (MemOp r) + type RepE (PrimOp r) = EitherE6 + {- UnOp -} (LiftE P.UnOp `PairE` Atom r) + {- BinOp -} (LiftE P.BinOp `PairE` Atom r `PairE` Atom r) + {- MemOp -} (MemOp r) {- VectorOp -} (VectorOp r) {- MiscOp -} (MiscOp r) - ) (EitherE2 - {- Hof -} (TypedHof r) - {- RefOp -} (Atom r `PairE` RefOp r) - ) + {- RefOp -} (Atom r `PairE` RefOp r) fromE = \case - UnOp op x -> Case0 $ Case0 $ LiftE op `PairE` x - BinOp op x y -> Case0 $ Case1 $ LiftE op `PairE` x `PairE` y - MemOp op -> Case0 $ Case2 op - VectorOp op -> Case0 $ Case3 op - MiscOp op -> Case0 $ Case4 op - Hof op -> Case1 $ Case0 op - RefOp r op -> Case1 $ Case1 $ r `PairE` op + UnOp op x -> Case0 $ LiftE op `PairE` x + BinOp op x y -> Case1 $ LiftE op `PairE` x `PairE` y + MemOp op -> Case2 op + VectorOp op -> Case3 op + MiscOp op -> Case4 op + RefOp r op -> Case5 $ r `PairE` op {-# INLINE fromE #-} toE = \case - Case0 rep -> case rep of - Case0 (LiftE op `PairE` x ) -> UnOp op x - Case1 (LiftE op `PairE` x `PairE` y) -> BinOp op x y - Case2 op -> MemOp op - Case3 op -> VectorOp op - Case4 op -> MiscOp op - _ -> error "impossible" - Case1 rep -> case rep of - Case0 op -> Hof op - Case1 (r `PairE` op) -> RefOp r op - _ -> error "impossible" + Case0 (LiftE op `PairE` x ) -> UnOp op x + Case1 (LiftE op `PairE` x `PairE` y) -> BinOp op x y + Case2 op -> MemOp op + Case3 op -> VectorOp op + Case4 op -> MiscOp op + Case5 (r `PairE` op) -> RefOp r op _ -> error "impossible" {-# INLINE toE #-} @@ -1081,17 +1072,17 @@ instance IRRep r => RenameE (PrimOp r) instance GenericOp VectorOp where type OpConst VectorOp r = P.VectorOp fromOp = \case - VectorBroadcast x t -> GenericOpRep P.VectorBroadcast [t] [x] [] - VectorIota t -> GenericOpRep P.VectorIota [t] [] [] - VectorIdx x y t -> GenericOpRep P.VectorIdx [t] [x, y] [] - VectorSubref x y t -> GenericOpRep P.VectorSubref [t] [x, y] [] + VectorBroadcast x t -> GenericOpRep P.VectorBroadcast (Just t) [x] + VectorIota t -> GenericOpRep P.VectorIota (Just t) [] + VectorIdx x y t -> GenericOpRep P.VectorIdx (Just t) [x, y] + VectorSubref x y t -> GenericOpRep P.VectorSubref (Just t) [x, y] {-# INLINE fromOp #-} toOp = \case - GenericOpRep P.VectorBroadcast [t] [x] [] -> Just $ VectorBroadcast x t - GenericOpRep P.VectorIota [t] [] [] -> Just $ VectorIota t - GenericOpRep P.VectorIdx [t] [x, y] [] -> Just $ VectorIdx x y t - GenericOpRep P.VectorSubref [t] [x, y] [] -> Just $ VectorSubref x y t + GenericOpRep P.VectorBroadcast (Just t) [x] -> Just $ VectorBroadcast x t + GenericOpRep P.VectorIota (Just t) [] -> Just $ VectorIota t + GenericOpRep P.VectorIdx (Just t) [x, y] -> Just $ VectorIdx x y t + GenericOpRep P.VectorSubref (Just t) [x, y] -> Just $ VectorSubref x y t _ -> Nothing {-# INLINE toOp #-} @@ -1108,18 +1099,18 @@ instance IRRep r => RenameE (VectorOp r) instance GenericOp MemOp where type OpConst MemOp r = P.MemOp fromOp = \case - IOAlloc x -> GenericOpRep P.IOAlloc [] [x] [] - IOFree x -> GenericOpRep P.IOFree [] [x] [] - PtrOffset x y -> GenericOpRep P.PtrOffset [] [x, y] [] - PtrLoad x -> GenericOpRep P.PtrLoad [] [x] [] - PtrStore x y -> GenericOpRep P.PtrStore [] [x, y] [] + IOAlloc x -> GenericOpRep P.IOAlloc Nothing [x] + IOFree x -> GenericOpRep P.IOFree Nothing [x] + PtrOffset x y -> GenericOpRep P.PtrOffset Nothing [x, y] + PtrLoad x -> GenericOpRep P.PtrLoad Nothing [x] + PtrStore x y -> GenericOpRep P.PtrStore Nothing [x, y] {-# INLINE fromOp #-} toOp = \case - GenericOpRep P.IOAlloc [] [x] [] -> Just $ IOAlloc x - GenericOpRep P.IOFree [] [x] [] -> Just $ IOFree x - GenericOpRep P.PtrOffset [] [x, y] [] -> Just $ PtrOffset x y - GenericOpRep P.PtrLoad [] [x] [] -> Just $ PtrLoad x - GenericOpRep P.PtrStore [] [x, y] [] -> Just $ PtrStore x y + GenericOpRep P.IOAlloc Nothing [x] -> Just $ IOAlloc x + GenericOpRep P.IOFree Nothing [x] -> Just $ IOFree x + GenericOpRep P.PtrOffset Nothing [x, y] -> Just $ PtrOffset x y + GenericOpRep P.PtrLoad Nothing [x] -> Just $ PtrLoad x + GenericOpRep P.PtrStore Nothing [x, y] -> Just $ PtrStore x y _ -> Nothing {-# INLINE toOp #-} @@ -1136,32 +1127,32 @@ instance IRRep r => RenameE (MemOp r) instance GenericOp MiscOp where type OpConst MiscOp r = P.MiscOp fromOp = \case - Select p x y -> GenericOpRep P.Select [] [p,x,y] [] - CastOp t x -> GenericOpRep P.CastOp [t] [x] [] - BitcastOp t x -> GenericOpRep P.BitcastOp [t] [x] [] - UnsafeCoerce t x -> GenericOpRep P.UnsafeCoerce [t] [x] [] - GarbageVal t -> GenericOpRep P.GarbageVal [t] [] [] - NewRef t -> GenericOpRep P.NewRef [t] [] [] - ThrowError t -> GenericOpRep P.ThrowError [t] [] [] - SumTag x -> GenericOpRep P.SumTag [] [x] [] - ToEnum t x -> GenericOpRep P.ToEnum [t] [x] [] - OutputStream -> GenericOpRep P.OutputStream [] [] [] - ShowAny x -> GenericOpRep P.ShowAny [] [x] [] - ShowScalar x -> GenericOpRep P.ShowScalar [] [x] [] + Select p x y -> GenericOpRep P.Select Nothing [p,x,y] + CastOp t x -> GenericOpRep P.CastOp (Just t) [x] + BitcastOp t x -> GenericOpRep P.BitcastOp (Just t) [x] + UnsafeCoerce t x -> GenericOpRep P.UnsafeCoerce (Just t) [x] + GarbageVal t -> GenericOpRep P.GarbageVal (Just t) [] + NewRef t -> GenericOpRep P.NewRef (Just t) [] + ThrowError t -> GenericOpRep P.ThrowError (Just t) [] + SumTag x -> GenericOpRep P.SumTag Nothing [x] + ToEnum t x -> GenericOpRep P.ToEnum (Just t) [x] + OutputStream -> GenericOpRep P.OutputStream Nothing [] + ShowAny x -> GenericOpRep P.ShowAny Nothing [x] + ShowScalar x -> GenericOpRep P.ShowScalar Nothing [x] {-# INLINE fromOp #-} toOp = \case - GenericOpRep P.Select [] [p,x,y] [] -> Just $ Select p x y - GenericOpRep P.CastOp [t] [x] [] -> Just $ CastOp t x - GenericOpRep P.BitcastOp [t] [x] [] -> Just $ BitcastOp t x - GenericOpRep P.UnsafeCoerce [t] [x] [] -> Just $ UnsafeCoerce t x - GenericOpRep P.GarbageVal [t] [] [] -> Just $ GarbageVal t - GenericOpRep P.NewRef [t] [] [] -> Just $ NewRef t - GenericOpRep P.ThrowError [t] [] [] -> Just $ ThrowError t - GenericOpRep P.SumTag [] [x] [] -> Just $ SumTag x - GenericOpRep P.ToEnum [t] [x] [] -> Just $ ToEnum t x - GenericOpRep P.OutputStream [] [] [] -> Just $ OutputStream - GenericOpRep P.ShowAny [] [x] [] -> Just $ ShowAny x - GenericOpRep P.ShowScalar [] [x] [] -> Just $ ShowScalar x + GenericOpRep P.Select Nothing [p,x,y] -> Just $ Select p x y + GenericOpRep P.CastOp (Just t) [x] -> Just $ CastOp t x + GenericOpRep P.BitcastOp (Just t) [x] -> Just $ BitcastOp t x + GenericOpRep P.UnsafeCoerce (Just t) [x] -> Just $ UnsafeCoerce t x + GenericOpRep P.GarbageVal (Just t) [] -> Just $ GarbageVal t + GenericOpRep P.NewRef (Just t) [] -> Just $ NewRef t + GenericOpRep P.ThrowError (Just t) [] -> Just $ ThrowError t + GenericOpRep P.SumTag Nothing [x] -> Just $ SumTag x + GenericOpRep P.ToEnum (Just t) [x] -> Just $ ToEnum t x + GenericOpRep P.OutputStream Nothing [] -> Just $ OutputStream + GenericOpRep P.ShowAny Nothing [x] -> Just $ ShowAny x + GenericOpRep P.ShowScalar Nothing [x] -> Just $ ShowScalar x _ -> Nothing {-# INLINE toOp #-} @@ -1712,7 +1703,6 @@ instance IRRep r => PrettyPrec (PrimOp r n) where prettyPrec = \case MemOp op -> prettyPrec op VectorOp op -> prettyPrec op - Hof (TypedHof _ hof) -> prettyPrec hof RefOp ref eff -> atPrec LowestPrec case eff of MGet -> "get" <+> pApp ref MPut x -> pApp ref <+> ":=" <+> pApp x @@ -1742,8 +1732,8 @@ instance IRRep r => PrettyPrec (VectorOp r n) where prettyOpGeneric :: (IRRep r, GenericOp op, Show (OpConst op r)) => op r n -> DocPrec ann prettyOpGeneric op = case fromEGenericOpRep op of - GenericOpRep op' [] [] [] -> atPrec ArgPrec (pretty $ show op') - GenericOpRep op' ts xs lams -> atPrec AppPrec $ pAppArg (pretty (show op')) xs <+> pretty ts <+> pretty lams + GenericOpRep op' Nothing [] -> atPrec ArgPrec (pretty $ show op') + GenericOpRep op' ts xs -> atPrec AppPrec $ pAppArg (pretty (show op')) xs <+> pretty ts instance Pretty IxMethod where pretty method = pretty $ show method @@ -1786,6 +1776,7 @@ instance IRRep r => PrettyPrec (Expr r n) where ApplyMethod _ d i xs -> atPrec AppPrec $ "applyMethod" <+> p d <+> p i <+> p xs Project _ i x -> atPrec AppPrec $ "Project" <+> p i <+> p x Unwrap _ x -> atPrec AppPrec $ "Unwrap" <+> p x + Hof (TypedHof _ hof) -> prettyPrec hof where p :: Pretty a => a -> Doc ann p = pretty diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index fa7d04e55..69ade54fe 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -379,7 +379,7 @@ instance Visitor (CalcWidthM i o) SimpIR i o where instance ExprVisitorNoEmits (CalcWidthM i o) SimpIR i o where visitExprNoEmits expr = case expr of - PrimOp (Hof _) -> fallback + Hof _ -> fallback PrimOp (RefOp _ _) -> fallback PrimOp _ -> do expr' <- renameM expr