From a31e39620f39ea6f4c304ec4ed9f1847513d3c03 Mon Sep 17 00:00:00 2001 From: ozan Date: Sun, 20 Sep 2020 13:07:57 +0900 Subject: [PATCH] test state save-resume on python --- python/artan/tests/test_filters.py | 56 +++++++++++++++++++ python/artan/tests/test_mixtures.py | 2 +- .../artan/ml/filter/FilterParams.scala | 2 +- 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/python/artan/tests/test_filters.py b/python/artan/tests/test_filters.py index 40566e5..bb7ff2b 100644 --- a/python/artan/tests/test_filters.py +++ b/python/artan/tests/test_filters.py @@ -180,6 +180,62 @@ def test_ols_equivalence(self): expected, _, _, _ = np.linalg.lstsq(features, y, rcond=None) np.testing.assert_array_almost_equal(state, expected.reshape(2), decimal=5) + def test_batch_save_and_resume(self): + n = 100 + ts = np.arange(0, n) + zs = np.random.normal(0, 1, n) + ts + + split_point = n//2 + initial = zs[:split_point] + remaining = zs[split_point:] + + filter = LinearKalmanFilter()\ + .setMeasurementCol("measurement")\ + .setInitialStateMean( + Vectors.dense([0.0, 0.0]))\ + .setInitialStateCovariance( + Matrices.dense(2, 2, [1, 0, 0, 1]))\ + .setProcessModel( + Matrices.dense(2, 2, [1, 0, 1, 1]))\ + .setProcessNoise( + Matrices.dense(2, 2, [0.01, 0.0, 0.0, 0.01]))\ + .setMeasurementNoise( + Matrices.dense(1, 1, [1]))\ + .setMeasurementModel( + Matrices.dense(1, 2, [1, 0])) + + initial_filter = filter.setInitialStateCovariance( + Matrices.dense(2, 2, [1000.0, 0.0, 0.0, 1000.0])) + + def create_df(m): + return self.spark.createDataFrame( + [(Vectors.dense(m[i]), ) for i in range(len(m))], + ["measurement"]) + + initial_measurements = create_df(initial) + + complete_measurements = create_df(zs) + + initial_state = initial_filter.transform(initial_measurements)\ + .filter(f"stateIndex == {len(initial)}")\ + .select("stateKey", "state") + + complete_state = initial_filter.transform(complete_measurements) \ + .filter(f"stateIndex == {len(zs)}")\ + .select("stateKey", "state") + + restarted_filter = filter\ + .setInitialStateDistributionCol("state") + + remaining_measurements = create_df(remaining)\ + .crossJoin(initial_state) + + restarted_state = restarted_filter.transform(remaining_measurements)\ + .filter(f"stateIndex == {n - split_point}")\ + .select("stateKey", "state") + + assert(restarted_state.collect() == complete_state.collect()) + def test_multiple_model_adaptive_filter(self): n = 100 a = 0.27 diff --git a/python/artan/tests/test_mixtures.py b/python/artan/tests/test_mixtures.py index 0d10202..0a95dbe 100644 --- a/python/artan/tests/test_mixtures.py +++ b/python/artan/tests/test_mixtures.py @@ -180,7 +180,7 @@ def test_batch_pmm(self): assert(mae_weights < 0.1) for i, dist in enumerate(mixture_model.distributions): mae_rate = _mae(dist.rate, self.rates[i]) - assert(mae_rate < 2) + assert(mae_rate < 4) def test_persistance(self): pmm = PoissonMixture() \ diff --git a/src/main/scala/com/github/ozancicek/artan/ml/filter/FilterParams.scala b/src/main/scala/com/github/ozancicek/artan/ml/filter/FilterParams.scala index 5dcbc63..843dc17 100644 --- a/src/main/scala/com/github/ozancicek/artan/ml/filter/FilterParams.scala +++ b/src/main/scala/com/github/ozancicek/artan/ml/filter/FilterParams.scala @@ -559,7 +559,7 @@ private[artan] trait HasInitialStateDistributionCol extends Params { */ final val initialStateDistributionCol: Param[String] = new Param[String]( this, - "initialStateCol", + "initialStateDistributionCol", "Column name for initial state distribution. It should be a struct column with mean and covariance fields" + "mean field should be vector, and covariance field should be matrix")