Skip to content

Commit 7c210a6

Browse files
Multi-sample online least-squares (#74)
* added support for online linear regression * added test * Addressed reviewer's concern * Refactor update method, add documentation * Expand linear regression unit test * Allow multi-sample updates from the update method * Update tests for linear regression Co-authored-by: Kenneth Odoh <[email protected]>
1 parent ccf6e35 commit 7c210a6

File tree

2 files changed

+55
-23
lines changed

2 files changed

+55
-23
lines changed

numpy_ml/linear_models/lm.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,17 @@ def __init__(self, fit_intercept=True):
3232

3333
self._is_fit = False
3434

35-
def update(self, x, y):
35+
def update(self, X, y):
3636
r"""
37-
Incrementally update the least-squares coefficients on a new example
38-
via recursive least-squares (RLS) [1]_ .
37+
Incrementally update the least-squares coefficients for a set of new
38+
examples.
3939
4040
Notes
4141
-----
42-
The RLS algorithm [2]_ is used to efficiently update the regression
43-
parameters as new examples become available. For a new example
44-
:math:`(\mathbf{x}_{t+1}, \mathbf{y}_{t+1})`, the parameter updates are
42+
The recursive least-squares algorithm [1]_ [2]_ is used to efficiently
43+
update the regression parameters as new examples become available. For
44+
a single new example :math:`(\mathbf{x}_{t+1}, \mathbf{y}_{t+1})`, the
45+
parameter updates are
4546
4647
.. math::
4748
@@ -55,33 +56,41 @@ def update(self, x, y):
5556
:math:`\mathbf{X}_{1:t}` and :math:`\mathbf{Y}_{1:t}` are the set of
5657
examples observed from timestep 1 to *t*.
5758
58-
To perform the above update efficiently, the RLS algorithm makes use of
59-
the Sherman-Morrison formula [3]_ to avoid re-inverting the covariance
60-
matrix on each new update.
59+
In the single-example case, the RLS algorithm uses the Sherman-Morrison
60+
formula [3]_ to avoid re-inverting the covariance matrix on each new
61+
update. In the multi-example case (i.e., where :math:`\mathbf{X}_{t+1}`
62+
and :math:`\mathbf{y}_{t+1}` are matrices of `N` examples each), we use
63+
the generalized Woodbury matrix identity [4]_ to update the inverse
64+
covariance. This comes at a performance cost, but is still more
65+
performant than doing multiple single-example updates if *N* is large.
6166
6267
References
6368
----------
6469
.. [1] Gauss, C. F. (1821) _Theoria combinationis observationum
6570
erroribus minimis obnoxiae_, Werke, 4. Gottinge
6671
.. [2] https://en.wikipedia.org/wiki/Recursive_least_squares_filter
6772
.. [3] https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula
73+
.. [4] https://en.wikipedia.org/wiki/Woodbury_matrix_identity
6874
6975
Parameters
7076
----------
71-
x : :py:class:`ndarray <numpy.ndarray>` of shape `(1, M)`
72-
A single example of rank `M`
73-
y : :py:class:`ndarray <numpy.ndarray>` of shape `(1, K)`
74-
A `K`-dimensional target vector for the current example
77+
X : :py:class:`ndarray <numpy.ndarray>` of shape `(N, M)`
78+
A dataset consisting of `N` examples, each of dimension `M`
79+
y : :py:class:`ndarray <numpy.ndarray>` of shape `(N, K)`
80+
The targets for each of the `N` examples in `X`, where each target
81+
has dimension `K`
7582
"""
7683
if not self._is_fit:
7784
raise RuntimeError("You must call the `fit` method before calling `update`")
7885

79-
x, y = np.atleast_2d(x), np.atleast_2d(y)
80-
beta, S_inv = self.beta, self.sigma_inv
86+
X, y = np.atleast_2d(X), np.atleast_2d(y)
8187

82-
X1, Y1 = x.shape[0], y.shape[0]
83-
err_str = f"First dimension of x and y must be 1, but got {X1} and {Y1}"
84-
assert X1 == Y1 == 1, err_str
88+
X1, Y1 = X.shape[0], y.shape[0]
89+
self._update1D(X, y) if X1 == Y1 == 1 else self._update2D(X, y)
90+
91+
def _update1D(self, x, y):
92+
"""Sherman-Morrison update for a single example"""
93+
beta, S_inv = self.beta, self.sigma_inv
8594

8695
# convert x to a design vector if we're fitting an intercept
8796
if self.fit_intercept:
@@ -93,6 +102,22 @@ def update(self, x, y):
93102
# update the model coefficients
94103
beta += S_inv @ x.T @ (y - x @ beta)
95104

105+
def _update2D(self, X, y):
106+
"""Woodbury update for multiple examples"""
107+
beta, S_inv = self.beta, self.sigma_inv
108+
109+
# convert X to a design matrix if we're fitting an intercept
110+
if self.fit_intercept:
111+
X = np.c_[np.ones(X.shape[0]), X]
112+
113+
I = np.eye(X.shape[0])
114+
115+
# update the inverse of the covariance matrix via Woodbury identity
116+
S_inv -= S_inv @ X.T @ np.linalg.pinv(I + X @ S_inv @ X.T) @ X @ S_inv
117+
118+
# update the model coefficients
119+
beta += S_inv @ X.T @ (y - X @ beta)
120+
96121
def fit(self, X, y):
97122
"""
98123
Fit the regression coefficients via maximum likelihood.

numpy_ml/tests/test_linear_regression.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# flake8: noqa
12
import numpy as np
23

34
from sklearn.linear_model import LinearRegression as LinearRegressionGold
@@ -12,7 +13,7 @@ def test_linear_regression(N=10):
1213

1314
i = 1
1415
while i < N + 1:
15-
train_samples = np.random.randint(1, 30)
16+
train_samples = np.random.randint(2, 30)
1617
update_samples = np.random.randint(1, 30)
1718
n_samples = train_samples + update_samples
1819

@@ -37,15 +38,21 @@ def test_linear_regression(N=10):
3738
lr = LinearRegression(fit_intercept=fit_intercept)
3839
lr.fit(X_train, y_train)
3940

41+
do_single_sample_update = np.random.choice([True, False])
42+
4043
# ...then update our model on the examples (X_update, y_update)
41-
for x_new, y_new in zip(X_update, y_update):
42-
lr.update(x_new, y_new)
44+
if do_single_sample_update:
45+
for x_new, y_new in zip(X_update, y_update):
46+
lr.update(x_new, y_new)
47+
else:
48+
lr.update(X_update, y_update)
4349

4450
# check that model predictions match
45-
np.testing.assert_almost_equal(lr.predict(X), lr_gold.predict(X))
51+
np.testing.assert_almost_equal(lr.predict(X), lr_gold.predict(X), decimal=5)
4652

4753
# check that model coefficients match
4854
beta = lr.beta.T[:, 1:] if fit_intercept else lr.beta.T
49-
np.testing.assert_almost_equal(beta, lr_gold.coef_)
55+
np.testing.assert_almost_equal(beta, lr_gold.coef_, decimal=6)
56+
5057
print("\tPASSED")
5158
i += 1

0 commit comments

Comments
 (0)