Skip to content

improve slice type inference with decidability #417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion spidr/src/Model/GaussianProcess.idr
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ fit : (forall n . Tensor [n] F64 -> Optimizer $ Tensor [n] F64)
fit optimizer (MkDataset x y) (MkConjugateGPR {p} mkPrior gpParams noise) = do
let objective : Tensor [S p] F64 -> Tag $ Tensor [] F64
objective params = do
let priorParams = slice [1.to (S p)] params
let priorParams = slice {inBounds = rewrite compareNatDiag (S p) in %search} [1.to (S p)] params
logMarginalLikelihood !(mkPrior priorParams) (slice [at 0] params) (x, squeeze y)

params <- optimizer (concat 0 (expand 0 noise) gpParams) objective
Expand Down
154 changes: 78 additions & 76 deletions spidr/src/Tensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
module Tensor

import Control.Monad.Error.Either
import public Data.Either
import public Control.Monad.State
import public Data.List
import public Data.List.Elem
Expand Down Expand Up @@ -311,68 +312,69 @@ squeeze :
Tensor to dtype
squeeze $ MkTensor {shape} x = MkTensor $ Reshape shape to x

||| A `SliceOrIndex d` is a valid slice or index into a dimension of size `d`. See `slice` for
||| A `Subset d` is a valid slice or index into a dimension of size `d`. See `slice` for
||| details.
export
data SliceOrIndex : Nat -> Type where
Slice :
(from, to : Nat) ->
{size : _} ->
{auto 0 fromTo : from + size = to} ->
{auto 0 inDim : LTE to d} ->
SliceOrIndex d
Index : (idx : Nat) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d
DynamicSlice : Tensor [] U64 -> (size : Nat) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d
DynamicIndex : Tensor [] U64 -> SliceOrIndex d
data Subset : Type where
All : Subset
Slice : (from, to : Nat) -> Subset
Index : (idx : Nat) -> Subset
DynamicSlice : Tensor [] U64 -> (size : Nat) -> Subset
DynamicIndex : Tensor [] U64 -> Subset

||| Index at `idx`. See `slice` for details.
public export
at : (idx : Nat) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d
at : (idx : Nat) -> Subset
at = Index

namespace Dynamic
||| Index at the specified index. See `slice` for details.
public export
at : Tensor [] U64 -> SliceOrIndex d
at : Tensor [] U64 -> Subset
at = DynamicIndex

||| Slice from `from` (inclusive) to `to` (exclusive). See `slice` for details.
public export
(.to) :
(from, to : Nat) ->
{size : _} ->
{auto 0 fromTo : from + size = to} ->
{auto 0 inDim : LTE to d} ->
SliceOrIndex d
(.to) = Slice
(.to) : (from, to : Nat) -> {auto 0 ordered : from `LTE` to} -> Subset
from.to to = Slice from to

||| Slice `size` elements starting at the specified scalar `U64` index. See `slice` for details.
public export
(.size) : Tensor [] U64 -> (size : Nat) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d
(.size) : Tensor [] U64 -> (size : Nat) -> Subset
(.size) = DynamicSlice

||| Slice across all indices along an axis. See `slice` for details.
public export
all : {d : _} -> SliceOrIndex d
all = Slice 0 @{%search} @{reflexive {ty = Nat}} d
all : Subset
all = All

||| A `MultiSlice shape` is a valid multi-dimensional slice into a tensor with shape `shape`.
||| See `slice` for details.
public export
data MultiSlice : Shape -> Type where
Nil : MultiSlice ds
(::) : SliceOrIndex d -> MultiSlice ds -> MultiSlice (d :: ds)
namespace Subset
public export
data InvalidSubsetError =
||| The number of dimensions requested and found (in that order)
OutOfBounds Nat Nat

| ||| The number of unaccounted-for axes
TooManyAxes Nat

namespace MultiSlice
||| The shape of a tensor produced by slicing with the specified multi-dimensional slice. See
||| `Tensor.slice` for details.
public export
slice : {shape : _} -> MultiSlice shape -> Shape
slice {shape} [] = shape
slice {shape = (_ :: _)} (Slice {size} _ _ :: xs) = size :: slice xs
slice {shape = (_ :: _)} (Index _ :: xs) = slice xs
slice {shape = (_ :: _)} (DynamicSlice _ size :: xs) = size :: slice xs
slice {shape = (_ :: _)} (DynamicIndex _ :: xs) = slice xs
slice : List Subset -> Shape -> Either InvalidSubsetError Shape
slice at@(_ :: _) [] = Left $ TooManyAxes (length at)
slice [] ds = Right ds
slice (All :: xs) (d :: ds) = map (d ::) (slice xs ds)
slice (Slice from to :: xs) (d :: ds) =
ifThenElse (to > d) (Left $ OutOfBounds to d) $ (to `minus` from ::) <$> slice xs ds
slice (Index idx :: xs) (d :: ds) = ifThenElse (idx >= d) (Left $ OutOfBounds idx d) $ slice xs ds
slice (DynamicSlice _ size :: xs) (d :: ds) =
ifThenElse (size > d) (Left $ OutOfBounds size d) $ map (size ::) (slice xs ds)
slice (DynamicIndex _ :: xs) (_ :: ds) = slice xs ds

public export
fromRight : (e : Either l r) -> {auto 0 isRight : IsRight e} -> r
fromRight (Right r) = r
fromRight (Left _) impossible

||| Slice or index `Tensor` axes. Each axis can be sliced or indexed, and this can be done with
||| either static (`Nat`) or dynamic (scalar `U64`) indices.
Expand Down Expand Up @@ -461,48 +463,48 @@ namespace MultiSlice
||| @at The multi-dimensional slices and indices at which to slice the tensor.
export
slice :
forall shape, dtype .
Primitive dtype =>
(at : MultiSlice shape) ->
(at : List Subset) ->
Tensor shape dtype ->
Tensor (slice at) dtype
slice at $ MkTensor x = MkTensor
$ Reshape (mapd size id at) (MultiSlice.slice at)
$ DynamicSlice (dynStarts [] at) (mapd size id at)
$ Slice (mapd start (const 0) at) (mapd stop id at) (replicate (length shape) 1) x

where
mapd : ((Nat -> a) -> {d : Nat} -> SliceOrIndex d -> a) ->
(Nat -> a) ->
{shape : Shape} ->
MultiSlice shape ->
List a
mapd _ dflt {shape} [] = Prelude.map dflt shape
mapd f dflt (x :: xs) = f dflt x :: mapd f dflt xs

start : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
start _ (Slice from _) = from
start _ (Index idx) = idx
start f {d} _ = f d

stop : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
stop _ (Slice _ to) = to
stop _ (Index idx) = S idx
stop f {d} _ = f d

size : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
size _ (Slice {size = size'} _ _) = size'
size _ (Index _) = 1
size _ (DynamicSlice _ size') = size'
size _ (DynamicIndex _) = 1

zero : Expr
zero = FromLiteral {shape = []} {dtype = U64} 0

dynStarts : List Expr -> {shape : _} -> MultiSlice shape -> List Expr
dynStarts idxs {shape} [] = replicate (length shape) zero ++ idxs
dynStarts idxs (DynamicSlice (MkTensor i) _ :: ds) = i :: dynStarts idxs ds
dynStarts idxs (DynamicIndex (MkTensor i) :: ds) = i :: dynStarts idxs ds
dynStarts idxs (_ :: ds) = zero :: dynStarts idxs ds
let out = slice at shape in
{auto 0 inBounds : IsRight out} ->
Tensor (fromRight out) dtype
slice at $ MkTensor x =
let ldiff = length shape `minus` length at
at' = at ++ replicate ldiff All
map = Prelude.map -- help type inference
in MkTensor
$ Reshape (zipWith size shape at') (fromRight $ slice at shape)
$ DynamicSlice (map dynStart at ++ replicate ldiff zero) (zipWith size shape at')
$ Slice (map start at) (zipWith stop shape at') (replicate (length shape) 1) x

where

start : Subset -> Nat
start (Slice from _) = from
start (Index idx) = idx
start _ = 0

stop : Nat -> Subset -> Nat
stop _ (Slice _ to) = to
stop _ (Index idx) = S idx
stop d _ = d

size : Nat -> Subset -> Nat
size _ (Slice from to) = to `minus` from
size _ (Index _) = 1
size _ (DynamicSlice _ size') = size'
size _ (DynamicIndex _) = 1
size d All = d

zero : Expr
zero = FromLiteral {shape = []} {dtype = U64} 0

dynStart : Subset -> Expr
dynStart (DynamicSlice (MkTensor i) _) = i
dynStart (DynamicIndex (MkTensor i)) = i
dynStart _ = zero

||| Concatenate two `Tensor`s along the specfied `axis`. For example,
||| `concat 0 (tensor [[1, 2], [3, 4]]) (tensor [[5, 6]])` and
Expand Down
Loading