Skip to content

Commit

Permalink
Remove reconstructions from Simp pass
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed May 1, 2024
1 parent 3e0cbe9 commit 6507556
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 135 deletions.
138 changes: 20 additions & 118 deletions src/lib/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

{-# LANGUAGE UndecidableInstances #-}

module Simplify
( simplifyTopBlock, simplifyTopFunction, ReconstructAtom (..), applyReconTop,
linearizeTopFun, SimplifiedTopLam (..)) where
module Simplify (simplifyTopBlock, simplifyTopFunction, linearizeTopFun) where

import Control.Category ((>>>))
import Control.Monad
Expand All @@ -17,7 +15,6 @@ import Data.Maybe

import Builder
import CheapReduction
import CheckType
import Core
import Err
import Generalize
Expand Down Expand Up @@ -48,11 +45,6 @@ import Util (enumerate)
-- programmer. Those, however, are second-class: they are all
-- toplevel, and get specialized until they are first order.

-- Currently, simplification also discharges `CatchException` by
-- elaborating the expression into a Maybe-style monad. Note: the
-- plan is for `CatchException` to become a user-defined effect, and
-- for simplification to discharge all of them.

-- Simplification also opportunistically does peep-hole optimizations:
-- some constant folding, case-of-known-constructor, projecting known
-- elements from products, etc; but is not guaranteed to find all such
Expand Down Expand Up @@ -225,16 +217,6 @@ withSimplifiedBinders (Nest (bCore:>ty) bsCore) cont = do
withSimplifiedBinders bsCore' \bsSimp xs ->
cont (Nest bSimp bsSimp) (sink x:xs)

-- === Reconstructions ===

data ReconstructAtom (n::S) =
CoerceRecon (Type CoreIR n)
| LamRecon (ReconAbs SimpIR CAtom n)

applyRecon :: (EnvReader m, Fallible1 m) => ReconstructAtom n -> SAtom n -> m n (CAtom n)
applyRecon (CoerceRecon ty) x = liftSimpAtom ty x
applyRecon (LamRecon ab) x = applyReconAbs ab x

-- === Simplification monad ===

class (ScopableBuilder2 SimpIR m, SubstReader AtomSubstVal m) => Simplifier m
Expand Down Expand Up @@ -264,54 +246,19 @@ deriving instance ScopableBuilder SimpIR (SimplifyM i)

-- === Top-level API ===

data SimplifiedTopLam n = SimplifiedTopLam (STopLam n) (ReconstructAtom n)
data SimplifiedBlock n = SimplifiedBlock (SExpr n) (ReconstructAtom n)

simplifyTopBlock
:: (TopBuilder m, Mut n) => TopBlock CoreIR n -> m n (SimplifiedTopLam n)
:: (TopBuilder m, Mut n) => TopBlock CoreIR n -> m n (STopLam n)
simplifyTopBlock (TopLam _ _ (LamExpr Empty body)) = do
SimplifiedBlock block recon <- liftSimplifyM do
{-# SCC "Simplify" #-} buildSimplifiedBlock $ simplifyExpr body
topLam <- asTopLam $ LamExpr Empty block
return $ SimplifiedTopLam topLam recon
block <- liftSimplifyM do
buildSimplifiedBlock $ simplifyExpr body
asTopLam $ LamExpr Empty block
simplifyTopBlock _ = error "not a block (nullary lambda)"

simplifyTopFunction :: (TopBuilder m, Mut n) => CTopLam n -> m n (STopLam n)
simplifyTopFunction (TopLam False _ f) = do
asTopLam =<< liftSimplifyM do
(lam, CoerceReconAbs) <- {-# SCC "Simplify" #-} simplifyLam f
return lam
asTopLam =<< liftSimplifyM (simplifyLam f)
simplifyTopFunction _ = error "shouldn't be in destination-passing style already"

applyReconTop :: (EnvReader m, Fallible1 m) => ReconstructAtom n -> SAtom n -> m n (CAtom n)
applyReconTop = applyRecon

instance GenericE SimplifiedBlock where
type RepE SimplifiedBlock = PairE SExpr ReconstructAtom
fromE (SimplifiedBlock block recon) = PairE block recon
{-# INLINE fromE #-}
toE (PairE block recon) = SimplifiedBlock block recon
{-# INLINE toE #-}

instance SinkableE SimplifiedBlock
instance RenameE SimplifiedBlock
instance HoistableE SimplifiedBlock

instance Pretty (SimplifiedBlock n) where
pretty (SimplifiedBlock block recon) =
pretty block <> hardline <> pretty recon

instance SinkableE SimplifiedTopLam where
sinkingProofE = todoSinkableProof

instance CheckableE SimpIR SimplifiedTopLam where
checkE (SimplifiedTopLam lam recon) =
-- TODO: CheckableE instance for the recon too
SimplifiedTopLam <$> checkE lam <*> renameM recon

instance Pretty (SimplifiedTopLam n) where
pretty (SimplifiedTopLam lam recon) = pretty lam <> hardline <> pretty recon

-- === All the bits of IR ===

simplifyDecls :: Emits o => Nest (Decl CoreIR) i i' -> SimplifyM i' o a -> SimplifyM i o a
Expand Down Expand Up @@ -545,37 +492,21 @@ simplifyAtom = substM
-- Assumes first order (args/results are "data", allowing newtypes), monormophic
simplifyLam
:: LamExpr CoreIR i
-> SimplifyM i o (LamExpr SimpIR o, Abs (Nest (AtomNameBinder SimpIR)) ReconstructAtom o)
-> SimplifyM i o (LamExpr SimpIR o)
simplifyLam (LamExpr bsTop body) = case bsTop of
Nest b bs -> withSimplifiedBinder b \b'@(b'':>_) -> do
(LamExpr bs' body', Abs bsRecon recon) <- simplifyLam $ LamExpr bs body
return (LamExpr (Nest b' bs') body', Abs (Nest b'' bsRecon) recon)
Nest b bs -> withSimplifiedBinder b \b' -> do
LamExpr bs' body' <- simplifyLam $ LamExpr bs body
return $ LamExpr (Nest b' bs') body'
Empty -> do
SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyExpr body
return (LamExpr Empty body', Abs Empty recon)
body' <- buildSimplifiedBlock $ simplifyExpr body
return $ LamExpr Empty body'

buildSimplifiedBlock
:: (forall o'. (Emits o', DExt o o') => SimplifyM i o' (CAtom o'))
-> SimplifyM i o (SimplifiedBlock o)
buildSimplifiedBlock cont = do
Abs decls eitherResult <- buildScoped do
ans <- cont
tryAsDataAtom ans >>= \case
Nothing -> return $ LeftE ans
Just (dataResult, _) -> do
ansTy <- return $ getType ans
return $ RightE (dataResult `PairE` ansTy)
case eitherResult of
LeftE ans -> do
(blockAbs, recon) <- refreshAbs (Abs decls ans) \decls' ans' -> do
(newResult, reconAbs) <- telescopicCapture (toScopeFrag decls') ans'
return (Abs decls' newResult, LamRecon reconAbs)
block' <- mkBlock blockAbs
return $ SimplifiedBlock block' recon
RightE (ans `PairE` ty) -> do
let ty' = ignoreHoistFailure $ hoist (toScopeFrag decls) ty
block <- mkBlock $ Abs decls ans
return $ SimplifiedBlock block (CoerceRecon ty')
-> SimplifyM i o (SExpr o)
buildSimplifiedBlock cont = buildBlock do
ans <- cont
dropSubst $ toDataAtom ans

simplifyOp :: Emits o => PrimOp CoreIR i -> SimplifyM i o (CAtom o)
simplifyOp op = case op of
Expand Down Expand Up @@ -615,9 +546,6 @@ simplifyGenericOp op = do
liftSimpAtom ty =<< emit op'
{-# INLINE simplifyGenericOp #-}

pattern CoerceReconAbs :: Abs (Nest b) ReconstructAtom n
pattern CoerceReconAbs <- Abs _ (CoerceRecon _)

applyDictMethod :: Emits o => CType o -> CDict o -> Int -> [CAtom o] -> SimplifyM i o (CAtom o)
applyDictMethod resultTy d i methodArgs = case d of
DictCon (InstanceDict _ instanceName instanceArgs) -> dropSubst do
Expand All @@ -642,28 +570,25 @@ applyDictMethod resultTy d i methodArgs = case d of
simplifyHof :: Emits o => CType o -> Hof CoreIR i -> SimplifyM i o (CAtom o)
simplifyHof resultTy = \case
For d (IxType ixTy ixDict) lam -> do
(lam', CoerceReconAbs) <- simplifyLam lam
lam' <- simplifyLam lam
ixTy' <- getRepType ixTy
ixDict' <- simplifyIxDict ixDict
ans <- emitHof $ For d (IxType ixTy' ixDict') lam'
liftSimpAtom resultTy ans
While body -> do
SimplifiedBlock body' (CoerceRecon _) <- buildSimplifiedBlock $ simplifyExpr body
body' <- buildSimplifiedBlock $ simplifyExpr body
result <- emitHof $ While body'
liftSimpAtom resultTy result
Linearize lam x -> do
x' <- toDataAtom x
-- XXX: we're ignoring the result type here, which only makes sense if we're
-- dealing with functions on simple types.
(lam', recon) <- simplifyLam lam
CoerceReconAbs <- return recon
lam' <- simplifyLam lam
(result, linFun) <- liftDoubleBuilderToSimplifyM $ linearize lam' x'
PairTy lamResultTy linFunTy <- return resultTy
result' <- liftSimpAtom lamResultTy result
linFun' <- liftSimpFun linFunTy linFun
return $ PairVal result' linFun'
Transpose lam x -> do
(lam', CoerceReconAbs) <- simplifyLam lam
lam' <- simplifyLam lam
x' <- toDataAtom x
result <- transpose lam' x'
liftSimpAtom resultTy result
Expand Down Expand Up @@ -796,26 +721,3 @@ defuncLinearized ab = liftBuilder $ refreshAbs ab \bs ab' -> do
applyRename (tangentBs' @@> (atomVarName <$> tangents)) body >>= emit
let tangentFun = LamExpr (bs >>> residualAndTangentBs) tangentBody
return $ PairE primalFun tangentFun

-- === instances ===

instance GenericE ReconstructAtom where
type RepE ReconstructAtom = EitherE2 (Type CoreIR) (ReconAbs SimpIR CAtom)

fromE = \case
CoerceRecon ty -> Case0 ty
LamRecon ab -> Case1 ab
{-# INLINE fromE #-}
toE = \case
Case0 ty -> CoerceRecon ty
Case1 ab -> LamRecon ab
_ -> error "impossible"
{-# INLINE toE #-}

instance SinkableE ReconstructAtom
instance HoistableE ReconstructAtom
instance RenameE ReconstructAtom

instance Pretty (ReconstructAtom n) where
pretty (CoerceRecon ty) = "Coercion reconstruction: " <> pretty ty
pretty (LamRecon ab) = "Reconstruction abs: " <> pretty ab
36 changes: 19 additions & 17 deletions src/lib/TopLevel.hs
Original file line number Diff line number Diff line change
Expand Up @@ -473,23 +473,25 @@ whenOpt x act = getConfig <&> optLevel >>= \case
Optimize -> act x

evalBlock :: (Topper m, Mut n) => TopBlock CoreIR n -> m n (CAtom n)
evalBlock (TopLam _ _ (LamExpr Empty (Atom result))) = return result
evalBlock typed = do
SimplifiedTopLam simp recon <- checkPass SimpPass $ simplifyTopBlock typed
opt <- simpOptimizations simp
simpResult <- case opt of
TopLam _ _ (LamExpr Empty (Atom result)) -> return result
_ -> do
dps <- checkPass LowerPass $ dpsPass opt
lOpt <- checkPass OptPass $ loweredOptimizations dps
cc <- getEntryFunCC
impOpt <- checkPass ImpPass $ toImpFunction cc lOpt
llvmOpt <- packageLLVMCallable impOpt
resultVals <- liftIO $ callEntryFun llvmOpt []
TopLam _ destTy _ <- return lOpt
resultTy <- return $ assumeConst $ piTypeWithoutDest destTy
repValAtom =<< repValFromFlatList resultTy resultVals
applyReconTop recon simpResult
evalBlock typed@(TopLam _ _ (LamExpr Empty body)) = case body of
Atom result -> return result
_ -> do
simp <- checkPass SimpPass $ simplifyTopBlock typed
opt <- simpOptimizations simp
simpResult <- case opt of
TopLam _ _ (LamExpr Empty (Atom result)) -> return result
_ -> do
dps <- checkPass LowerPass $ dpsPass opt
lOpt <- checkPass OptPass $ loweredOptimizations dps
cc <- getEntryFunCC
impOpt <- checkPass ImpPass $ toImpFunction cc lOpt
llvmOpt <- packageLLVMCallable impOpt
resultVals <- liftIO $ callEntryFun llvmOpt []
TopLam _ destTy _ <- return lOpt
resultTy <- return $ assumeConst $ piTypeWithoutDest destTy
repValAtom =<< repValFromFlatList resultTy resultVals
liftSimpAtom (getType body) simpResult
evalBlock _ = error "not a top block"
{-# SCC evalBlock #-}

simpOptimizations :: Topper m => STopLam n -> m n (STopLam n)
Expand Down

0 comments on commit 6507556

Please sign in to comment.