Skip to content

Commit

Permalink
Use integer time in all interfaces, including propagators
Browse files Browse the repository at this point in the history
  • Loading branch information
spenczar committed Oct 5, 2023
1 parent dd87996 commit de372bc
Show file tree
Hide file tree
Showing 14 changed files with 79 additions and 82 deletions.
9 changes: 4 additions & 5 deletions adam_core/dynamics/propagation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import jax.numpy as jnp
import numpy as np
from astropy.time import Time
from jax import config, jit, vmap

from ..constants import Constants as c
Expand Down Expand Up @@ -76,7 +75,7 @@ def _propagate_2body(

def propagate_2body(
orbits: Orbits,
times: Time,
times: Timestamp,
mu: float = MU,
max_iter: int = 1000,
tol: float = 1e-14,
Expand All @@ -88,7 +87,7 @@ def propagate_2body(
----------
orbits : `~jax.numpy.ndarray` (N, 6)
Cartesian orbits with position in units of au and velocity in units of au per day.
times : `~astropy.time.core.Time` (M)
times : Timestamp (M)
Epochs to which to propagate each orbit. If a single epoch is given, all orbits are propagated to this
epoch. If multiple epochs are given, then each orbit to will be propagated to each epoch.
mu : float, optional
Expand All @@ -108,8 +107,8 @@ def propagate_2body(
"""
# Lets extract the cartesian orbits and times from the orbits object
cartesian_orbits = orbits.coordinates.values
t0 = orbits.coordinates.time.to_astropy().tdb.mjd
t1 = times.tdb.mjd
t0 = orbits.coordinates.time.rescale("tdb").mjd()
t1 = times.rescale("tdb").mjd()
orbit_ids = orbits.orbit_id.to_numpy(zero_copy_only=False)
object_ids = orbits.object_id.to_numpy(zero_copy_only=False)

Expand Down
16 changes: 7 additions & 9 deletions adam_core/dynamics/tests/test_propagation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import spiceypy as sp
from astropy import units as u
from astropy.time import Time

from ...constants import Constants as c
from ...coordinates.cartesian import CartesianCoordinates
Expand Down Expand Up @@ -107,19 +106,19 @@ def test_propagate_2body_against_spice_elliptical(orbital_elements):
)

# Set propagation times (same for all orbits)
times = Time(
times = Timestamp.from_mjd(
t0.min() + np.arange(0, 10000, 10),
format="mjd",
scale="tdb",
)

# Propagate orbits with SPICE and accumulate results
spice_propagated = []
times_mjd = times.mjd()
for i, cartesian_i in enumerate(cartesian_elements):

# Calculate dts from t0 for this orbit (each orbit's t0 is different)
# but the final times we are propagating to are the same for all orbits
dts = times.tdb.mjd - t0[i]
dts = times_mjd - t0[i]
spice_propagated_i = np.empty((len(dts), 6))
for j, dt_i in enumerate(dts):
spice_propagated_i[j] = sp.prop2b(
Expand Down Expand Up @@ -177,19 +176,19 @@ def test_propagate_2body_against_spice_hyperbolic(orbital_elements):
)

# Set propagation times (same for all orbits)
times = Time(
times = Timestamp.from_mjd(
t0.min() + np.arange(0, 10000, 10),
format="mjd",
scale="tdb",
)

# Propagate orbits with SPICE and accumulate results
spice_propagated = []
times_mjd = times.mjd()
for i, cartesian_i in enumerate(cartesian_elements):

# Calculate dts from t0 for this orbit (each orbit's t0 is different)
# but the final times we are propagating to are the same for all orbits
dts = times.tdb.mjd - t0[i]
dts = times_mjd - t0[i]
spice_propagated_i = np.empty((len(dts), 6))
for j, dt_i in enumerate(dts):
spice_propagated_i[j] = sp.prop2b(
Expand Down Expand Up @@ -272,9 +271,8 @@ def test_benchmark_propagate_2body(benchmark, orbital_elements):
frame="ecliptic",
),
)
times = Time(
times = Timestamp.from_mjd(
[t0.min() + 1],
format="mjd",
scale="tdb",
)
benchmark(propagate_2body, orbits[0], times=times)
1 change: 0 additions & 1 deletion adam_core/dynamics/tisserand.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
This code generates the dictionary of semi-major axes for the
third body needed for the Tisserand parameter
from astropy.time import Time
from adam_core.orbits.query import _get_horizons_elements
ids = ["199", "299", "399", "499", "599", "699", "799", "899"]
Expand Down
5 changes: 2 additions & 3 deletions adam_core/observations/tests/test_detections.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import astropy.time
import healpy

from ...time import Timestamp
Expand All @@ -7,7 +6,7 @@


def test_detections_link_to_exposures():
start_times = astropy.time.Time(
start_times = Timestamp.from_iso8601(
[
"2000-01-01T00:00:00",
"2000-01-02T00:00:00",
Expand All @@ -16,7 +15,7 @@ def test_detections_link_to_exposures():
)
exp = Exposures.from_kwargs(
id=["e1", "e2"],
start_time=Timestamp.from_astropy(start_times),
start_time=start_times,
duration=[60, 30],
filter=["g", "r"],
observatory_code=["I41", "I41"],
Expand Down
29 changes: 13 additions & 16 deletions adam_core/observations/tests/test_exposures.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import pathlib

import astropy.time
import numpy as np
import quivr as qv

Expand All @@ -11,25 +10,24 @@


def test_exposure_midpoints():
start_times = astropy.time.Time(
start_times = Timestamp.from_iso8601(
[
"2000-01-01T00:00:00",
"2000-01-02T00:00:00",
],
scale="utc",
)
exp = Exposures.from_kwargs(
id=["e1", "e2"],
start_time=Timestamp.from_astropy(start_times),
start_time=start_times,
duration=[60, 30],
filter=["g", "r"],
observatory_code=["I41", "I41"],
)

midpoints = exp.midpoint()
midpoints_at = midpoints.to_astropy()
assert midpoints_at[0] == astropy.time.Time("2000-01-01T00:00:30", scale="utc")
assert midpoints_at[1] == astropy.time.Time("2000-01-02T00:00:15", scale="utc")
assert midpoints == Timestamp.from_iso8601(
["2000-01-01T00:00:30", "2000-01-02T00:00:15"]
)


def test_exposure_states():
Expand All @@ -51,15 +49,14 @@ def test_exposure_states():

# Mix up w84 and i41 in one big exposure table
codes = ["W84", "I41", "W84", "I41", "W84"]
state_times = astropy.time.Time(
state_times = qv.concatenate(
[
w84_state_data.time.to_astropy()[0],
i41_state_data.time.to_astropy()[0],
w84_state_data.time.to_astropy()[3],
i41_state_data.time.to_astropy()[1],
w84_state_data.time.to_astropy()[2],
],
scale="tdb",
w84_state_data.time[0],
i41_state_data.time[0],
w84_state_data.time[3],
i41_state_data.time[1],
w84_state_data.time[2],
]
)
expected = qv.concatenate(
[
Expand All @@ -72,7 +69,7 @@ def test_exposure_states():
)
exp = Exposures.from_kwargs(
id=["e1", "e2", "e3", "e4", "e5"],
start_time=Timestamp.from_astropy(state_times),
start_time=state_times,
duration=[0, 0, 0, 0, 0],
filter=["g", "r", "g", "r", "g"],
observatory_code=codes,
Expand Down
4 changes: 2 additions & 2 deletions adam_core/observers/observers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def from_code(cls, code: Union[str, OriginCodes], times: Timestamp) -> Self:
Examples
--------
>>> import numpy as np
>>> from astropy.time import Time
>>> from adam_core.time import Timestamp
>>> from adam_core.observers import Observers
>>> times = Time(np.arange(59000, 59000 + 100), scale="tdb", format="mjd")
>>> times = Timestamp.from_mjd(np.arange(59000, 59000 + 100), scale="tdb")
>>> observers = Observers.from_code("X05", times)
"""
from .state import get_observer_state
Expand Down
7 changes: 0 additions & 7 deletions adam_core/observers/tests/testdata/get_states.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import numpy as np
from astropy.time import Time
from astroquery.jplhorizons import Horizons

from adam_core.coordinates.cartesian import CartesianCoordinates
Expand All @@ -8,11 +6,6 @@
from adam_core.time import Timestamp

observatory_codes = ["I41", "X05", "F51", "W84", "000", "500"]
times = Time(
np.arange(59000, 60000, 10),
format="mjd",
scale="utc",
)

for code in observatory_codes:
for id in ["sun", "ssb"]:
Expand Down
19 changes: 9 additions & 10 deletions adam_core/orbits/query/horizons.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy.typing as npt
import pandas as pd
import pyarrow as pa
from astropy.time import Time
from astroquery.jplhorizons import Horizons

from ...coordinates.cartesian import CartesianCoordinates
Expand All @@ -18,7 +17,7 @@

def _get_horizons_vectors(
object_ids: Union[List, npt.ArrayLike],
times: Time,
times: Timestamp,
location: str = "@sun",
id_type: str = "smallbody",
aberrations: str = "geometric",
Expand Down Expand Up @@ -56,7 +55,7 @@ def _get_horizons_vectors(
for i, obj_id in enumerate(object_ids):
obj = Horizons(
id=obj_id,
epochs=times.tdb.mjd,
epochs=times.rescale("tdb").mjd(),
location=location,
id_type=id_type,
)
Expand All @@ -77,7 +76,7 @@ def _get_horizons_vectors(

def _get_horizons_elements(
object_ids: Union[List, npt.ArrayLike],
times: Time,
times: Timestamp,
location: str = "@sun",
id_type: str = "smallbody",
refplane: str = "ecliptic",
Expand Down Expand Up @@ -112,7 +111,7 @@ def _get_horizons_elements(
for i, obj_id in enumerate(object_ids):
obj = Horizons(
id=obj_id,
epochs=times.tdb.mjd,
epochs=times.rescale("tdb").mjd(),
location=location,
id_type=id_type,
)
Expand All @@ -133,7 +132,7 @@ def _get_horizons_elements(

def _get_horizons_ephemeris(
object_ids: Union[List, npt.ArrayLike],
times: Time,
times: Timestamp,
location: str,
id_type: str = "smallbody",
) -> pd.DataFrame:
Expand Down Expand Up @@ -163,7 +162,7 @@ def _get_horizons_ephemeris(
for i, obj_id in enumerate(object_ids):
obj = Horizons(
id=obj_id,
epochs=times.utc.mjd,
epochs=times.rescale("utc").mjd(),
location=location,
id_type=id_type,
)
Expand Down Expand Up @@ -213,7 +212,7 @@ def query_horizons_ephemeris(
for observatory_code, observers_i in observers.iterate_codes():
ephemeris = _get_horizons_ephemeris(
object_ids,
observers_i.coordinates.time.to_astropy(),
observers_i.coordinates.time,
observatory_code,
)
dfs.append(ephemeris)
Expand All @@ -229,7 +228,7 @@ def query_horizons_ephemeris(

def query_horizons(
object_ids: Union[List, npt.ArrayLike],
times: Time,
times: Timestamp,
coordinate_type: str = "cartesian",
location: str = "@sun",
id_type: str = "smallbody",
Expand All @@ -242,7 +241,7 @@ def query_horizons(
----------
object_ids : npt.ArrayLike (N)
Object IDs / designations recognizable by HORIZONS.
times : `~astropy.core.time.Time` (M)
times : Timestamp (M)
Astropy time object at which to gather state vectors.
coordinate_type : {'cartesian', 'keplerian', 'cometary'}
Type of orbital elements to return.
Expand Down
5 changes: 3 additions & 2 deletions adam_core/orbits/query/sbdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np
import numpy.typing as npt
from astropy.time import Time
from astroquery.jplsbdb import SBDB

from ...coordinates.cometary import CometaryCoordinates
Expand Down Expand Up @@ -203,7 +202,9 @@ def query_sbdb(ids: npt.ArrayLike) -> Orbits:
coords_cometary[i, 2] = elements["i"].value
coords_cometary[i, 3] = elements["om"].value
coords_cometary[i, 4] = elements["w"].value
coords_cometary[i, 5] = Time(elements["tp"].value, scale="tdb", format="jd").mjd
coords_cometary[i, 5] = (
Timestamp.from_jd([elements["tp"].value], scale="tdb").mjd()[0].as_py()
)

covariances_cometary = _convert_SBDB_covariances(covariances_sbdb)
times = Timestamp.from_jd(times, scale="tdb")
Expand Down
10 changes: 5 additions & 5 deletions adam_core/propagator/propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from typing import List, Optional, Union

import quivr as qv
from astropy.time import Time

from ..observers.observers import Observers
from ..orbits.ephemeris import Ephemeris
from ..orbits.orbits import Orbits
from ..orbits.variants import VariantOrbits
from ..time import Timestamp
from .utils import _iterate_chunks, sort_propagated_orbits

logger = logging.getLogger(__name__)
Expand All @@ -18,7 +18,7 @@


def propagation_worker(
orbits: OrbitType, times: Time, propagator: "Propagator"
orbits: OrbitType, times: Timestamp, propagator: "Propagator"
) -> Orbits:
propagated = propagator._propagate_orbits(orbits, times)
return propagated
Expand All @@ -44,7 +44,7 @@ class Propagator(ABC):
"""

@abstractmethod
def _propagate_orbits(self, orbits: OrbitType, times: Time) -> OrbitType:
def _propagate_orbits(self, orbits: OrbitType, times: Timestamp) -> OrbitType:
"""
Propagate orbits to times.
Expand All @@ -55,7 +55,7 @@ def _propagate_orbits(self, orbits: OrbitType, times: Time) -> OrbitType:
def propagate_orbits(
self,
orbits: Orbits,
times: Time,
times: Timestamp,
covariance: bool = False,
chunk_size: int = 100,
max_processes: Optional[int] = 1,
Expand All @@ -67,7 +67,7 @@ def propagate_orbits(
----------
orbits : `~adam_core.orbits.orbits.Orbits` (N)
Orbits to propagate.
times : `~astropy.time.core.Time` (M)
times : Timestamp (M)
Times to which to propagate orbits.
covariance: bool, optional
Propagate the covariance matrices of the orbits. This is done by sampling the
Expand Down
Loading

0 comments on commit de372bc

Please sign in to comment.