diff --git a/spidr/src/Model/GaussianProcess.idr b/spidr/src/Model/GaussianProcess.idr index 92149e087..d2de6e1d4 100644 --- a/spidr/src/Model/GaussianProcess.idr +++ b/spidr/src/Model/GaussianProcess.idr @@ -113,7 +113,7 @@ fit : (forall n . Tensor [n] F64 -> Optimizer $ Tensor [n] F64) fit optimizer (MkDataset x y) (MkConjugateGPR {p} mkPrior gpParams noise) = do let objective : Tensor [S p] F64 -> Tag $ Tensor [] F64 objective params = do - let priorParams = slice [1.to (S p)] params + let priorParams = slice {inBounds = rewrite compareNatDiag (S p) in %search} [1.to (S p)] params logMarginalLikelihood !(mkPrior priorParams) (slice [at 0] params) (x, squeeze y) params <- optimizer (concat 0 (expand 0 noise) gpParams) objective diff --git a/spidr/src/Tensor.idr b/spidr/src/Tensor.idr index e662ebc98..0e6e13a53 100644 --- a/spidr/src/Tensor.idr +++ b/spidr/src/Tensor.idr @@ -23,6 +23,7 @@ limitations under the License. module Tensor import Control.Monad.Error.Either +import public Data.Either import public Control.Monad.State import public Data.List import public Data.List.Elem @@ -311,68 +312,69 @@ squeeze : Tensor to dtype squeeze $ MkTensor {shape} x = MkTensor $ Reshape shape to x -||| A `SliceOrIndex d` is a valid slice or index into a dimension of size `d`. See `slice` for +||| A `Subset d` is a valid slice or index into a dimension of size `d`. See `slice` for ||| details. export -data SliceOrIndex : Nat -> Type where - Slice : - (from, to : Nat) -> - {size : _} -> - {auto 0 fromTo : from + size = to} -> - {auto 0 inDim : LTE to d} -> - SliceOrIndex d - Index : (idx : Nat) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d - DynamicSlice : Tensor [] U64 -> (size : Nat) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d - DynamicIndex : Tensor [] U64 -> SliceOrIndex d +data Subset : Type where + All : Subset + Slice : (from, to : Nat) -> Subset + Index : (idx : Nat) -> Subset + DynamicSlice : Tensor [] U64 -> (size : Nat) -> Subset + DynamicIndex : Tensor [] U64 -> Subset ||| Index at `idx`. See `slice` for details. public export -at : (idx : Nat) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d +at : (idx : Nat) -> Subset at = Index namespace Dynamic ||| Index at the specified index. See `slice` for details. public export - at : Tensor [] U64 -> SliceOrIndex d + at : Tensor [] U64 -> Subset at = DynamicIndex ||| Slice from `from` (inclusive) to `to` (exclusive). See `slice` for details. public export -(.to) : - (from, to : Nat) -> - {size : _} -> - {auto 0 fromTo : from + size = to} -> - {auto 0 inDim : LTE to d} -> - SliceOrIndex d -(.to) = Slice +(.to) : (from, to : Nat) -> {auto 0 ordered : from `LTE` to} -> Subset +from.to to = Slice from to ||| Slice `size` elements starting at the specified scalar `U64` index. See `slice` for details. public export -(.size) : Tensor [] U64 -> (size : Nat) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d +(.size) : Tensor [] U64 -> (size : Nat) -> Subset (.size) = DynamicSlice ||| Slice across all indices along an axis. See `slice` for details. public export -all : {d : _} -> SliceOrIndex d -all = Slice 0 @{%search} @{reflexive {ty = Nat}} d +all : Subset +all = All -||| A `MultiSlice shape` is a valid multi-dimensional slice into a tensor with shape `shape`. -||| See `slice` for details. -public export -data MultiSlice : Shape -> Type where - Nil : MultiSlice ds - (::) : SliceOrIndex d -> MultiSlice ds -> MultiSlice (d :: ds) +namespace Subset + public export + data InvalidSubsetError = + ||| The number of dimensions requested and found (in that order) + OutOfBounds Nat Nat + + | ||| The number of unaccounted-for axes + TooManyAxes Nat -namespace MultiSlice ||| The shape of a tensor produced by slicing with the specified multi-dimensional slice. See ||| `Tensor.slice` for details. public export - slice : {shape : _} -> MultiSlice shape -> Shape - slice {shape} [] = shape - slice {shape = (_ :: _)} (Slice {size} _ _ :: xs) = size :: slice xs - slice {shape = (_ :: _)} (Index _ :: xs) = slice xs - slice {shape = (_ :: _)} (DynamicSlice _ size :: xs) = size :: slice xs - slice {shape = (_ :: _)} (DynamicIndex _ :: xs) = slice xs + slice : List Subset -> Shape -> Either InvalidSubsetError Shape + slice at@(_ :: _) [] = Left $ TooManyAxes (length at) + slice [] ds = Right ds + slice (All :: xs) (d :: ds) = map (d ::) (slice xs ds) + slice (Slice from to :: xs) (d :: ds) = + ifThenElse (to > d) (Left $ OutOfBounds to d) $ (to `minus` from ::) <$> slice xs ds + slice (Index idx :: xs) (d :: ds) = ifThenElse (idx >= d) (Left $ OutOfBounds idx d) $ slice xs ds + slice (DynamicSlice _ size :: xs) (d :: ds) = + ifThenElse (size > d) (Left $ OutOfBounds size d) $ map (size ::) (slice xs ds) + slice (DynamicIndex _ :: xs) (_ :: ds) = slice xs ds + +public export +fromRight : (e : Either l r) -> {auto 0 isRight : IsRight e} -> r +fromRight (Right r) = r +fromRight (Left _) impossible ||| Slice or index `Tensor` axes. Each axis can be sliced or indexed, and this can be done with ||| either static (`Nat`) or dynamic (scalar `U64`) indices. @@ -461,48 +463,48 @@ namespace MultiSlice ||| @at The multi-dimensional slices and indices at which to slice the tensor. export slice : + forall shape, dtype . Primitive dtype => - (at : MultiSlice shape) -> + (at : List Subset) -> Tensor shape dtype -> - Tensor (slice at) dtype -slice at $ MkTensor x = MkTensor - $ Reshape (mapd size id at) (MultiSlice.slice at) - $ DynamicSlice (dynStarts [] at) (mapd size id at) - $ Slice (mapd start (const 0) at) (mapd stop id at) (replicate (length shape) 1) x - - where - mapd : ((Nat -> a) -> {d : Nat} -> SliceOrIndex d -> a) -> - (Nat -> a) -> - {shape : Shape} -> - MultiSlice shape -> - List a - mapd _ dflt {shape} [] = Prelude.map dflt shape - mapd f dflt (x :: xs) = f dflt x :: mapd f dflt xs - - start : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat - start _ (Slice from _) = from - start _ (Index idx) = idx - start f {d} _ = f d - - stop : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat - stop _ (Slice _ to) = to - stop _ (Index idx) = S idx - stop f {d} _ = f d - - size : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat - size _ (Slice {size = size'} _ _) = size' - size _ (Index _) = 1 - size _ (DynamicSlice _ size') = size' - size _ (DynamicIndex _) = 1 - - zero : Expr - zero = FromLiteral {shape = []} {dtype = U64} 0 - - dynStarts : List Expr -> {shape : _} -> MultiSlice shape -> List Expr - dynStarts idxs {shape} [] = replicate (length shape) zero ++ idxs - dynStarts idxs (DynamicSlice (MkTensor i) _ :: ds) = i :: dynStarts idxs ds - dynStarts idxs (DynamicIndex (MkTensor i) :: ds) = i :: dynStarts idxs ds - dynStarts idxs (_ :: ds) = zero :: dynStarts idxs ds + let out = slice at shape in + {auto 0 inBounds : IsRight out} -> + Tensor (fromRight out) dtype +slice at $ MkTensor x = + let ldiff = length shape `minus` length at + at' = at ++ replicate ldiff All + map = Prelude.map -- help type inference + in MkTensor + $ Reshape (zipWith size shape at') (fromRight $ slice at shape) + $ DynamicSlice (map dynStart at ++ replicate ldiff zero) (zipWith size shape at') + $ Slice (map start at) (zipWith stop shape at') (replicate (length shape) 1) x + + where + + start : Subset -> Nat + start (Slice from _) = from + start (Index idx) = idx + start _ = 0 + + stop : Nat -> Subset -> Nat + stop _ (Slice _ to) = to + stop _ (Index idx) = S idx + stop d _ = d + + size : Nat -> Subset -> Nat + size _ (Slice from to) = to `minus` from + size _ (Index _) = 1 + size _ (DynamicSlice _ size') = size' + size _ (DynamicIndex _) = 1 + size d All = d + + zero : Expr + zero = FromLiteral {shape = []} {dtype = U64} 0 + + dynStart : Subset -> Expr + dynStart (DynamicSlice (MkTensor i) _) = i + dynStart (DynamicIndex (MkTensor i)) = i + dynStart _ = zero ||| Concatenate two `Tensor`s along the specfied `axis`. For example, ||| `concat 0 (tensor [[1, 2], [3, 4]]) (tensor [[5, 6]])` and