Skip to content

Commit

Permalink
consistent usage of propagator_class across precovery fns
Browse files Browse the repository at this point in the history
  • Loading branch information
ntellis committed Sep 24, 2024
1 parent 88b47cf commit 399f267
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 19 deletions.
15 changes: 7 additions & 8 deletions precovery/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def precover_many(
allow_version_mismatch: bool = False,
datasets: Optional[set[str]] = None,
n_workers: int = multiprocessing.cpu_count(),
propagator: Optional[Type[Propagator]] = None,
propagator_class: Optional[Type[Propagator]] = None,
) -> Tuple[PrecoveryCandidates, FrameCandidates]:
"""
Run a precovery search algorithm against many orbits at once.
Expand All @@ -38,7 +38,7 @@ def precover_many(
window_size,
allow_version_mismatch,
datasets,
propagator,
propagator_class,
)
for o in orbits
]
Expand Down Expand Up @@ -72,14 +72,13 @@ def precover_worker(
window_size: int = 7,
allow_version_mismatch: bool = False,
datasets: Optional[set[str]] = None,
propagator: Optional[Type[Propagator]] = None,
propagator_class: Optional[Type[Propagator]] = None,
) -> Tuple[PrecoveryCandidates, FrameCandidates]:
"""
Wraps the precover function to return the orbit_id for mapping.
"""

# initialize our propagator
propagator_instance = propagator() if propagator is not None else None
precovery_candidates, frame_candidates = precover(
orbit,
database_directory,
Expand All @@ -89,7 +88,7 @@ def precover_worker(
window_size,
allow_version_mismatch,
datasets,
propagator=propagator_instance,
propagator_class=propagator_class,
)

return (
Expand All @@ -107,7 +106,7 @@ def precover(
window_size: int = 7,
allow_version_mismatch: bool = False,
datasets: Optional[set[str]] = None,
propagator: Optional[Type[Propagator]] = None,
propagator_class: Optional[Type[Propagator]] = None,
) -> Tuple[PrecoveryCandidates, FrameCandidates]:
"""
Connect to database directory and run precovery for the input orbit.
Expand Down Expand Up @@ -138,7 +137,7 @@ def precover(
Allows using a precovery db version that does not match the library version.
datasets : set[str], optional
Filter down searches to only scan selected datasets
propagator : Type[Propagator], optional
propagator_class : Type[Propagator], optional
An adam_core.propagator.Propagator subclass to use for propagating the orbit.
Returns
Expand All @@ -163,7 +162,7 @@ def precover(
end_mjd=end_mjd,
window_size=window_size,
datasets=datasets,
propagator=propagator,
propagator_class=propagator_class,
)

return precovery_candidates, frame_candidates
6 changes: 3 additions & 3 deletions precovery/precovery_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def precover(
end_mjd: Optional[float] = None,
window_size: int = 7,
datasets: Optional[set[str]] = None,
propagator: Optional[Type[Propagator]] = None,
propagator_class: Optional[Type[Propagator]] = None,
) -> Tuple[PrecoveryCandidates, FrameCandidates]:
"""
Find observations which match orbit in the database. Observations are
Expand Down Expand Up @@ -446,10 +446,10 @@ def precover(
yield match
"""

if propagator is None:
if propagator_class is None:
raise ValueError("A propagator must be provided to run precovery")

propagator_instance = propagator()
propagator_instance = propagator_class()
orbit_id = orbit.orbit_id[0].as_py()

if datasets is not None:
Expand Down
6 changes: 3 additions & 3 deletions scripts/precovery-test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@

if __name__ == "__main__":

propagator = ASSISTPropagator()

parser = argparse.ArgumentParser(
description="Run precovery using test_orbits.csv in this directory"
)
Expand Down Expand Up @@ -57,7 +55,9 @@
# Select a single orbit
orbit = orbits[i]

candidates, frame_candidates = db.precover(orbit, tolerance=1 / 3600)
candidates, frame_candidates = db.precover(
orbit, tolerance=1 / 3600, propagator_class=ASSISTPropagator
)

print(
f"Found {len(candidates)} potential matches for orbit ID: {orbit.object_id[0].as_py()}"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_precovery_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_precovery(test_db_dir):
test_db_dir,
tolerance=1 / 3600,
window_size=1,
propagator=ASSISTPropagator(),
propagator_class=ASSISTPropagator,
)

object_observations = observations_df[
Expand Down
9 changes: 5 additions & 4 deletions tests/test_precovery_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_precover(precovery_db, sample_orbits):
precovery_db.frames.add_frames(ds_id, frames)

# Do the search. We should find the three observations we inserted.
matches, misses = precovery_db.precover(orbit, propagator=ASSISTPropagator())
matches, misses = precovery_db.precover(orbit, propagator_class=ASSISTPropagator)
assert len(matches) == 3
assert len(misses) == 0

Expand All @@ -41,7 +41,6 @@ def test_precover(precovery_db, sample_orbits):
@requires_jpl_ephem_data
def test_precover_dataset_filter(precovery_db, sample_orbits):
# Make two datasets which contain something we're looking for.
propagator = ASSISTPropagator()

orbit = sample_orbits[0]
timestamps = [50000.0, 50001.0, 50002.0]
Expand All @@ -62,7 +61,7 @@ def test_precover_dataset_filter(precovery_db, sample_orbits):

# Do the search with no dataset filters. We should find all six
# observations we inserted.
matches, misses = precovery_db.precover(orbit, propagator=propagator)
matches, misses = precovery_db.precover(orbit, propagator_class=ASSISTPropagator)
assert len(matches) == 6

have_ids = set(matches.observation_id.to_pylist())
Expand All @@ -72,7 +71,9 @@ def test_precover_dataset_filter(precovery_db, sample_orbits):
# Now repeat the search, but filter to just one dataset. We should
# only find that dataset's observations.
matches, misses = list(
precovery_db.precover(orbit, datasets={ds1_id}, propagator=propagator)
precovery_db.precover(
orbit, datasets={ds1_id}, propagator_class=ASSISTPropagator
)
)
assert len(matches) == 3

Expand Down

0 comments on commit 399f267

Please sign in to comment.