Skip to content

Commit

Permalink
test state save-resume on python
Browse files Browse the repository at this point in the history
  • Loading branch information
ozancicek committed Sep 20, 2020
1 parent 0eab203 commit a31e396
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 2 deletions.
56 changes: 56 additions & 0 deletions python/artan/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,62 @@ def test_ols_equivalence(self):
expected, _, _, _ = np.linalg.lstsq(features, y, rcond=None)
np.testing.assert_array_almost_equal(state, expected.reshape(2), decimal=5)

def test_batch_save_and_resume(self):
n = 100
ts = np.arange(0, n)
zs = np.random.normal(0, 1, n) + ts

split_point = n//2
initial = zs[:split_point]
remaining = zs[split_point:]

filter = LinearKalmanFilter()\
.setMeasurementCol("measurement")\
.setInitialStateMean(
Vectors.dense([0.0, 0.0]))\
.setInitialStateCovariance(
Matrices.dense(2, 2, [1, 0, 0, 1]))\
.setProcessModel(
Matrices.dense(2, 2, [1, 0, 1, 1]))\
.setProcessNoise(
Matrices.dense(2, 2, [0.01, 0.0, 0.0, 0.01]))\
.setMeasurementNoise(
Matrices.dense(1, 1, [1]))\
.setMeasurementModel(
Matrices.dense(1, 2, [1, 0]))

initial_filter = filter.setInitialStateCovariance(
Matrices.dense(2, 2, [1000.0, 0.0, 0.0, 1000.0]))

def create_df(m):
return self.spark.createDataFrame(
[(Vectors.dense(m[i]), ) for i in range(len(m))],
["measurement"])

initial_measurements = create_df(initial)

complete_measurements = create_df(zs)

initial_state = initial_filter.transform(initial_measurements)\
.filter(f"stateIndex == {len(initial)}")\
.select("stateKey", "state")

complete_state = initial_filter.transform(complete_measurements) \
.filter(f"stateIndex == {len(zs)}")\
.select("stateKey", "state")

restarted_filter = filter\
.setInitialStateDistributionCol("state")

remaining_measurements = create_df(remaining)\
.crossJoin(initial_state)

restarted_state = restarted_filter.transform(remaining_measurements)\
.filter(f"stateIndex == {n - split_point}")\
.select("stateKey", "state")

assert(restarted_state.collect() == complete_state.collect())

def test_multiple_model_adaptive_filter(self):
n = 100
a = 0.27
Expand Down
2 changes: 1 addition & 1 deletion python/artan/tests/test_mixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_batch_pmm(self):
assert(mae_weights < 0.1)
for i, dist in enumerate(mixture_model.distributions):
mae_rate = _mae(dist.rate, self.rates[i])
assert(mae_rate < 2)
assert(mae_rate < 4)

def test_persistance(self):
pmm = PoissonMixture() \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ private[artan] trait HasInitialStateDistributionCol extends Params {
*/
final val initialStateDistributionCol: Param[String] = new Param[String](
this,
"initialStateCol",
"initialStateDistributionCol",
"Column name for initial state distribution. It should be a struct column with mean and covariance fields" +
"mean field should be vector, and covariance field should be matrix")

Expand Down

0 comments on commit a31e396

Please sign in to comment.