From 46121eb50b45527843f4952bc1573f6a4af1f740 Mon Sep 17 00:00:00 2001 From: ozan Date: Sat, 25 Jan 2020 17:08:02 +0900 Subject: [PATCH] add rls ols equivalence test --- .gitignore | 3 +++ python/artan/tests/test_filters.py | 34 +++++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 77bb9aa..aa4abab 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,9 @@ *.iws *.log .idea +__pycache__ +*.pyc +.coverage out target/ dist/* diff --git a/python/artan/tests/test_filters.py b/python/artan/tests/test_filters.py index c74aadc..28c49e9 100644 --- a/python/artan/tests/test_filters.py +++ b/python/artan/tests/test_filters.py @@ -23,15 +23,47 @@ class RLSTests(ReusedSparkTestCase): - def test_rls(self): + np.random.seed(0) + + def test_simple_rls(self): df = self.spark.createDataFrame( [(1.0, Vectors.dense(0.0, 5.0)), (0.0, Vectors.dense(1.0, 2.0)), (1.0, Vectors.dense(2.0, 1.0)), (0.0, Vectors.dense(3.0, 3.0)), ], ["label", "features"]) + rls = RecursiveLeastSquaresFilter(2) + model = rls.transform(df).filter("stateIndex=4").collect() state = model[0].state.values + expected = np.array([5.31071176e-09, 1.53846148e-01]) + np.testing.assert_array_almost_equal(state, expected) + + def test_ols_equivalence(self): + # Simple ols problem + # y = a * x + b + r + # Where r ~ N(0, 1) + n = 40 + a = 0.5 + b = 2 + x = np.arange(0, n) + r = np.random.normal(0, 1, n) + y = a * x + b + r + features = x.reshape(n, 1) + features = np.concatenate([features, np.ones_like(features)], axis=1) + + df = self.spark.createDataFrame( + [(float(y[i]), Vectors.dense(features[i])) for i in range(n)], ["label", "features"]) + + # set high regularization matrix factor to get close to OLS solution + rls = RecursiveLeastSquaresFilter(2)\ + .setInitialEstimate(Vectors.dense([1.0, 1.0]))\ + .setRegularizationMatrixFactor(10E6) + + model = rls.transform(df) + state = model.filter("stateIndex = {}".format(n)).collect()[0].state.values + # Check equivalence with least squares solution with numpy + expected, _, _, _ = np.linalg.lstsq(features, y) np.testing.assert_array_almost_equal(state, expected)