Skip to content

Commit 744081f

Browse files
committed
implement grid search
1 parent 23d2813 commit 744081f

File tree

11 files changed

+287
-10
lines changed

11 files changed

+287
-10
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
Copyright 2022 Joel Berkeley
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
17+
18+
#include "../xla_builder.h"
19+
20+
extern "C" {
21+
XlaOp* ArgMax(XlaOp& input, int output_type, int axis) {
22+
auto& input_ = reinterpret_cast<xla::XlaOp&>(input);
23+
xla::XlaOp res = xla::ArgMax(input_, (xla::PrimitiveType) output_type, axis);
24+
return reinterpret_cast<XlaOp*>(new xla::XlaOp(res));
25+
}
26+
27+
XlaOp* ArgMin(XlaOp& input, int output_type, int axis) {
28+
auto& input_ = reinterpret_cast<xla::XlaOp&>(input);
29+
xla::XlaOp res = xla::ArgMin(input_, (xla::PrimitiveType) output_type, axis);
30+
return reinterpret_cast<XlaOp*>(new xla::XlaOp(res));
31+
}
32+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
Copyright 2022 Joel Berkeley
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
#include "../xla_builder.h"
17+
18+
extern "C" {
19+
XlaOp* ArgMax(XlaOp& input, int output_type, int axis);
20+
XlaOp* ArgMin(XlaOp& input, int output_type, int axis);
21+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
{--
2+
Copyright 2022 Joel Berkeley
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
--}
16+
module Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.Arithmetic
17+
18+
import System.FFI
19+
20+
import Compiler.Xla.Prim.Util
21+
22+
export
23+
%foreign (libxla "ArgMax")
24+
prim__argMax : GCAnyPtr -> Int -> Int -> PrimIO AnyPtr
25+
26+
export
27+
%foreign (libxla "ArgMin")
28+
prim__argMin : GCAnyPtr -> Int -> Int -> PrimIO AnyPtr
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{--
2+
Copyright 2022 Joel Berkeley
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
--}
16+
module Compiler.Xla.TensorFlow.Compiler.Xla.Client.Lib.Constants
17+
18+
import Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.Lib.Constants
19+
import Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaBuilder
20+
import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData
21+
22+
export
23+
argMax : (HasIO io, Primitive dtype) => XlaOp -> Nat -> io XlaOp
24+
argMax (MkXlaOp input) axis = do
25+
opPtr <- primIO $ argMax input (xlaIdentifier {dtype}) (cast axis)
26+
opPtr <- onCollectAny opPtr XlaOp.delete
27+
pure (MkXlaOp opPtr)
28+
29+
export
30+
argMin : (HasIO io, Primitive dtype) => XlaOp -> Nat -> io XlaOp
31+
argMin (MkXlaOp input) axis = do
32+
opPtr <- primIO $ argMin input (xlaIdentifier {dtype}) (cast axis)
33+
opPtr <- onCollectAny opPtr XlaOp.delete
34+
pure (MkXlaOp opPtr)

src/Model/GaussianProcess.idr

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ export
108108
||| Fit the Gaussian process and noise to the specified data.
109109
export
110110
fit : ConjugateGPRegression features
111-
-> (forall n . Tensor [n] F64 -> Optimizer $ Tensor [n] F64)
111+
-> (forall n . Tensor [n] F64 -> Optimizer [n])
112112
-> Dataset features [1]
113113
-> ConjugateGPRegression features
114114
fit (MkConjugateGPR {p} mkPrior gpParams noise) optimizer (MkDataset x y) =

src/Optimize.idr

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
||| This module contains definitions of function optimizers.
1717
module Optimize
1818

19+
import Literal
1920
import Tensor
2021

2122
||| An `Optimizer` finds the value, in a `Tensor`-valued feature space, which (approximately)
@@ -28,22 +29,61 @@ import Tensor
2829
|||
2930
||| @domain The type of the domain over which to find the optimizer.
3031
public export 0
31-
Optimizer : {default id 0 m : Type -> Type} -> (0 domain : Type) -> Type
32-
Optimizer a = (a -> m $ Tensor [] F64) -> m a
32+
Optimizer : Shape -> Type
33+
Optimizer domain = (Tensor domain F64 -> Tensor [] F64) -> Tensor domain F64
3334

34-
||| Construct an `Optimizer` that implements grid search over a scalar feature space. Grid search
35-
||| approximates the optimum by evaluating the objective over a finite, evenly-spaced grid.
36-
|||
37-
||| **NOTE** This function is not yet implemented.
35+
-- naively, i'd like there to be just one optimizer, but is that possible, and practical?
36+
public export 0
37+
BatchOptimizer : Shape -> Type
38+
BatchOptimizer domain = ({n : _} -> Tensor (n :: domain) F64 -> Tensor [n] F64) -> Tensor domain F64
39+
40+
||| Grid search over a scalar feature space. Grid search approximates the optimum by evaluating the
41+
||| objective over a finite, evenly-spaced grid.
3842
|||
3943
||| @density The density of the grid.
4044
||| @lower The lower (inclusive) bound of the grid.
4145
||| @upper The upper (exclusive) bound of the grid.
4246
export
43-
gridSearch : (density : Tensor [d] U32) ->
47+
gridSearch : {d : _} ->
48+
(density : Vect d Nat) ->
4449
(lower : Tensor [d] F64) ->
4550
(upper : Tensor [d] F64) ->
46-
Optimizer (Tensor [d] F64)
51+
BatchOptimizer [d]
52+
gridSearch {d=Z} _ _ _ _ = fromLiteral []
53+
gridSearch {d=S k} density lower upper f =
54+
let densityAll : Nat
55+
densityAll = product density
56+
57+
prodDims : Tensor [S k] U64 := fromLiteral $ cast $ scanr (*) 1 (tail density)
58+
idxs = fromLiteral {shape=[densityAll]} $ cast $ Vect.range densityAll
59+
densityTensor = broadcast $ fromLiteral {shape=[S k]} {dtype=U64} (cast density)
60+
grid = broadcast {to=[densityAll, S k]} (expand 1 idxs)
61+
`Tensor.div` broadcast {from=[S k]} (cast prodDims)
62+
`Tensor.mod` densityTensor
63+
gridRelative : Tensor [densityAll, S k] F64 = cast grid / cast densityTensor
64+
points = with Tensor.(+)
65+
broadcast lower + broadcast {to=[densityAll, S k]} (upper - lower) * gridRelative
66+
idx = argmin 0 (f points)
67+
in index 0 idx points
68+
69+
||| If `xs` is a vector of exclusive upper bounds for a number of dimensions, this produces a list
70+
||| of all positions in the (higher-dimensional) grid. For example, `grid [3]` is `[[0], [1], [2]]`,
71+
||| and `grid [2, 3]` is
72+
||| ```
73+
||| [
74+
||| [0, 0]
75+
||| , [0, 1]
76+
||| , [0, 2]
77+
||| , [1, 0]
78+
||| , [1, 1]
79+
||| , [1, 2]
80+
||| ]
81+
||| ```
82+
export
83+
grid : (xs : Vect (S n) Nat) -> Vect (product xs) (Vect (S n) Nat)
84+
grid xs =
85+
let prodDims = scanr (*) 1 (tail xs)
86+
in map (\e => zipWith (\x, p => e `div` p `mod` x) xs prodDims) (range (product xs))
4787

4888
||| The limited-memory BFGS (L-BFGS) optimization tactic, see
4989
|||
@@ -58,4 +98,4 @@ gridSearch : (density : Tensor [d] U32) ->
5898
|||
5999
||| @initialPoints The points from which to start optimization.
60100
export
61-
lbfgs : (initialPoints : Tensor [n] F64) -> Optimizer (Tensor [n] F64)
101+
lbfgs : (initialPoints : Tensor [n] F64) -> Optimizer [n]

src/Tensor.idr

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,16 @@ slice at (MkTensor expr) =
369369
size _ (DynamicSlice _ size') = size'
370370
size _ (DynamicIndex _) = 1
371371

372+
namespace Dynamic
373+
export
374+
index :
375+
Primitive dtype =>
376+
(axis : Nat) ->
377+
{auto 0 axisInBounds : InBounds axis shape} ->
378+
Tensor [] U64 ->
379+
Tensor shape dtype ->
380+
Tensor (deleteAt axis shape) dtype
381+
372382
||| Concatenate two `Tensor`s along the specfied `axis`. For example,
373383
||| `concat 0 (fromLiteral [[1, 2], [3, 4]]) (fromLiteral [[5, 6]])` and
374384
||| `concat 1 (fromLiteral [[3], [6]]) fromLiteral ([[4, 5], [7, 8]])` are both
@@ -1161,6 +1171,28 @@ namespace Monoid
11611171
Monoid (Tensor shape dtype) using Semigroup.Max where
11621172
neutral = fill (- 1.0 / 0.0)
11631173

1174+
export
1175+
argmin :
1176+
(Primitive outType, Primitive.Num dtype) =>
1177+
(axis : Nat) ->
1178+
{auto 0 ok : InBounds axis shape} ->
1179+
Tensor shape dtype ->
1180+
Tensor [] outType
1181+
1182+
export
1183+
argmax :
1184+
(Primitive outType, Primitive.Num dtype) =>
1185+
(axis : Nat) ->
1186+
{auto 0 ok : InBounds axis shape} ->
1187+
Tensor shape dtype ->
1188+
Tensor [] outType
1189+
1190+
export
1191+
div : Primitive.Integral dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1192+
1193+
export
1194+
mod : Primitive.Integral dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1195+
11641196
---------------------------- other ----------------------------------
11651197

11661198
||| Cholesky decomposition. Computes the lower triangular matrix `L` from the symmetric, positive

src/Util.idr

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@ namespace Vect
4949
let lengthOK = lengthCorrect xs
5050
in rewrite sym lengthOK in zip (range (length xs)) (rewrite lengthOK in xs)
5151

52+
||| Like `foldr`, but returns a vector of all intermediate accumulated states. The first
53+
||| state appears last in the result, and the last state appears first.
54+
public export
55+
scanr : (elem -> res -> res) -> res -> Vect len elem -> Vect (S len) res
56+
scanr _ q [] = [q]
57+
scanr f q (x :: xs) = let qs'@(q' :: _) = scanr f q xs in f x q' :: qs'
58+
5259
namespace List
5360
||| All numbers from `0` to `n - 1` inclusive, in increasing order.
5461
|||

test.ipkg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ modules =
2020
Unit.Util.TestHashable,
2121
Unit.TestDistribution,
2222
Unit.TestLiteral,
23+
Unit.TestOptimize,
2324
Unit.TestTensor,
2425
Unit.TestUtil,
2526

test/Main.idr

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import Utils.TestComparison
2424
import Unit.Model.TestKernel
2525
import Unit.Util.TestHashable
2626
import Unit.TestDistribution
27+
import Unit.TestOptimize
2728
import Unit.TestTensor
2829
import Unit.TestLiteral
2930
import Unit.TestUtil
@@ -37,6 +38,7 @@ main = test [
3738
, Unit.TestUtil.group
3839
, Unit.TestLiteral.group
3940
, Unit.TestTensor.group
41+
, Unit.TestOptimize.group
4042
, Unit.TestDistribution.group
4143
, Unit.Model.TestKernel.group
4244
]

test/Unit/TestOptimize.idr

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
{--
2+
Copyright 2022 Joel Berkeley
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
--}
16+
module Unit.TestOptimize
17+
18+
import Literal
19+
import Optimize
20+
import Tensor
21+
22+
import Utils.Cases
23+
import Utils.Comparison
24+
25+
gridSearch : Property
26+
gridSearch = fixedProperty $ do
27+
let lower = fromLiteral [-1.0, -1.0, -1.0]
28+
upper = fromLiteral [1.0, 1.0, 1.0]
29+
30+
f : {n : _} -> Tensor [n, 3] F64 -> Tensor [n] F64
31+
f x = reduce @{Sum} 1 (x ^ fill 2.0)
32+
33+
gridSearch [100, 100, 100] lower upper f ===# fromLiteral [0.0, 0.0, 0.0]
34+
35+
grid' : Property
36+
grid' = fixedProperty $ do
37+
grid [0] === []
38+
grid [0, 2] === []
39+
grid [1] === [[0]]
40+
grid [3] === [[0], [1], [2]]
41+
grid [2, 3] === [
42+
[0, 0],
43+
[0, 1],
44+
[0, 2],
45+
[1, 0],
46+
[1, 1],
47+
[1, 2]
48+
]
49+
grid [2, 4, 3] === [
50+
[0, 0, 0],
51+
[0, 0, 1],
52+
[0, 0, 2],
53+
[0, 1, 0],
54+
[0, 1, 1],
55+
[0, 1, 2],
56+
[0, 2, 0],
57+
[0, 2, 1],
58+
[0, 2, 2],
59+
[0, 3, 0],
60+
[0, 3, 1],
61+
[0, 3, 2],
62+
[1, 0, 0],
63+
[1, 0, 1],
64+
[1, 0, 2],
65+
[1, 1, 0],
66+
[1, 1, 1],
67+
[1, 1, 2],
68+
[1, 2, 0],
69+
[1, 2, 1],
70+
[1, 2, 2],
71+
[1, 3, 0],
72+
[1, 3, 1],
73+
[1, 3, 2]
74+
]
75+
76+
export covering
77+
group : Group
78+
group = MkGroup "Optimize" $ [
79+
("grid search", gridSearch)
80+
]

0 commit comments

Comments
 (0)