Skip to content

Commit

Permalink
Implement Observers.from_codes
Browse files Browse the repository at this point in the history
  • Loading branch information
moeyensj committed Sep 18, 2024
1 parent 6af3e0e commit aff459b
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 6 deletions.
106 changes: 100 additions & 6 deletions src/adam_core/observers/observers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import warnings
from typing import Union

import numpy as np
import numpy.typing as npt
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
import quivr as qv
from mpc_obscodes import mpc_obscodes
from typing_extensions import Self
Expand Down Expand Up @@ -54,15 +57,18 @@ class Observers(qv.Table):
coordinates = CartesianCoordinates.as_column()

@classmethod
def from_codes(cls, codes: pa.Array, times: Timestamp) -> Self:
def from_codes(
cls, codes: Union[list, npt.NDArray[np.str_], pa.Array], times: Timestamp
) -> Self:
"""
Create an Observers table from a list of codes and times. The codes and times
do not need to be unique. The observer state will be calculated for each time
and correctly matched to the input times and replicated for duplicate times.
do not need to be unique and are assumed to belong to each other in an element-wise fashion.
The observer state will be calculated correctly matched to the input times and
replicated for duplicate times.
Parameters
----------
codes : pa.Array (N)
codes : Union[list, npt.NDArray[np.str], pa.Array] (N)
MPC observatory codes for which to find the states.
times : Timestamp (N)
Epochs for which to find the observatory locations.
Expand All @@ -72,8 +78,96 @@ def from_codes(cls, codes: pa.Array, times: Timestamp) -> Self:
observers : `~adam_core.observers.observers.Observers` (N)
The observer and its state at each time.
"""
assert len(codes) == len(times)
raise NotImplementedError
if len(codes) != len(times):
raise ValueError("codes and times must have the same length.")

if not isinstance(codes, pa.Array):
codes = pa.array(codes, type=pa.large_string())

# Create a table with the codes and times and add
# and index column to track the original order
table = pa.Table.from_pydict(
{
"index": pa.array(range(len(codes)), type=pa.uint64()),
"code": codes,
"times.days": times.days,
"times.nanos": times.nanos,
}
)

# Expected observers schema with the addition of a
# column that tracks the original index
observers_schema = pa.schema(
[
pa.field("code", pa.large_string(), nullable=False),
pa.field(
"coordinates",
pa.struct(
[
pa.field("x", pa.float64()),
pa.field("y", pa.float64()),
pa.field("z", pa.float64()),
pa.field("vx", pa.float64()),
pa.field("vy", pa.float64()),
pa.field("vz", pa.float64()),
pa.field(
"time",
pa.struct(
[
pa.field("days", pa.int64()),
pa.field("nanos", pa.int64()),
]
),
),
pa.field(
"covariance",
pa.struct(
[pa.field("values", pa.large_list(pa.float64()))]
),
),
pa.field(
"origin",
pa.struct([pa.field("code", pa.large_string())]),
),
]
),
),
pa.field("index", pa.uint64()),
],
metadata={
"coordinates.time.scale": times.scale,
"coordinates.frame": "ecliptic",
},
)

# Create an empty table with the expected schema
observers_table = observers_schema.empty_table()

# Loop through each unique code and calculate the observer's
# state for each time (these can be non-unique as cls.from_code
# will handle this)
for code in table["code"].unique():

times_code = table.filter(pc.equal(table["code"], code))

observers = cls.from_code(
code.as_py(),
Timestamp.from_kwargs(
days=times_code["times.days"],
nanos=times_code["times.nanos"],
scale=times.scale,
),
)

observers_table_i = observers.table.append_column(
"index", times_code["index"]
)
observers_table = pa.concat_tables(
[observers_table, observers_table_i]
).combine_chunks()

observers_table = observers_table.sort_by(("index")).drop_columns(["index"])
return cls.from_pyarrow(observers_table)

@classmethod
def from_code(cls, code: Union[str, OriginCodes], times: Timestamp) -> Self:
Expand Down
61 changes: 61 additions & 0 deletions src/adam_core/observers/tests/test_observers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pyarrow as pa
import pyarrow.compute as pc
import pytest

from ...time import Timestamp
from ..observers import Observers


@pytest.fixture
def codes_times() -> tuple[pa.Array, Timestamp]:
codes = pa.array(
["500", "X05", "I41", "X05", "I41", "W84", "500"],
)

times = Timestamp.from_kwargs(
days=[59000, 59001, 59002, 59003, 59004, 59005, 59006],
nanos=[0, 0, 0, 0, 0, 0, 0],
scale="tdb",
)
return codes, times


def test_Observers_from_codes(codes_times) -> None:
# Test that observers from code returns the correct number of observers
# and in the order that they were requested
codes, times = codes_times

observers = Observers.from_codes(codes, times)
assert len(observers) == 7
assert pc.all(pc.equal(observers.code, codes)).as_py()
assert pc.all(pc.equal(observers.coordinates.time.days, times.days)).as_py()
assert pc.all(pc.equal(observers.coordinates.time.nanos, times.nanos)).as_py()


def test_Observers_from_codes_non_pyarrow(codes_times) -> None:
# Test that observers from code returns the correct number of observers
# and in the order that they were requested
codes, times = codes_times

observers = Observers.from_codes(codes.to_numpy(zero_copy_only=False), times)
assert len(observers) == 7
assert pc.all(pc.equal(observers.code, codes)).as_py()
assert pc.all(pc.equal(observers.coordinates.time.days, times.days)).as_py()
assert pc.all(pc.equal(observers.coordinates.time.nanos, times.nanos)).as_py()

observers = Observers.from_codes(codes.to_pylist(), times)
assert len(observers) == 7
assert pc.all(pc.equal(observers.code, codes)).as_py()
assert pc.all(pc.equal(observers.coordinates.time.days, times.days)).as_py()
assert pc.all(pc.equal(observers.coordinates.time.nanos, times.nanos)).as_py()


def test_Observers_from_codes_raises(codes_times) -> None:
# Test that observers from code raises an error if the codes and times
# are not the same length
codes, times = codes_times

with pytest.raises(ValueError, match="codes and times must have the same length."):
Observers.from_codes(codes[:3], times)
with pytest.raises(ValueError, match="codes and times must have the same length."):
Observers.from_codes(codes, times[:3])

0 comments on commit aff459b

Please sign in to comment.