Skip to content

Commit

Permalink
add rls ols equivalence test
Browse files Browse the repository at this point in the history
  • Loading branch information
ozancicek committed Jan 25, 2020
1 parent 499e9ac commit 46121eb
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
*.iws
*.log
.idea
__pycache__
*.pyc
.coverage
out
target/
dist/*
Expand Down
34 changes: 33 additions & 1 deletion python/artan/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,47 @@

class RLSTests(ReusedSparkTestCase):

def test_rls(self):
np.random.seed(0)

def test_simple_rls(self):
df = self.spark.createDataFrame(
[(1.0, Vectors.dense(0.0, 5.0)),
(0.0, Vectors.dense(1.0, 2.0)),
(1.0, Vectors.dense(2.0, 1.0)),
(0.0, Vectors.dense(3.0, 3.0)), ], ["label", "features"])

rls = RecursiveLeastSquaresFilter(2)

model = rls.transform(df).filter("stateIndex=4").collect()
state = model[0].state.values

expected = np.array([5.31071176e-09, 1.53846148e-01])
np.testing.assert_array_almost_equal(state, expected)

def test_ols_equivalence(self):
# Simple ols problem
# y = a * x + b + r
# Where r ~ N(0, 1)
n = 40
a = 0.5
b = 2
x = np.arange(0, n)
r = np.random.normal(0, 1, n)
y = a * x + b + r
features = x.reshape(n, 1)
features = np.concatenate([features, np.ones_like(features)], axis=1)

df = self.spark.createDataFrame(
[(float(y[i]), Vectors.dense(features[i])) for i in range(n)], ["label", "features"])

# set high regularization matrix factor to get close to OLS solution
rls = RecursiveLeastSquaresFilter(2)\
.setInitialEstimate(Vectors.dense([1.0, 1.0]))\
.setRegularizationMatrixFactor(10E6)

model = rls.transform(df)
state = model.filter("stateIndex = {}".format(n)).collect()[0].state.values

# Check equivalence with least squares solution with numpy
expected, _, _, _ = np.linalg.lstsq(features, y)
np.testing.assert_array_almost_equal(state, expected)

0 comments on commit 46121eb

Please sign in to comment.