Skip to content

Commit af8dbd1

Browse files
committed
fix: spellings to pass spell check
Includes renaming some variables to match updated contributor guidelines.
1 parent f6cc615 commit af8dbd1

18 files changed

+80
-80
lines changed

CODE_OF_CONDUCT.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Project maintainers have the right and responsibility to remove, edit, or reject
3535
Enforcement
3636
-----------
3737

38-
Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible
38+
Instances of abusive, harassing, or otherwise unacceptable behaviour may be reported to the community leaders responsible
3939
for enforcement at [[email protected]](mailto:[email protected]). All complaints will be reviewed and investigated promptly and fairly, and will
4040
result in a response that is deemed necessary and appropriate to the circumstances. The community leaders responsible
4141
for enforcement are obligated to maintain confidentiality with regard to the reporter of an incident.

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Coreax is a library for **coreset algorithms**, written in [Jax](https://jax.rea
1111
A coreset algorithm takes a $n \times d$ data set and reduces it to $m \ll n$ points whilst attempting to preserve the statistical properties of the full data set. Some algorithms return the $m$ points with weights, such that importance can be attributed to each point. These are often chosen from the simplex, i.e. such that they are non-negative and sum to 1.
1212

1313
## Quick example
14-
Here are $n=10,000$ points drawn from six $2$-D Gaussians. The coreset size, which we set, is $m=100$. Run `examples/weighted_herding.py` to replicate.
14+
Here are $n=10,000$ points drawn from six $2$-D Gaussian distributions. The coreset size, which we set, is $m=100$. Run `examples/weighted_herding.py` to replicate.
1515

1616
![](examples/data/coreset_seq/coreset_seq.gif)
1717
![](examples/data/random_seq/random_seq.gif)
@@ -82,7 +82,7 @@ coreset = kernel_herding_refine_block(X, m, k):
8282
```
8383

8484
## Stein kernel herding
85-
We have implemented a version of kernel herding that uses a **Stein kernel**, which targets [kernelised Stein discrepancy (KSD)](https://arxiv.org/abs/1602.03253) rather than MMD. This can often give better integration error in practice, but it can be slower than using a simpler kernel targeting MMD. To use Stein kernel herding, we have to define a continuous approximation to the discerete measure, e.g. using a KDE, or estimate the score function $\nabla \log f_X(\mathbf{x})$ of a continuous PDF from a finite set of samples. In this example, we use a Stein kernel with an inverse multi-quadric base kernel; computing the score function explicitly (score matching coming soon). Again, there are block versions for fitting within GPU memory constraints.
85+
We have implemented a version of kernel herding that uses a **Stein kernel**, which targets [kernelised Stein discrepancy (KSD)](https://arxiv.org/abs/1602.03253) rather than MMD. This can often give better integration error in practice, but it can be slower than using a simpler kernel targeting MMD. To use Stein kernel herding, we have to define a continuous approximation to the discrete measure, e.g. using a KDE, or estimate the score function $\nabla \log f_X(\mathbf{x})$ of a continuous PDF from a finite set of samples. In this example, we use a Stein kernel with an inverse multi-quadric base kernel; computing the score function explicitly (score matching coming soon). Again, there are block versions for fitting within GPU memory constraints.
8686
```python
8787
from coreax.kernel import stein_kernel_pc_imq_element, rbf_grad_log_f_x
8888
from coreax.kernel_herding import stein_kernel_herding_block

coreax/approximation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def anchor_body(
283283
Execute main loop of the ANNchor construction.
284284
285285
:param idx: Loop counter
286-
:param features: Loop updateables
286+
:param features: Loop variables to be updated
287287
:param data: Original :math:`n \times d` dataset
288288
:param kernel_function: Vectorised kernel function on pairs `(X,x)`:
289289
:math:`k: \mathbb{R}^{n \times d} \times \mathbb{R}^d \rightarrow \mathbb{R}^n`

coreax/kernel.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from coreax.util import (
2626
KernelFunction,
2727
KernelFunctionWithGrads,
28-
pdiff,
28+
pairwise_diff,
2929
sq_dist,
3030
sq_dist_pairwise,
3131
)
@@ -115,7 +115,7 @@ def grad_rbf_y(
115115
else:
116116
gram_matrix = jnp.asarray(gram_matrix)
117117

118-
distances = pdiff(x_array, y_array)
118+
distances = pairwise_diff(x_array, y_array)
119119

120120
return distances * gram_matrix[:, :, None] / bandwidth**2
121121

@@ -164,7 +164,7 @@ def grad_pc_imq_y(
164164
gram_matrix = pc_imq(x_array, y_array, bandwidth)
165165
else:
166166
gram_matrix = jnp.asarray(gram_matrix)
167-
mq_array = pdiff(x_array, y_array)
167+
mq_array = pairwise_diff(x_array, y_array)
168168

169169
return gram_matrix[:, :, None] ** 3 * mq_array / scaling
170170

coreax/kernel_herding.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def greedy_body(
4545
Execute main loop of greedy kernel herding.
4646
4747
:param i: Loop counter
48-
:param val: Loop updatables
48+
:param val: Loop variables to be updated
4949
:param X: Original :math:`n \times d` dataset
5050
:param k_vec: Vectorised kernel function on pairs `(X,x)`:
5151
:math:`k: \mathbb{R}^{n \times d} \times \mathbb{R}^d \rightarrow \mathbb{R}^n`
@@ -84,7 +84,7 @@ def stein_greedy_body(
8484
Execute the main loop of greedy Stein herding.
8585
8686
:param i: Loop counter
87-
:param val: Loop updatables
87+
:param val: Loop variables to be updated
8888
:param X: Original :math:`n \times d` dataset
8989
:param k_vec: Vectorised kernel function on pairs ``(X,x,Y,y)``:
9090
:math:`k: \mathbb{R}^{n \times d} \times \mathbb{R}^d \times`
@@ -147,7 +147,7 @@ def kernel_herding_block(
147147
S = jnp.zeros(n_core, dtype=jnp.int32)
148148
K = jnp.zeros((n_core, n))
149149

150-
# Greedly select coreset points
150+
# Greedily select coreset points
151151
body = partial(greedy_body, X=X, k_vec=k_vec, K_mean=K_mean, unique=unique)
152152
S, K, _ = lax.fori_loop(0, n_core, body, (S, K, K_t))
153153
Kbar = K.mean(axis=1)
@@ -208,7 +208,7 @@ def stein_kernel_herding_block(
208208
S = jnp.zeros(n_core, dtype=jnp.int32)
209209
K = jnp.zeros((n_core, n))
210210

211-
# Greedly select coreset points
211+
# Greedily select coreset points
212212
body = partial(
213213
stein_greedy_body,
214214
X=X,
@@ -226,7 +226,7 @@ def stein_kernel_herding_block(
226226

227227

228228
@jit
229-
def fw_linesearch(arg_x_t: int, K: ArrayLike, Ek: ArrayLike) -> Array:
229+
def fw_line_search(arg_x_t: int, K: ArrayLike, Ek: ArrayLike) -> Array:
230230
r"""
231231
Execute Frank-Wolfe line search.
232232
@@ -254,7 +254,7 @@ def herding_body(
254254
Execute body of default herding.
255255
256256
:param i: Loop counter
257-
:param val: Loop updatables
257+
:param val: Loop variables to be updated
258258
:return: Coreset indices, objective, Gram matrix mean and Gram matrix
259259
"""
260260
S, objective, Kbar, K = val
@@ -277,7 +277,7 @@ def greedy_herding_body(
277277
Execute body of Stein thinning.
278278
279279
:param i: Loop counter
280-
:param val: Loop updatables
280+
:param val: Loop variables to be updated
281281
:return: Coreset indices, objective, Gram matrix mean and Gram matrix
282282
"""
283283
S, objective, Kbar, K = val
@@ -300,7 +300,7 @@ def fw_herding_body(
300300
Execute body of Frank-Wolfe herding.
301301
302302
:param i: Loop counter
303-
:param val: Loop updatables
303+
:param val: Loop variables to be updated
304304
:return: Coreset indices, objective, Gram matrix mean and Gram matrix
305305
"""
306306
S, objective, Kbar, K = val
@@ -310,7 +310,7 @@ def fw_herding_body(
310310
K = jnp.asarray(K)
311311
j = objective.argmax()
312312
S = S.at[i].set(j)
313-
rho = fw_linesearch(S[i], K, Kbar)
313+
rho = fw_line_search(S[i], K, Kbar)
314314
objective = objective * (1 - rho) + (Kbar - K[S[i]]) * rho
315315
return S, objective, Kbar, K
316316

@@ -408,8 +408,8 @@ def scalable_herding(
408408
else:
409409
# build a kdtree
410410
kdtree = KDTree(X, leaf_size=size)
411-
_, nindices, nodes, _ = kdtree.get_arrays()
412-
new_indices = [jnp.array(nindices[nd[0] : nd[1]]) for nd in nodes if nd[2]]
411+
_, node_indices, nodes, _ = kdtree.get_arrays()
412+
new_indices = [jnp.array(node_indices[nd[0] : nd[1]]) for nd in nodes if nd[2]]
413413
split_data = [X[n] for n in new_indices]
414414
# k = len(split_data)
415415
# print(n, k, n // k)

coreax/refine.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def refine_body(
8888
Execute main loop of the refine method, :math:`S \rightarrow x`.
8989
9090
:param i: Loop counter
91-
:param S: Loop updatables
91+
:param S: Loop variables to be updated
9292
:param x: Original :math:`n \times d` dataset
9393
:param K_mean: Mean vector over rows for the Gram matrix, a :math:`1 \times n` array
9494
:param K_diag: Gram matrix diagonal, a :math:`1 \times n` array
@@ -209,7 +209,7 @@ def refine_rand_body(
209209
Execute main loop of the random refine method.
210210
211211
:param i: Loop counter
212-
:param val: Loop updatables
212+
:param val: Loop variables to be updated
213213
:param x: Original :math:`n \times d` dataset
214214
:param n_cand: Number of candidates for comparison
215215
:param K_mean: Mean vector over rows for the Gram matrix, a :math:`1 \times n` array
@@ -228,7 +228,7 @@ def refine_rand_body(
228228
cand = random.randint(subkey, (n_cand,), 0, len(x))
229229
# cand = random.choice(subkey, len(x), (n_cand,), replace=False)
230230
comps = comparison_cand(S[i], cand, S, x, K_mean, K_diag, k_pairwise, k_vec)
231-
S = lax.cond(jnp.any(comps > 0), change, nochange, i, S, cand, comps)
231+
S = lax.cond(jnp.any(comps > 0), change, no_change, i, S, cand, comps)
232232

233233
return key, S
234234

@@ -296,7 +296,7 @@ def change(i: int, S: ArrayLike, cand: ArrayLike, comps: ArrayLike) -> Array:
296296

297297

298298
@jit
299-
def nochange(i: int, S: ArrayLike, cand: ArrayLike, comps: ArrayLike) -> Array:
299+
def no_change(i: int, S: ArrayLike, cand: ArrayLike, comps: ArrayLike) -> Array:
300300
r"""
301301
Leave ``S`` unchanged.
302302
@@ -373,7 +373,7 @@ def refine_rev_body(
373373
Execute main loop of the refine method, :math:`x \rightarrow S`.
374374
375375
:param i: Loop counter
376-
:param S: Loop updatables
376+
:param S: Loop variables to be updated
377377
:param x: Original :math:`n \times d` dataset
378378
:param K_mean: Mean vector over rows for the Gram matrix, a :math:`1 \times n` array
379379
:param K_diag: Gram matrix diagonal, a :math:`1 \times n` array
@@ -384,7 +384,7 @@ def refine_rev_body(
384384
:returns: Updated loop variables ``S``
385385
"""
386386
comps = comparison_rev(i, S, x, K_mean, K_diag, k_pairwise, k_vec)
387-
S = lax.cond(jnp.any(comps > 0), change_rev, nochange_rev, i, S, comps)
387+
S = lax.cond(jnp.any(comps > 0), change_rev, no_change_rev, i, S, comps)
388388

389389
return S
390390

@@ -447,7 +447,7 @@ def change_rev(i: int, S: ArrayLike, comps: ArrayLike) -> Array:
447447

448448

449449
@jit
450-
def nochange_rev(i: int, S: ArrayLike, comps: ArrayLike) -> Array:
450+
def no_change_rev(i: int, S: ArrayLike, comps: ArrayLike) -> Array:
451451
r"""
452452
Leave ``S`` unchanged.
453453

coreax/score_matching.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def noise_conditional_loop_body(
159159
Sum objective function with noise perturbations.
160160
161161
Inputs are perturbed by Gaussian random noise to improve performance of score
162-
matching. See :cite:p:`improvedsgm` for details.
162+
matching. See :cite:p:`improved_sgm` for details.
163163
164164
:param i: Loop index
165165
:param obj: Running objective, i.e. the current partial sum
@@ -226,7 +226,7 @@ def loss(params):
226226

227227
def sliced_score_matching(
228228
X: ArrayLike,
229-
rgenerator: Callable,
229+
rand_generator: Callable,
230230
noise_conditioning: bool = True,
231231
use_analytic: bool = False,
232232
M: int = 1,
@@ -246,7 +246,7 @@ def sliced_score_matching(
246246
the score function. Alternative network architectures can be considered.
247247
248248
:param X: The :math:`n \times d` data vectors
249-
:param rgenerator: Distribution sampler (key, shape, dtype) :math:`\rightarrow`
249+
:param rand_generator: Distribution sampler (key, shape, dtype) :math:`\rightarrow`
250250
:class:`~jax.Array`, e.g. distributions in :class:`~jax.random`
251251
:param noise_conditioning: Use the noise conditioning version of score matching,
252252
defaults to True
@@ -255,7 +255,7 @@ def sliced_score_matching(
255255
:param M: The number of random vectors to use per data vector, defaults to 1
256256
:param lr: Optimiser learning rate, defaults to 1e-3
257257
:param epochs: Epochs for training, defaults to 10
258-
:param batch_size: Size of minibatch, defaults to 64
258+
:param batch_size: Size of mini-batch, defaults to 64
259259
:param hidden_dim: The ScoreNetwork hidden dimension, defaults to 128
260260
:param optimiser: The optax optimiser to use, defaults to :func:`~optax.adamw`
261261
:param L: Number of noise models to use in noise conditional score matching,
@@ -280,7 +280,7 @@ def sliced_score_matching(
280280

281281
# random vector setup
282282
k1, k2 = random.split(random.PRNGKey(0))
283-
V = rgenerator(k1, (n, M, d), dtype=float)
283+
V = rand_generator(k1, (n, M, d), dtype=float)
284284

285285
# training setup
286286
state = create_train_state(sn, k2, lr, d, optimiser)

coreax/util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def diff(x: ArrayLike, y: ArrayLike) -> Array:
9797

9898

9999
@jit
100-
def pdiff(x_array: ArrayLike, y_array: ArrayLike) -> Array:
100+
def pairwise_diff(x_array: ArrayLike, y_array: ArrayLike) -> Array:
101101
r"""
102102
Calculate efficient pairwise difference between two arrays of vectors.
103103

coreax/weights.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def calculate_BQ_weights(
3030
Calculate weights from Sequential Bayesian Quadrature (SBQ).
3131
3232
References for this technique can be found in
33-
:cite:p:`huszar2016optimallyweighted`. These are equivalent to the unconstrained
33+
:cite:p:`huszar2016optimally`. These are equivalent to the unconstrained
3434
weighted maximum mean discrepancy (MMD) optimum.
3535
3636
:param x: The original :math:`n \times d` data

documentation/source/references.bib

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ @misc{chatalic2022nystrom
77
primaryClass={{stat.ML}},
88
}
99

10-
@inproceedings{improvedsgm,
10+
@inproceedings{improved_sgm,
1111
title={Improved techniques for training score-based generative models},
1212
author={Song, Yang and Ermon, Stefano},
1313
booktitle={{Advances in Neural Information Processing Systems}},
@@ -25,7 +25,7 @@ @inproceedings{ssm
2525
organization={PMLR}
2626
}
2727

28-
@misc{huszar2016optimallyweighted,
28+
@misc{huszar2016optimally,
2929
title={{Optimally-Weighted Herding is Bayesian Quadrature}},
3030
author={Huszar, Ferenc and Duvenaud, David},
3131
year={2016},

examples/david.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def main(
103103

104104
print("Choosing random subset...")
105105
# choose a random subset of C points from the original image
106-
rpoints = np.random.choice(n, C, replace=False)
106+
rand_points = np.random.choice(n, C, replace=False)
107107

108108
# define a reference kernel to use for comparisons of MMD. We'll use an RBF
109109
def k(x, y):
@@ -113,7 +113,7 @@ def k(x, y):
113113
m = mmd_block(X, X[coreset], k, max_size=1000)
114114

115115
# compute the MMD between X and the random sample
116-
rm = mmd_block(X, X[rpoints], k, max_size=1000).item()
116+
rm = mmd_block(X, X[rand_points], k, max_size=1000).item()
117117

118118
# print the MMDs
119119
print("Random MMD")
@@ -148,9 +148,9 @@ def k(x, y):
148148
# plot the image of randomly sampled points
149149
plt.subplot(1, 3, 3)
150150
plt.scatter(
151-
X[rpoints, 1],
152-
-X[rpoints, 0],
153-
c=X[rpoints, 2],
151+
X[rand_points, 1],
152+
-X[rand_points, 0],
153+
c=X[rand_points, 2],
154154
s=1.0,
155155
cmap="gray",
156156
marker="h",

examples/pounce.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ def k(x, y):
9090
m = mmd_block(X, X[coreset], k, max_size=1000)
9191

9292
# get a random sample of points to compare against
93-
rsample = np.random.choice(N, size=C, replace=False)
93+
rand_sample = np.random.choice(N, size=C, replace=False)
9494
# compute the MMD between X and the random sample
95-
rm = mmd_block(X, X[rsample], k, max_size=1000).item()
95+
rm = mmd_block(X, X[rand_sample], k, max_size=1000).item()
9696

9797
# print the MMDs
9898
print(f"Random MMD: {rm}")

examples/pounce_sm.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def main(directory: Path = Path("../examples/data/pounce")) -> tuple[float, floa
3232
Run the 'pounce' example for video sampling with score matching.
3333
3434
Take a video of a pouncing cat, apply PCA and then generate a coreset using
35-
score matching, in which we train a neural network to approximate the score functon
35+
score matching, in which we train a neural network to approximate the score function
3636
of the underlying distribution. Compare the result from this to a coreset generated
3737
via uniform random sampling. Coreset quality is measured using maximum mean
3838
discrepancy (MMD).
@@ -90,9 +90,9 @@ def k(x, y):
9090
m = mmd_block(X, X[coreset], k, max_size=1000)
9191

9292
# get a random sample of points to compare against
93-
rsample = np.random.choice(N, size=C, replace=False)
93+
rand_sample = np.random.choice(N, size=C, replace=False)
9494
# compute the MMD between X and the random sample
95-
rm = mmd_block(X, X[rsample], k, max_size=1000).item()
95+
rm = mmd_block(X, X[rand_sample], k, max_size=1000).item()
9696

9797
# print the MMDs
9898
print(f"Random MMD: {rm}")

examples/weighted_herding.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def k(x, y):
8181
)
8282

8383
# get a random sample of points to compare against
84-
rsample = np.random.choice(N, size=C, replace=False)
84+
rand_sample = np.random.choice(N, size=C, replace=False)
8585

8686
# the weighted bool turns the coreset weights on or off. If on, a quadratic program
8787
# is invoked to solve the weights' vector. This buys some increase in integration
@@ -103,7 +103,7 @@ def k(x, y):
103103
m = m.item()
104104

105105
# compute the MMD between X and the random sample
106-
rm = mmd_block(X, X[rsample], k, max_size=1000).item()
106+
rm = mmd_block(X, X[rand_sample], k, max_size=1000).item()
107107

108108
# nudge the weights to avoid negative entries for plotting
109109
if weights.min() < 0:
@@ -117,7 +117,7 @@ def k(x, y):
117117
plt.show()
118118

119119
plt.scatter(X[:, 0], X[:, 1], s=2.0, alpha=0.1)
120-
plt.scatter(X[rsample, 0], X[rsample, 1], s=10, color="red")
120+
plt.scatter(X[rand_sample, 0], X[rand_sample, 1], s=10, color="red")
121121
plt.title("Random, m=%d, MMD=%.6f" % (C, rm))
122122
plt.axis("off")
123123

0 commit comments

Comments
 (0)