Skip to content

Commit

Permalink
Test performance improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
akoumjian committed Dec 2, 2024
1 parent 4146a30 commit 48c80a7
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 40 deletions.
56 changes: 23 additions & 33 deletions precovery/precovery_db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import time
from typing import Optional, Tuple, Type

import numpy as np
Expand Down Expand Up @@ -380,34 +381,25 @@ def find_healpixel_matches(
PropagationTargets
Propagation targets that match the ephemeris
"""
unique_target_times: Timestamp = propagation_targets.time.unique()
filtered_targets = GenericFrame.empty()
for target_time in unique_target_times:
time_propagation_targets = propagation_targets.apply_mask(
propagation_targets.time.equals(target_time)
)
assert len(time_propagation_targets) > 0, "No matching targets found"
time_matching_ephems_mask = ephems.coordinates.time.equals(
target_time, precision="ms"
)
time_matching_ephems = ephems.apply_mask(time_matching_ephems_mask)
assert (
len(time_matching_ephems) > 0
), "No matching ephemeris found for target time"

ephem_healpixels = pa.array(
radec_to_healpixel(
ra=time_matching_ephems.coordinates.lon.to_numpy(),
dec=time_matching_ephems.coordinates.lat.to_numpy(),
nside=nside,
),
type=pa.int64(),
)
# Sort them both by time
propagation_targets = propagation_targets.sort_by(["time.days", "time.nanos"])
ephems = ephems.sort_by(["coordinates.time.days", "coordinates.time.nanos"])

matching_targets = time_propagation_targets.where(
pc.is_in(time_propagation_targets.healpixel, ephem_healpixels)
)
filtered_targets = qv.concatenate([filtered_targets, matching_targets])
# quickly check to make sure times are equal
assert pc.all(
propagation_targets.time.equals(ephems.coordinates.time, precision="ms")
).as_py(), "Propagation targets and ephemeris must have matching times"

# Calculate the healpixels for the ephemeris
ephem_healpixels = radec_to_healpixel(
ephems.coordinates.lon.to_numpy(),
ephems.coordinates.lat.to_numpy(),
nside=nside,
)

# Find the matching healpixels
mask = pc.equal(propagation_targets.healpixel, ephem_healpixels)
filtered_targets = propagation_targets.apply_mask(mask)

return filtered_targets

Expand Down Expand Up @@ -724,26 +716,21 @@ def _check_window(
)
return PrecoveryCandidates.empty(), FrameCandidates.empty()

times = propagation_targets.time.unique()
times = propagation_targets.time

# create our observers
observers = Observers.from_code(obscode, times)

## first propagate with 2_body
propagated_orbits = propagate_2body(orbit, times)

# hotfix: rescale the times back from propagated_orbits
# this behavior should be fixed in adam_core, to always return
# the timescale of the submitted times
propagated_orbits = propagated_orbits.set_column(
"coordinates.time", propagated_orbits.coordinates.time.rescale("utc")
)

assert propagated_orbits.coordinates.time.equals(times)

# generate ephemeris
ephems = generate_ephemeris_2body(propagated_orbits, observers)

frames_to_check = find_healpixel_matches(
propagation_targets, ephems, self.frames.healpix_nside
)
Expand Down Expand Up @@ -863,6 +850,9 @@ def find_matches_in_frame(
observations, repeated_ephem, tolerance
)

if len(matching_observations) == 0:
return PrecoveryCandidates.empty()

candidates = candidates_from_ephem(matching_observations, matching_ephem, frame)
return candidates

Expand Down
10 changes: 3 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ classifiers = [
]

dependencies = [
"adam_core>=0.2.5",
"numpy",
"numba",
"jaxlib",
Expand All @@ -37,11 +36,11 @@ dependencies = [
"pandas",
"healpy",
"requests",
"adam-core @ git+https://github.com/B612-Asteroid-Institute/adam_core@f843e60fd0434b5280803f4d7d6bd83222bbd77a",
]
[dependency-groups]
test = [
"adam_assist",
"pre-commit",
"pytest",
"black",
"isort",
Expand All @@ -53,9 +52,6 @@ test = [
"astroquery",
"pytest-mock>=3.14.0",
]
dev = [
"ipython>=8.29.0",
]

[project.urls]
"Homepage" = "https://github.com/b612-asteroid-institute/precovery"
Expand All @@ -82,9 +78,9 @@ lint = { composite = [
fix = "ruff ./precovery --fix"
typecheck = "mypy ./precovery"

test = "pytest --benchmark-disable {args}"
test = "pytest --benchmark-skip {args}"
doctest = "pytest --doctest-plus --doctest-only"
benchmark = "pytest --benchmark-only"
benchmark = "pytest --benchmark-only --benchmark-cprofile=cumtime"
coverage = "pytest --cov=precovery --cov-report=xml"

[tool.pdm]
Expand Down
45 changes: 45 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import glob
import os

import numpy as np
import pandas as pd
import pytest
from adam_core.orbits import Orbits

from precovery.frame_db import FrameDB, FrameIndex
from precovery.ingest import index
from precovery.precovery_db import PrecoveryDatabase

SAMPLE_ORBITS_FILE = os.path.join(
os.path.dirname(__file__), "data", "sample_orbits.parquet"
)
TEST_OBSERVATIONS_DIR = os.path.join(os.path.dirname(__file__), "data/index")


@pytest.fixture
def test_db():
Expand Down Expand Up @@ -38,6 +47,42 @@ def precovery_db(tmp_path, frame_db):
yield PrecoveryDatabase.create(str(tmp_path), nside=32)


@pytest.fixture
def precovery_db_with_data(tmp_path):
observation_files = glob.glob(
os.path.join(TEST_OBSERVATIONS_DIR, "dataset_*", "*.csv")
)
observations_dfs = []
for observation_file in observation_files:
observations_df_i = pd.read_csv(
observation_file,
float_precision="round_trip",
dtype={
"dataset_id": str,
"observatory_code": str,
"filter": str,
"exposure_duration": np.float64,
},
)
observations_dfs.append(observations_df_i)

dataset_id = observations_df_i["dataset_id"].values[0]

index(
out_dir=tmp_path,
dataset_id=dataset_id,
dataset_name=dataset_id,
data_dir=os.path.join(
os.path.dirname(__file__), f"data/index/{dataset_id}/"
),
nside=32,
)

# Read in the frames and observations from the database
db = PrecoveryDatabase.from_dir(tmp_path, mode="r")
return db


@pytest.fixture
def sample_orbits():
sample_orbits_file = os.path.join(
Expand Down
10 changes: 10 additions & 0 deletions tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,13 @@ def benchmark_case():
propagate_2body(orbit, times)

benchmark(benchmark_case)


@pytest.mark.benchmark(group="precovery")
def test_benchmark_precovery_search(benchmark, precovery_db_with_data, sample_orbits):

orbit = sample_orbits[0]

def benchmark_case():
precovery_db_with_data.precover(orbit, tolerance=5 / 3600, window_size=7, propagator_class=ASSISTPropagator)
benchmark.pedantic(benchmark_case, iterations=1, rounds=1)

0 comments on commit 48c80a7

Please sign in to comment.