Skip to content

Commit

Permalink
Start getting examples working again
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Jan 6, 2024
1 parent feba99f commit eb351ab
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 46 deletions.
30 changes: 15 additions & 15 deletions lib/parser.dx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def from_ordinal_exc(n|Ix, i:Nat) -> {Except} n =

# TODO: allow this to happen in-place
# TODO: if it takes too long to make that possible, start with a bounded version
def push(ref:Ref h (List a), x:a) -> {State h} () given (h, a|Data) =
def push(ref:Ref h (List a), x:a) -> {State h} () given (h:Heap, a|Data) =
l = get ref
ref := l <> AsList _ [x]

Expand All @@ -28,11 +28,11 @@ struct ParserHandle(h:Heap) =
enum Parser(a:Type) =
MkParser((given(h:Heap), ParserHandle h )-> {Except, State h} a)

def parse(handle:ParserHandle h, parser:Parser a) -> {Except, State h} a given (a, h) =
def parse(handle:ParserHandle h, parser:Parser a) -> {Except, State h} a given (a:Type, h:Heap) =
MkParser f = parser
f handle

def run_parser_partial(s:String, parser:Parser a) -> Maybe a given (a) =
def run_parser_partial(s:String, parser:Parser a) -> Maybe a given (a:Type) =
MkParser f = parser
with_state 0 \pos.
catch $ \.
Expand All @@ -49,17 +49,17 @@ def p_char(c:Char) -> Parser () = MkParser \h.
def p_eof() ->> Parser () = MkParser \h.
assert $ get h.offset >= list_length h.input

def (<|>)(p1:Parser a, p2:Parser a) -> Parser a given (a) = MkParser \h.
def (<|>)(p1:Parser a, p2:Parser a) -> Parser a given (a:Type) = MkParser \h.
curPos = get h.offset
case catch \. parse h p1 of
Nothing ->
assert $ curPos == get h.offset
parse h p2
Just ans -> ans

def returnP(x:a) -> Parser a given (a) = MkParser \_. x
def returnP(x:a) -> Parser a given (a:Type) = MkParser \_. x

def run_parser(s:String, parser:Parser a) -> Maybe a given (a) =
def run_parser(s:String, parser:Parser a) -> Maybe a given (a:Type) =
run_parser_partial s $ MkParser \h.
ans = parse h parser
_ = parse h p_eof
Expand All @@ -79,7 +79,7 @@ def parse_anything_but(c:Char) -> Parser Char =
h.offset := i + 1
c'

def try(parser:Parser a) -> Parser a given (a) = MkParser \h.
def try(parser:Parser a) -> Parser a given (a:Type) = MkParser \h.
savedPos = get h.offset
ans = catch \. parse h parser
case ans of
Expand All @@ -93,14 +93,14 @@ def try(parser:Parser a) -> Parser a given (a) = MkParser \h.
def parse_digit() ->> Parser Int = try $ MkParser \h.
c = parse h $ parse_any
i = w8_to_i c - 48
assert $ 0 <= i && i < 10
assert $ (0::Int) <= i && i < 10
i

def optional(p:Parser a) -> Parser (Maybe a) given (a) =
def optional(p:Parser a) -> Parser (Maybe a) given (a:Type) =
(MkParser \h. Just (parse h p)) <|> returnP Nothing

def parse_many(parser:Parser a) -> Parser (List a) given (a|Data) = MkParser \h.
yield_state (AsList _ []) \results.
yield_state mempty \results.
iter \_.
maybeVal = parse h $ optional parser
case maybeVal of
Expand All @@ -118,7 +118,7 @@ def parse_some(parser:Parser a) -> Parser (List a) given (a|Data) =
def parse_unsigned_int() ->> Parser Int = MkParser \h.
AsList(_, digits) = parse h $ parse_many parse_digit
yield_state 0 \ref.
for i. ref := 10 * get ref + digits[i]
each digits \digit. ref := 10 * get ref + digit

def parse_int() ->> Parser Int = MkParser \h.
negSign = parse h $ optional $ p_char '-'
Expand All @@ -127,18 +127,18 @@ def parse_int() ->> Parser Int = MkParser \h.
Nothing -> x
Just () -> (-1) * x

def bracketed(l:Parser (), r:Parser (), body:Parser a) -> Parser a given (a) =
def bracketed(l:Parser (), r:Parser (), body:Parser a) -> Parser a given (a:Type) =
MkParser \h.
_ = parse h l
ans = parse h body
_ = parse h r
ans

def parens(parser:Parser a) -> Parser a given (a) =
def parens(parser:Parser a) -> Parser a given (a:Type) =
bracketed (p_char '(') (p_char ')') parser

def split(space:Char, s:String) -> List String =
def trailing_spaces(space:Parser (), body:Parser a) -> Parser a given (a) =
def trailing_spaces(space:Parser (), body:Parser a) -> Parser a given (a:Type) =
MkParser \h.
ans = parse h body
_ = parse h $ parse_many space
Expand All @@ -148,4 +148,4 @@ def split(space:Char, s:String) -> List String =
parse h $ parse_many (trailing_spaces (p_char space) (parse_some (parse_anything_but space)))
case run_parser s split_parser of
Just l -> l
Nothing -> AsList _ []
Nothing -> mempty
54 changes: 28 additions & 26 deletions lib/stats.dx
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ actual value being represented (the raw probability) can be computed with
use `ls_sum`, which applies the standard max-sweep log-sum-exp trick directly,
rather than relying on monoid reduction using `add`.

struct LogSpace(a) =
struct LogSpace(a:Type) =
log : a

def Exp(x:a) -> LogSpace a given (a) = LogSpace x
def Exp(x:a) -> LogSpace a given (a:Type) = LogSpace x

instance Mul(LogSpace f) given (f|Add)
def (*)(x, y) = Exp $ x.log + y.log
Expand All @@ -38,14 +38,14 @@ instance Arbitrary(LogSpace Float)
def is_infinite(x:f) -> Bool given (f|Fractional|Sub|Mul|Ord) =
# Note: According to this function, nan is not infinite.
# Todo: Make polymorphic versions of these and put them in the prelude.
infinity = divide one zero
neg_infinity = zero - infinity
infinity : f = divide one zero
neg_infinity : f = zero - infinity
x == infinity || x == neg_infinity

def log_add_exp(la:f, lb:f) -> f
given (f|Floating|Add|Sub|Mul|Fractional|Ord) =
infinity = (divide one zero)
neg_infinity = zero - infinity
infinity : f = divide(one, zero)
neg_infinity : f = zero - infinity
if la == infinity && lb == infinity
then infinity
else if la == neg_infinity && lb == neg_infinity
Expand All @@ -66,27 +66,29 @@ def ln(x:LogSpace f) -> f given (f|Floating) = x.log

def log_sum_exp(xs:n=>f) -> f
given(n|Ix, f|Fractional|Sub|Floating|Mul|Ord) =
m_raw = maximum xs
m = case is_infinite m_raw of
m_raw = maximum(xs)
m = case is_infinite(m_raw) of
False -> m_raw
True -> zero
m + (log $ sum for i. exp (xs[i] - m))
m + (log $ sum $ each xs \x. exp(x) - m)

def ls_sum(x:n=>LogSpace f) -> LogSpace f
given (n|Ix, f|Fractional|Sub|Floating|Mul|Ord) =
lx = map ln x
lx = each(x, ln)
Exp $ log_sum_exp lx

'## Probability distributions
Simulation and evaluation of [probability distributions](https://en.wikipedia.org/wiki/Probability_distribution). Probability distributions which can be sampled should implement `Random`, and those which can be evaluated should implement the `Dist` interface. Distributions over an ordered space (such as typical *univariate* distributions) should also implement `OrderedDist`.

interface Random(d, a)
# TODO: use an associated type for the `a` parameter
interface Random(d:Type, a:Type)
draw : (d, Key) -> a # function for random draws

interface Dist(d, a, f)
# TODO: use an associated type for the `a` parameter
interface Dist(d:Type, a:Type, f:Type)
density : (d, a) -> LogSpace f # either the density function or mass function

interface OrderedDist(d, a, f, given () (Dist d a f))
interface OrderedDist(d:Type, a:Type, f:Type, given () (Dist d a f))
cumulative : (d, a) -> LogSpace f # the cumulative distribution function (CDF)
survivor : (d, a) -> LogSpace f # the survivor function (complement of CDF)
quantile : (d, f) -> a # the quantile function (inverse CDF)
Expand Down Expand Up @@ -131,7 +133,7 @@ struct Binomial =

instance Random(Binomial, Nat)
def draw(d, k) =
sum $ map b_to_n (rand_vec d.trials (\k_. draw(Bernoulli(d.prob), k_)) k)
sum $ each (rand_vec d.trials (\k_. draw(a=Bool, Bernoulli(d.prob), k_)) k) b_to_n

instance Dist(Binomial, Nat, Float)
def density(d, x) =
Expand All @@ -157,13 +159,13 @@ instance Dist(Binomial, Nat, Float)
instance OrderedDist(Binomial, Nat, Float)
def cumulative(d, x) =
xp1:Nat = x + 1
ls_sum $ for i:(Fin xp1). density d (ordinal i)
ls_sum $ for i:(Fin xp1). density(f=Float, d, ordinal i)
def survivor(d, x) =
tmx = d.trials -| x
ls_sum $ for i:(Fin tmx). density d (x + 1 + ordinal i)
ls_sum $ for i:(Fin tmx). density(f=Float, d, x + 1 + ordinal i)
def quantile(d, q) =
tp1:Nat = d.trials + 1
lpdf = for i:(Fin tp1). ln $ density d (ordinal i)
lpdf = for i:(Fin tp1). ln $ density(f=Float, d, ordinal i)
cdf = cdf_for_categorical lpdf
mi = search_sorted cdf q
ordinal $ from_just $ left_fence mi
Expand Down Expand Up @@ -257,13 +259,13 @@ struct Poisson =
instance Random(Poisson, Nat)
def draw(d, k) =
exp_neg_rate = exp (-d.rate)
[current_k, next_k] = split_key k
[current_k, next_k] = split_key(n=2, k)
yield_state 0 \ans.
with_state (rand current_k) \p. with_state next_k \k'.
while \.
if get p > exp_neg_rate
then
[ck, nk] = split_key (get k')
[ck, nk] = split_key(n=2, get k')
p := (get p) * rand ck
ans := (get ans) + 1
k' := nk
Expand All @@ -283,16 +285,16 @@ instance Dist(Poisson, Nat, Float)
instance OrderedDist(Poisson, Nat, Float)
def cumulative(d, x) =
xp1:Nat = x + 1
ls_sum $ for i:(Fin xp1). density d (ordinal i)
ls_sum $ for i:(Fin xp1). density(f=Float, d, ordinal i)
def survivor(d, x) =
xp1:Nat = x + 1
cdf = ls_sum $ for i:(Fin xp1). density d (ordinal i)
cdf = ls_sum $ for i:(Fin xp1). density(f=Float, d, ordinal i)
Exp $ log1p (-ls_to_f cdf)
def quantile(d, q) =
yield_state (0::Nat) \x.
with_state 0.0 \cdf.
while \.
cdf := (get cdf) + ls_to_f (density d (get x))
cdf := (get cdf) + ls_to_f (density(f=Float, d ,get x))
if (get cdf) > q
then
False
Expand Down Expand Up @@ -340,7 +342,7 @@ Some data summary functions. Note that `mean` is provided by the prelude.

def mean_and_variance(xs:n=>a) -> (a, a) given (n|Ix, a|VSpace|Mul) =
m = mean xs
ss = sum for i. sq (xs[i] - m)
ss = sum $ each xs \x. sq(x - m)
v = ss / (n_to_f (size n) - 1)
(m, v)

Expand All @@ -353,7 +355,7 @@ def std_dev(xs:n=>a) -> a given (n|Ix, a|VSpace|Mul|Floating) =
def covariance(xs:n=>a, ys:n=>a) -> a given (n|Ix, a|VSpace|Mul) =
xm = mean xs
ym = mean ys
ss = sum for i. (xs[i] - xm) * (ys[i] - ym)
ss = sum for i:n. (xs[i] - xm) * (ys[i] - ym)
ss / (n_to_f (size n) - 1)

def correlation(xs:n=>a, ys:n=>a) -> a
Expand All @@ -364,8 +366,8 @@ def correlation(xs:n=>a, ys:n=>a) -> a

def mean_and_variance_matrix(xs:n=>d=>a) -> (d=>a, d=>d=>a)
given (n|Ix, d|Ix, a|Mul|VSpace) =
xsMean:d=>a = (for i. sum for j. xs[j,i]) / n_to_f (size n)
xsCov:d=>d=>a = (for i i'. sum for j.
xsMean:d=>a = (for i:d. sum for j:n. xs[j,i]) / n_to_f (size n)
xsCov:d=>d=>a = (for i:d i':d. sum for j:n.
(xs[j,i'] - xsMean[i']) *
(xs[j,i] - xsMean[i] ) ) / (n_to_f (size n) - 1)
(xsMean, xsCov)
Expand Down
9 changes: 7 additions & 2 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,13 @@ example-names := \
regression brownian_motion particle-swarm-optimizer \
ode-integrator mcmc ctc raytrace particle-filter \
fluidsim \
sgd psd kernelregression nn \
sgd nn \
quaternions manifold-gradients schrodinger tutorial \
latex linear-maps dither mcts md bfgs
latex dither mcts md bfgs

# TODO: re-enable
# fft vega-plotting
# examples depending on linalg: linear-maps, psd, kernelregression

# Only test levenshtein-distance on Linux, because MacOS ships with a
# different (apparently _very_ different) word list.
Expand Down Expand Up @@ -423,6 +425,9 @@ lib-files = $(filter-out lib/prelude.dx,$(wildcard lib/*.dx))
pages-lib-files = $(patsubst %.dx,pages/%.html,$(lib-files))
static-files = $(static-names:%=pages/static/%)

serve-docs:
cd pages && python3 -m http.server

docs: $(static-files) pages-prelude $(pages-doc-files) $(pages-example-files) $(pages-lib-files) $(slow-pages) pages/index.md

pages/static/%: static/%
Expand Down
6 changes: 3 additions & 3 deletions static/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type HighlightType = "HighlightGroup" | "HighlightLeaf" | "HighlightError" | "H
| "HighlightBinder" | "HighlightOcc"
type Highlight = [SrcId, HighlightType]

type HsMaybe<T> = {tag:"Just"; contents:T} | {tag: "Nothing"}
type HsMaybe<T> = T | null
type HsOverwrite<T> = {tag:"OverwriteWith"; contents:T} | {tag: "NoChange"}

type FocusMap = Map<LexemeId, SrcId>
Expand Down Expand Up @@ -249,8 +249,8 @@ function extendCellOutput(cell: Cell, outputs:HsRenderedOutput[]) {
break
case "RenderedError":
const [maybeSrcId, errMsg] = output.contents
if (maybeSrcId.tag == "Just") {
const node : TreeNode = cell.treeMap.get(maybeSrcId.contents) ?? oops()
if (maybeSrcId !== null) {
const node : TreeNode = cell.treeMap.get(maybeSrcId) ?? oops()
highlightTreeNode(false, node, "HighlightError")}
addErrResult(cell, errMsg)
break
Expand Down

0 comments on commit eb351ab

Please sign in to comment.