diff --git a/dex.cabal b/dex.cabal index dcb537a3c..1b752a1ef 100644 --- a/dex.cabal +++ b/dex.cabal @@ -41,7 +41,6 @@ flag debug library exposed-modules: AbstractSyntax - , Algebra , Builder , CUDA , CheapReduction diff --git a/src/lib/Algebra.hs b/src/lib/Algebra.hs deleted file mode 100644 index 1175d1523..000000000 --- a/src/lib/Algebra.hs +++ /dev/null @@ -1,247 +0,0 @@ --- Copyright 2020 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -{-# LANGUAGE UndecidableInstances #-} - -module Algebra (sumUsingPolys) where - -import Prelude hiding (lookup, sum, pi) -import Control.Monad -import Data.Functor -import Data.Ratio -import Control.Applicative -import Data.Map.Strict hiding (foldl, map, empty, (!)) -import Data.Text.Prettyprint.Doc -import Data.List (intersperse) -import Data.Tuple (swap) - -import Builder -import Core -import CheapReduction -import Err -import IRVariants -import MTL1 -import Name -import Subst -import QueryType -import PPrint -import Types.Core -import Types.Imp -import Types.Primitives -import Util (Tree (..)) - -type PolyName = EitherE (AtomName SimpIR) ImpName -type PolyBinder = AtomNameBinder SimpIR - -type Constant = Rational - --- Set of variables, each with its power. -data Monomial n = Monomial { fromMonomial :: Map (PolyName n) Int } - deriving (Show, Eq, Ord) - --- Set of monomials, each multiplied by a constant. -newtype Polynomial (n::S) = - Polynomial { fromPolynomial :: Map (Monomial n) Constant } - deriving (Show, Eq, Ord) - --- This is the main entrypoint. Doing polynomial math sometimes lets --- us compute sums in closed form. This tries to compute --- `\sum_{i=0}^(lim-1) body`. `i`, `lim`, and `body` should all have type `Nat`. -sumUsingPolys :: Emits n - => SAtom n -> Abs (Binder SimpIR) (Expr SimpIR) n -> BuilderM SimpIR n (SAtom n) -sumUsingPolys lim (Abs i body) = do - sumAbs <- refreshAbs (Abs i body) \(i':>_) body' -> do - exprAsPoly body' >>= \case - Just poly' -> return $ Abs i' poly' - Nothing -> throwInternal $ - "Algebraic simplification failed to model index computations:\n" - ++ "Trying to sum from 0 to " ++ pprint lim ++ " - 1, \\" - ++ pprint i' ++ "." ++ pprint body' - limName <- emitToVar (Atom lim) - emitPolynomial $ sum (LeftE (atomVarName limName)) sumAbs - -mul :: Polynomial n-> Polynomial n -> Polynomial n -mul (Polynomial x) (Polynomial y) = - poly [ (cx * cy, mulMono mx my) - | (mx, cx) <- toList x, (my, cy) <- toList y] - -mulMono :: Monomial n -> Monomial n -> Monomial n -mulMono (Monomial mx) (Monomial my) = Monomial $ unionWith (+) mx my - -add :: Polynomial n -> Polynomial n -> Polynomial n -add x y = Polynomial $ unionWith (+) (fromPolynomial x) (fromPolynomial y) - -sub :: Polynomial n -> Polynomial n -> Polynomial n -sub x y = add x (Polynomial $ negate <$> fromPolynomial y) - -sumPolys :: [Polynomial n] -> Polynomial n -sumPolys ps = Polynomial $ unionsWith (+) $ map fromPolynomial ps - -mulConst :: Constant -> Polynomial n -> Polynomial n -mulConst c (Polynomial p) = Polynomial $ (*c) <$> p - --- evaluates `\sum_{i=0}^(lim-1) p` -sum :: PolyName n -> Abs PolyBinder Polynomial n -> Polynomial n -sum lim (Abs i p) = sumPolys polys - where polys = (toList $ fromPolynomial p) <&> \(m, c) -> - mulConst c $ sumMono lim (Abs i m) - -sumMono :: PolyName n -> Abs PolyBinder Monomial n -> Polynomial n -sumMono lim (Abs b (Monomial m)) = case lookup (LeftE $ binderName b) m of - -- TODO: Implement the formula for arbitrary order polynomials - Nothing -> poly [ ( 1, mulMono c $ mono [(lim, 1)])] - Just 0 -> error "Each variable appearing in a monomial should have a positive power" - -- Summing exclusive of `lim`: Sum_{i=1}^{n-1} i = (n-1)n/2 = 1/2 n^2 - 1/2 n - Just 1 -> poly [ ( 1/2, mulMono c $ mono [(lim, 2)]) - , (-1/2, mulMono c $ mono [(lim, 1)])] - -- Summing exclusive of `lim`: Sum_{i=1}^{n-1} i^2 = (n-1)n(2n-1)/6 = 1/3 n^3 - 1/2 n^2 + 1/6 n - Just 2 -> poly [ ( 1/3, mulMono c $ mono [(lim, 3)]) - , (-1/2, mulMono c $ mono [(lim, 2)]) - , ( 1/6, mulMono c $ mono [(lim, 1)])] - (Just n) -> error $ "Triangular arrays of order " ++ show n ++ " not implemented yet!" - where - c = ignoreHoistFailure $ hoist b $ -- failure impossible - Monomial $ delete (LeftE $ binderName b) m - --- === Constructors and singletons === - -poly :: [(Constant, Monomial n)] -> Polynomial n -poly monos = Polynomial $ fromListWith (+) $ fmap swap monos - -mono :: [(PolyName n, Int)] -> Monomial n -mono vars = Monomial $ fromListWith (error "Duplicate entries for variable") vars - --- === Type classes and helpers === - -showMono :: Monomial n -> String -showMono (Monomial m) = - concat $ intersperse " " $ fmap (\(n, p) -> docAsStr $ pretty n <> "^" <> pretty p) $ toList m - -_showPoly :: Polynomial n -> String -_showPoly (Polynomial p) = - concat $ intersperse " + " $ fmap (\(m, c) -> show c ++ " " ++ showMono m) $ toList p - --- === core expressions as polynomials === - -data PolySubstVal (c::C) (n::S) where - PolySubstVal :: Maybe (Polynomial n) -> PolySubstVal (AtomNameC SimpIR) n - PolyRename :: Name c n -> PolySubstVal c n - -instance SinkableV PolySubstVal -instance SinkableE (PolySubstVal c) where sinkingProofE = undefined -instance FromName PolySubstVal where fromName = PolyRename - -type BlockTraverserM i o a = SubstReaderT PolySubstVal (MaybeT1 (BuilderM SimpIR)) i o a - -exprAsPoly :: (EnvExtender m, EnvReader m) => SExpr n -> m n (Maybe (Polynomial n)) -exprAsPoly expr = liftBuilder $ runMaybeT1 $ runSubstReaderT idSubst $ exprAsPolyRec expr - -atomAsPoly :: SAtom i -> BlockTraverserM i o (Polynomial o) -atomAsPoly = \case - Stuck _ (Var v) -> atomVarAsPoly v - Stuck _ (RepValAtom (RepVal _ (Leaf (IVar v' _)))) -> impNameAsPoly v' - IdxRepVal i -> return $ poly [((fromIntegral i) % 1, mono [])] - _ -> empty - -impNameAsPoly :: ImpName i -> BlockTraverserM i o (Polynomial o) -impNameAsPoly v = getSubst <&> (!v) >>= \case - PolyRename v' -> return $ poly [(1, mono [(RightE v', 1)])] - -atomVarAsPoly :: AtomVar SimpIR i -> BlockTraverserM i o (Polynomial o) -atomVarAsPoly v = getSubst <&> (! atomVarName v) >>= \case - PolySubstVal Nothing -> empty - PolySubstVal (Just cp) -> return cp - PolyRename v' -> do - v'' <- toAtomVar v' - case getType v'' of - IdxRepTy -> return $ poly [(1, mono [(LeftE v', 1)])] - _ -> empty - -exprAsPolyRec :: Expr SimpIR i -> BlockTraverserM i o (Polynomial o) -exprAsPolyRec e = case e of - Block _ block -> blockAsPoly block - Atom a -> atomAsPoly a - PrimOp (BinOp op x y) -> case op of - IAdd -> add <$> atomAsPoly x <*> atomAsPoly y - IMul -> mul <$> atomAsPoly x <*> atomAsPoly y - -- XXX: we rely on the wrapping behavior of subtraction on unsigned ints - -- so that the distributive law holds, `a * (b - c) == (a * b) - (a * c)` - ISub -> sub <$> atomAsPoly x <*> atomAsPoly y - -- This is to handle `idiv` generated by `emitPolynomial` - IDiv -> case y of - IdxRepVal n -> mulConst (1 / fromIntegral n) <$> atomAsPoly x - _ -> empty - _ -> empty - _ -> empty - -blockAsPoly :: SBlock i -> BlockTraverserM i o (Polynomial o) -blockAsPoly (Abs decls result) = case decls of - Empty -> exprAsPolyRec result - Nest (Let b (DeclBinding _ expr)) restDecls -> do - p <- optional (exprAsPolyRec expr) - extendSubst (b@>PolySubstVal p) $ blockAsPoly $ Abs restDecls result - --- === polynomials to Core expressions === - --- We have to be extra careful here, because we're evaluating a polynomial --- that we know is guaranteed to return an integral number, but it has rational --- coefficients. This is why we have to find the least common multiples and do the --- accumulation over numbers multiplied by that LCM. We essentially do fixed point --- fractional math here. -emitPolynomial :: Emits n => Polynomial n -> BuilderM SimpIR n (SAtom n) -emitPolynomial (Polynomial p) = do - let constLCM = asAtom $ foldl lcm 1 $ fmap (denominator . snd) $ toList p - monoAtoms <- flip traverse (toList p) $ \(m, c) -> do - lcmFactor <- constLCM `idiv` (asAtom $ denominator c) - constFactor <- imul (asAtom $ numerator c) lcmFactor - imul constFactor =<< emitMonomial m - total <- foldM iadd (IdxRepVal 0) monoAtoms - total `idiv` constLCM - where - -- TODO: Check for overflows. We might also want to bail out if the LCM is too large, - -- because it might be causing overflows due to all arithmetic being shifted. - asAtom = IdxRepVal . fromInteger - -emitMonomial :: Emits n => Monomial n -> BuilderM SimpIR n (SAtom n) -emitMonomial (Monomial m) = do - varAtoms <- forM (toList m) \(v, e) -> case v of - LeftE v' -> do - v'' <- toAtom <$> toAtomVar v' - ipow v'' e - RightE v' -> do - atom <- mkStuck $ RepValAtom $ RepVal IdxRepTy (Leaf (IVar v' IIdxRepTy)) - ipow atom e - foldM imul (IdxRepVal 1) varAtoms - -ipow :: Emits n => SAtom n -> Int -> BuilderM SimpIR n (SAtom n) -ipow x i = foldM imul (IdxRepVal 1) (replicate i x) - -idiv :: Emits n => SAtom n -> SAtom n -> BuilderM SimpIR n (SAtom n) -idiv = undefined - --- === instances === - -instance GenericE Monomial where - type RepE Monomial = ListE (PairE PolyName (LiftE Int)) - fromE (Monomial m) = ListE $ toList m <&> \(v, n) -> PairE v (LiftE n) - {-# INLINE fromE #-} - toE (ListE pairs) = Monomial $ fromList $ pairs <&> \(PairE v (LiftE n)) -> (v, n) - {-# INLINE toE #-} - -instance SinkableE Monomial -instance HoistableE Monomial -instance AlphaEqE Monomial - -instance GenericE Polynomial where - type RepE Polynomial = ListE (PairE Monomial (LiftE Constant)) - fromE (Polynomial m) = ListE $ toList m <&> \(x, n) -> PairE x (LiftE n) - {-# INLINE fromE #-} - toE (ListE pairs) = Polynomial $ fromList $ pairs <&> \(PairE x (LiftE n)) -> (x, n) - {-# INLINE toE #-} - -instance SinkableE Polynomial -instance HoistableE Polynomial -instance AlphaEqE Polynomial diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 2e42b3a9f..a4740d4c9 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -13,7 +13,7 @@ module Imp , repValFromFlatList, addImpTracing -- These are just for the benefit of serialization/printing. otherwise we wouldn't need them , BufferType (..), IdxNest, IndexStructure, IExprInterpretation (..), typeToTree - , computeOffset, getIExprInterpretation + , getIExprInterpretation , isSingletonType, singletonTypeVal ) where @@ -28,7 +28,6 @@ import Control.Monad.Writer.Strict import Control.Monad.State.Strict hiding (State) import qualified Control.Monad.State.Strict as MTL -import Algebra import Builder import CheapReduction import CheckType (CheckableE (..)) @@ -854,16 +853,7 @@ buildGarbageVal ty = -- === Operations on dests === indexDest :: Emits n => Dest n -> SAtom n -> SubstImpM i n (Dest n) -indexDest (Dest (TyCon (TabPi tabTy)) tree) i = do - eltTy <- instantiate tabTy [i] - ord <- ordinalImp (tabIxType tabTy) i - leafTys <- typeToTree $ toType tabTy - Dest eltTy <$> forM (zipTrees leafTys tree) \(leafTy, ptr) -> do - BufferType ixStruct _ <- return $ getRefBufferType leafTy - offset <- computeOffsetImp ixStruct ord - impOffset ptr offset -indexDest _ _ = error "expected a reference to a table" -{-# INLINE indexDest #-} +indexDest (Dest (TyCon (TabPi tabTy)) tree) i = undefined projectDest :: Int -> Dest n -> Dest n projectDest i (Dest (TyCon (ProdType tys)) (Branch ds)) = @@ -876,20 +866,7 @@ type SBuilderM = BuilderM SimpIR computeElemCountImp :: Emits n => IndexStructure SimpIR n -> SubstImpM i n (IExpr n) computeElemCountImp Singleton = return $ IIdxRepVal 1 -computeElemCountImp idxs = do - result <- liftBuilderImp do - idxs' <- sinkM idxs - computeElemCount idxs' - fromScalarAtom result - -computeOffsetImp - :: Emits n => IndexStructure SimpIR n -> IExpr n -> SubstImpM i n (IExpr n) -computeOffsetImp idxs ixOrd = do - let ixOrd' = toScalarAtom ixOrd - result <- liftBuilderImp do - PairE idxs' ixOrd'' <- sinkM $ PairE idxs ixOrd' - computeOffset idxs' ixOrd'' - fromScalarAtom result +computeElemCountImp _ = undefined computeElemCount :: Emits n => IndexStructure SimpIR n -> SBuilderM n (Atom SimpIR n) computeElemCount (EmptyAbs Empty) = @@ -897,31 +874,6 @@ computeElemCount (EmptyAbs Empty) = -- in the case that we don't have any indices. The more general path will -- still compute `1`, but it might emit decls along the way. return $ IdxRepVal 1 -computeElemCount idxNest' = do - let (idxList, idxNest) = indexStructureSplit idxNest' - sizes <- forM idxList indexSetSize - listSize <- foldM imul (IdxRepVal 1) sizes - nestSize <- elemCountPoly idxNest - imul listSize nestSize - -elemCountPoly :: Emits n => IndexStructure SimpIR n -> SBuilderM n (Atom SimpIR n) -elemCountPoly (Abs bs UnitE) = case bs of - Empty -> return $ IdxRepVal 1 - Nest b@(PairB (LiftB d) (_:>t)) rest -> do - curSize <- indexSetSize $ IxType t d - restSizes <- computeSizeGivenOrdinal b $ EmptyAbs rest - sumUsingPolysImp curSize restSizes - -computeSizeGivenOrdinal - :: EnvReader m - => IxBinder SimpIR n l -> IndexStructure SimpIR l - -> m n (Abs SBinder SExpr n) -computeSizeGivenOrdinal (PairB (LiftB d) (b:>t)) idxStruct = liftBuilder do - withFreshBinder noHint IdxRepTy \bOrdinal -> - Abs bOrdinal <$> buildBlock do - i <- unsafeFromOrdinal (sink $ IxType t d) $ toAtom $ sink $ binderVar bOrdinal - idxStruct' <- applySubst (b@>SubstVal i) idxStruct - elemCountPoly $ sink idxStruct' -- Split the index structure into a prefix of non-dependent index types -- and a trailing nest of indices that can contain inter-dependencies. @@ -933,26 +885,6 @@ indexStructureSplit s@(Abs (Nest (PairB (LiftB d) b) rest) UnitE) = HoistSuccess rest' -> (IxType (binderType b) d:ans1, ans2) where (ans1, ans2) = indexStructureSplit rest' -computeOffset :: forall n. Emits n - => IndexStructure SimpIR n -> SAtom n -> SBuilderM n (SAtom n) -computeOffset (EmptyAbs (Nest _ Empty)) i = return i -- optimization -computeOffset (EmptyAbs (Nest b idxs)) idxOrdinal = do - case hoist b (EmptyAbs idxs) of - HoistFailure _ -> do - rhsElemCounts <- computeSizeGivenOrdinal b (EmptyAbs idxs) - sumUsingPolysImp idxOrdinal rhsElemCounts - HoistSuccess idxs' -> do - stride <- computeElemCount idxs' - idxOrdinal `imul` stride -computeOffset _ _ = error "Expected a nonempty nest of idx binders" - -sumUsingPolysImp - :: Emits n => SAtom n - -> Abs SBinder SExpr n -> BuilderM SimpIR n (SAtom n) -sumUsingPolysImp lim (Abs i body) = do - ab <- hoistDecls i body - sumUsingPolys lim ab - hoistDecls :: ( Builder SimpIR m, EnvReader m, Emits n , BindsNames b, BindsEnv b, RenameB b, SinkableB b)