diff --git a/thor/observations/filters.py b/thor/observations/filters.py index ea9d9f2d..298bb4ef 100644 --- a/thor/observations/filters.py +++ b/thor/observations/filters.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Iterator import numpy as np +import pyarrow.compute as pc import quivr as qv from adam_core.observations import PointSourceDetections @@ -39,6 +40,56 @@ def apply( ... +class LinkingWindowObservationFilter(ObservationFilter): + """A LinkingWindowObservationFilter reduces the number of observations by + yield observations within a paricular linking window. + + """ + + def __init__(self, length: int, step: int = 1): + """ + Parameters + ---------- + length : int + The length of the linking window in days. + step : int, optional + The step size of the linking window in days. + """ + self.length = length + self.step = step + + if self.step > self.length: + raise ValueError("step must be less than or equal to length") + + def apply(self, observations, test_orbit) -> Iterator["Observations"]: + """ + Apply the filter to a collection of observations. + + Parameters + ---------- + observations : `~thor.observations.Observations` + The observations to filter. + test_orbit : `~thor.orbit.TestOrbit` + The test orbit to use for filtering. Unused for this + filter. + + Returns + ------- + filtered_observations : `~thor.observations.Observations` + The filtered observations. + """ + times = observations.detections.time.rescale("utc") + days = times.days.unique().sort() + min_day = days[0].as_py() + max_day = days[-1].as_py() + for day in range(min_day, max_day + 1, self.step): + mask = pc.and_( + pc.greater_equal(times.days, day), + pc.less(times.days, day + self.length), + ) + yield observations.apply_mask(mask) + + class TestOrbitRadiusObservationFilter(ObservationFilter): """A TestOrbitRadiusObservationFilter is an ObservationFilter that gathers observations within a fixed radius of the test orbit's diff --git a/thor/observations/tests/test_filters.py b/thor/observations/tests/test_filters.py index 1dc9cd2b..7b37140a 100644 --- a/thor/observations/tests/test_filters.py +++ b/thor/observations/tests/test_filters.py @@ -1,4 +1,3 @@ -import astropy.time import numpy as np import pyarrow as pa import pyarrow.compute as pc @@ -11,7 +10,7 @@ from adam_core.time import Timestamp from ...orbit import TestOrbit -from ..filters import TestOrbitRadiusObservationFilter +from ..filters import LinkingWindowObservationFilter, TestOrbitRadiusObservationFilter from ..observations import Observations @@ -130,3 +129,43 @@ def test_orbit_radius_observation_filter(fixed_test_orbit, fixed_observations): # Should be about pi/4 fraction of the detections (0.785 assert len(have.detections) < 0.80 * len(fixed_observations.detections) assert len(have.detections) > 0.76 * len(fixed_observations.detections) + + +def assert_min_max_time(observations, min, max): + assert pc.min(observations.detections.time.days).as_py() == min + assert pc.max(observations.detections.time.days).as_py() == max + + +def test_LinkingWindowObservationFilter(fixed_test_orbit, fixed_observations): + # Create a LinkingWindowObservationFilter with a window length of 5 days and a step of 5 days + # There should be only one window that covers all the observations + filter = LinkingWindowObservationFilter(length=5, step=5) + filtered_observations = list(filter.apply(fixed_observations, fixed_test_orbit)) + assert len(filtered_observations) == 1 + assert len(filtered_observations[0]) == 50000 + assert_min_max_time(filtered_observations[0], 58849, 58853) + + # Create a LinkingWindowObservationFilter with a window length of 1 day and a step of 1 day + # There should be 5 windows that cover all the observations + filter = LinkingWindowObservationFilter(length=1, step=1) + filtered_observations = list(filter.apply(fixed_observations, fixed_test_orbit)) + assert len(filtered_observations) == 5 + for i in range(5): + assert len(filtered_observations[i]) == 10000 + assert_min_max_time(filtered_observations[i], 58849 + i, 58849 + i) + + # Create a LinkingWindowObservationFilter with a window length of 3 day and a step of 1 days + # There should be 5 windows that cover all the observations + filter = LinkingWindowObservationFilter(length=3, step=1) + filtered_observations = list(filter.apply(fixed_observations, fixed_test_orbit)) + assert len(filtered_observations) == 5 + assert len(filtered_observations[0]) == 30000 + assert_min_max_time(filtered_observations[0], 58849, 58851) + assert len(filtered_observations[1]) == 30000 + assert_min_max_time(filtered_observations[1], 58850, 58852) + assert len(filtered_observations[2]) == 30000 + assert_min_max_time(filtered_observations[2], 58851, 58853) + assert len(filtered_observations[3]) == 20000 + assert_min_max_time(filtered_observations[3], 58852, 58853) + assert len(filtered_observations[4]) == 10000 + assert_min_max_time(filtered_observations[4], 58853, 58853)