From 48c80a721576d430fd1ceecbe91c6dcb3140faca Mon Sep 17 00:00:00 2001 From: Alec Koumjian Date: Mon, 2 Dec 2024 13:40:15 -0500 Subject: [PATCH] Test performance improvements --- precovery/precovery_db.py | 56 ++++++++++++++++----------------------- pyproject.toml | 10 +++---- tests/conftest.py | 45 +++++++++++++++++++++++++++++++ tests/test_benchmarks.py | 10 +++++++ 4 files changed, 81 insertions(+), 40 deletions(-) diff --git a/precovery/precovery_db.py b/precovery/precovery_db.py index f18f0d5..dc9ee47 100644 --- a/precovery/precovery_db.py +++ b/precovery/precovery_db.py @@ -1,5 +1,6 @@ import logging import os +import time from typing import Optional, Tuple, Type import numpy as np @@ -380,34 +381,25 @@ def find_healpixel_matches( PropagationTargets Propagation targets that match the ephemeris """ - unique_target_times: Timestamp = propagation_targets.time.unique() - filtered_targets = GenericFrame.empty() - for target_time in unique_target_times: - time_propagation_targets = propagation_targets.apply_mask( - propagation_targets.time.equals(target_time) - ) - assert len(time_propagation_targets) > 0, "No matching targets found" - time_matching_ephems_mask = ephems.coordinates.time.equals( - target_time, precision="ms" - ) - time_matching_ephems = ephems.apply_mask(time_matching_ephems_mask) - assert ( - len(time_matching_ephems) > 0 - ), "No matching ephemeris found for target time" - - ephem_healpixels = pa.array( - radec_to_healpixel( - ra=time_matching_ephems.coordinates.lon.to_numpy(), - dec=time_matching_ephems.coordinates.lat.to_numpy(), - nside=nside, - ), - type=pa.int64(), - ) + # Sort them both by time + propagation_targets = propagation_targets.sort_by(["time.days", "time.nanos"]) + ephems = ephems.sort_by(["coordinates.time.days", "coordinates.time.nanos"]) - matching_targets = time_propagation_targets.where( - pc.is_in(time_propagation_targets.healpixel, ephem_healpixels) - ) - filtered_targets = qv.concatenate([filtered_targets, matching_targets]) + # quickly check to make sure times are equal + assert pc.all( + propagation_targets.time.equals(ephems.coordinates.time, precision="ms") + ).as_py(), "Propagation targets and ephemeris must have matching times" + + # Calculate the healpixels for the ephemeris + ephem_healpixels = radec_to_healpixel( + ephems.coordinates.lon.to_numpy(), + ephems.coordinates.lat.to_numpy(), + nside=nside, + ) + + # Find the matching healpixels + mask = pc.equal(propagation_targets.healpixel, ephem_healpixels) + filtered_targets = propagation_targets.apply_mask(mask) return filtered_targets @@ -724,14 +716,12 @@ def _check_window( ) return PrecoveryCandidates.empty(), FrameCandidates.empty() - times = propagation_targets.time.unique() + times = propagation_targets.time # create our observers observers = Observers.from_code(obscode, times) - ## first propagate with 2_body propagated_orbits = propagate_2body(orbit, times) - # hotfix: rescale the times back from propagated_orbits # this behavior should be fixed in adam_core, to always return # the timescale of the submitted times @@ -739,11 +729,8 @@ def _check_window( "coordinates.time", propagated_orbits.coordinates.time.rescale("utc") ) - assert propagated_orbits.coordinates.time.equals(times) - # generate ephemeris ephems = generate_ephemeris_2body(propagated_orbits, observers) - frames_to_check = find_healpixel_matches( propagation_targets, ephems, self.frames.healpix_nside ) @@ -863,6 +850,9 @@ def find_matches_in_frame( observations, repeated_ephem, tolerance ) + if len(matching_observations) == 0: + return PrecoveryCandidates.empty() + candidates = candidates_from_ephem(matching_observations, matching_ephem, frame) return candidates diff --git a/pyproject.toml b/pyproject.toml index ce86ca1..3089604 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ classifiers = [ ] dependencies = [ - "adam_core>=0.2.5", "numpy", "numba", "jaxlib", @@ -37,11 +36,11 @@ dependencies = [ "pandas", "healpy", "requests", + "adam-core @ git+https://github.com/B612-Asteroid-Institute/adam_core@f843e60fd0434b5280803f4d7d6bd83222bbd77a", ] [dependency-groups] test = [ "adam_assist", - "pre-commit", "pytest", "black", "isort", @@ -53,9 +52,6 @@ test = [ "astroquery", "pytest-mock>=3.14.0", ] -dev = [ - "ipython>=8.29.0", -] [project.urls] "Homepage" = "https://github.com/b612-asteroid-institute/precovery" @@ -82,9 +78,9 @@ lint = { composite = [ fix = "ruff ./precovery --fix" typecheck = "mypy ./precovery" -test = "pytest --benchmark-disable {args}" +test = "pytest --benchmark-skip {args}" doctest = "pytest --doctest-plus --doctest-only" -benchmark = "pytest --benchmark-only" +benchmark = "pytest --benchmark-only --benchmark-cprofile=cumtime" coverage = "pytest --cov=precovery --cov-report=xml" [tool.pdm] diff --git a/tests/conftest.py b/tests/conftest.py index 93ca560..1aa7b73 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,20 @@ +import glob import os +import numpy as np +import pandas as pd import pytest from adam_core.orbits import Orbits from precovery.frame_db import FrameDB, FrameIndex +from precovery.ingest import index from precovery.precovery_db import PrecoveryDatabase +SAMPLE_ORBITS_FILE = os.path.join( + os.path.dirname(__file__), "data", "sample_orbits.parquet" +) +TEST_OBSERVATIONS_DIR = os.path.join(os.path.dirname(__file__), "data/index") + @pytest.fixture def test_db(): @@ -38,6 +47,42 @@ def precovery_db(tmp_path, frame_db): yield PrecoveryDatabase.create(str(tmp_path), nside=32) +@pytest.fixture +def precovery_db_with_data(tmp_path): + observation_files = glob.glob( + os.path.join(TEST_OBSERVATIONS_DIR, "dataset_*", "*.csv") + ) + observations_dfs = [] + for observation_file in observation_files: + observations_df_i = pd.read_csv( + observation_file, + float_precision="round_trip", + dtype={ + "dataset_id": str, + "observatory_code": str, + "filter": str, + "exposure_duration": np.float64, + }, + ) + observations_dfs.append(observations_df_i) + + dataset_id = observations_df_i["dataset_id"].values[0] + + index( + out_dir=tmp_path, + dataset_id=dataset_id, + dataset_name=dataset_id, + data_dir=os.path.join( + os.path.dirname(__file__), f"data/index/{dataset_id}/" + ), + nside=32, + ) + + # Read in the frames and observations from the database + db = PrecoveryDatabase.from_dir(tmp_path, mode="r") + return db + + @pytest.fixture def sample_orbits(): sample_orbits_file = os.path.join( diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index 8f83f1b..d0961b0 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -78,3 +78,13 @@ def benchmark_case(): propagate_2body(orbit, times) benchmark(benchmark_case) + + +@pytest.mark.benchmark(group="precovery") +def test_benchmark_precovery_search(benchmark, precovery_db_with_data, sample_orbits): + + orbit = sample_orbits[0] + + def benchmark_case(): + precovery_db_with_data.precover(orbit, tolerance=5 / 3600, window_size=7, propagator_class=ASSISTPropagator) + benchmark.pedantic(benchmark_case, iterations=1, rounds=1)