Skip to content

Commit

Permalink
Add LinkingWindowObservationFilter
Browse files Browse the repository at this point in the history
  • Loading branch information
moeyensj committed Oct 19, 2023
1 parent 8f3cca3 commit 5e3996b
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 2 deletions.
51 changes: 51 additions & 0 deletions thor/observations/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, Iterator

import numpy as np
import pyarrow.compute as pc
import quivr as qv
from adam_core.observations import PointSourceDetections

Expand Down Expand Up @@ -39,6 +40,56 @@ def apply(
...


class LinkingWindowObservationFilter(ObservationFilter):
"""A LinkingWindowObservationFilter reduces the number of observations by
yield observations within a paricular linking window.
"""

def __init__(self, length: int, step: int = 1):
"""
Parameters
----------
length : int
The length of the linking window in days.
step : int, optional
The step size of the linking window in days.
"""
self.length = length
self.step = step

if self.step > self.length:
raise ValueError("step must be less than or equal to length")

def apply(self, observations, test_orbit) -> Iterator["Observations"]:
"""
Apply the filter to a collection of observations.
Parameters
----------
observations : `~thor.observations.Observations`
The observations to filter.
test_orbit : `~thor.orbit.TestOrbit`
The test orbit to use for filtering. Unused for this
filter.
Returns
-------
filtered_observations : `~thor.observations.Observations`
The filtered observations.
"""
times = observations.detections.time.rescale("utc")
days = times.days.unique().sort()
min_day = days[0].as_py()
max_day = days[-1].as_py()
for day in range(min_day, max_day + 1, self.step):
mask = pc.and_(
pc.greater_equal(times.days, day),
pc.less(times.days, day + self.length),
)
yield observations.apply_mask(mask)


class TestOrbitRadiusObservationFilter(ObservationFilter):
"""A TestOrbitRadiusObservationFilter is an ObservationFilter that
gathers observations within a fixed radius of the test orbit's
Expand Down
43 changes: 41 additions & 2 deletions thor/observations/tests/test_filters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import astropy.time
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
Expand All @@ -11,7 +10,7 @@
from adam_core.time import Timestamp

from ...orbit import TestOrbit
from ..filters import TestOrbitRadiusObservationFilter
from ..filters import LinkingWindowObservationFilter, TestOrbitRadiusObservationFilter
from ..observations import Observations


Expand Down Expand Up @@ -130,3 +129,43 @@ def test_orbit_radius_observation_filter(fixed_test_orbit, fixed_observations):
# Should be about pi/4 fraction of the detections (0.785
assert len(have.detections) < 0.80 * len(fixed_observations.detections)
assert len(have.detections) > 0.76 * len(fixed_observations.detections)


def assert_min_max_time(observations, min, max):
assert pc.min(observations.detections.time.days).as_py() == min
assert pc.max(observations.detections.time.days).as_py() == max


def test_LinkingWindowObservationFilter(fixed_test_orbit, fixed_observations):
# Create a LinkingWindowObservationFilter with a window length of 5 days and a step of 5 days
# There should be only one window that covers all the observations
filter = LinkingWindowObservationFilter(length=5, step=5)
filtered_observations = list(filter.apply(fixed_observations, fixed_test_orbit))
assert len(filtered_observations) == 1
assert len(filtered_observations[0]) == 50000
assert_min_max_time(filtered_observations[0], 58849, 58853)

# Create a LinkingWindowObservationFilter with a window length of 1 day and a step of 1 day
# There should be 5 windows that cover all the observations
filter = LinkingWindowObservationFilter(length=1, step=1)
filtered_observations = list(filter.apply(fixed_observations, fixed_test_orbit))
assert len(filtered_observations) == 5
for i in range(5):
assert len(filtered_observations[i]) == 10000
assert_min_max_time(filtered_observations[i], 58849 + i, 58849 + i)

# Create a LinkingWindowObservationFilter with a window length of 3 day and a step of 1 days
# There should be 5 windows that cover all the observations
filter = LinkingWindowObservationFilter(length=3, step=1)
filtered_observations = list(filter.apply(fixed_observations, fixed_test_orbit))
assert len(filtered_observations) == 5
assert len(filtered_observations[0]) == 30000
assert_min_max_time(filtered_observations[0], 58849, 58851)
assert len(filtered_observations[1]) == 30000
assert_min_max_time(filtered_observations[1], 58850, 58852)
assert len(filtered_observations[2]) == 30000
assert_min_max_time(filtered_observations[2], 58851, 58853)
assert len(filtered_observations[3]) == 20000
assert_min_max_time(filtered_observations[3], 58852, 58853)
assert len(filtered_observations[4]) == 10000
assert_min_max_time(filtered_observations[4], 58853, 58853)

0 comments on commit 5e3996b

Please sign in to comment.