Skip to content

Commit ce8f08c

Browse files
committed
Use SphericalCoordinates as a column instead of PointSourceDetectinons in Observations
1 parent f61d63c commit ce8f08c

File tree

10 files changed

+93
-135
lines changed

10 files changed

+93
-135
lines changed

thor/config.py

-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
from dataclasses import dataclass
22
from typing import Literal, Optional
33

4-
import numpy as np
5-
import numpy.typing as npt
6-
74

85
@dataclass
96
class Config:

thor/observations/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
# flake8: noqa: F401
22
from .observations import Observations
3+
from .photometry import Photometry

thor/observations/filters.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
import quivr as qv
88
import ray
9-
from adam_core.observations import PointSourceDetections
9+
from adam_core.coordinates import SphericalCoordinates
1010

1111
from ..orbit import TestOrbit, TestOrbitEphemeris
1212

@@ -45,7 +45,7 @@ def TestOrbitRadiusObservationFilter_worker(
4545
# Select the ephemeris and observations for this state
4646
ephemeris_state = ephemeris.select("id", state_id)
4747
observations_state = observations.select("state_id", state_id)
48-
detections_state = observations_state.detections
48+
coordinates_state = observations_state.coordinates
4949

5050
assert (
5151
len(ephemeris_state) == 1
@@ -56,7 +56,7 @@ def TestOrbitRadiusObservationFilter_worker(
5656

5757
# Return the observations within the radius for this particular state
5858
return observations_state.apply_mask(
59-
_within_radius(detections_state, ephem_ra, ephem_dec, radius)
59+
_within_radius(coordinates_state, ephem_ra, ephem_dec, radius)
6060
)
6161

6262

@@ -199,7 +199,11 @@ def apply(
199199

200200
observations_filtered = qv.concatenate(filtered_observations_list)
201201
observations_filtered = observations_filtered.sort_by(
202-
["detections.time.days", "detections.time.nanos", "observatory_code"]
202+
[
203+
"coordinates.time.days",
204+
"coordinates.time.nanos",
205+
"coordinates.origin.code",
206+
]
203207
)
204208

205209
time_end = time.perf_counter()
@@ -213,19 +217,19 @@ def apply(
213217

214218

215219
def _within_radius(
216-
detections: PointSourceDetections,
220+
coords: SphericalCoordinates,
217221
ra: float,
218222
dec: float,
219223
radius: float,
220224
) -> np.array:
221225
"""
222226
Return a boolean mask that identifies which of
223-
the detections are within a given radius of a given ra and dec.
227+
the coords are within a given radius of a given ra and dec.
224228
225229
Parameters
226230
----------
227-
detections : `~adam_core.observations.detections.PointSourceDetections`
228-
The detections to filter.
231+
coords : `~adam_core.coordinates.spherical.SphericalCoordinates`
232+
The coords to filter.
229233
ra : float
230234
The right ascension of the center of the radius in degrees.
231235
dec : float
@@ -236,11 +240,11 @@ def _within_radius(
236240
Returns
237241
-------
238242
mask : `~numpy.ndarray`
239-
A boolean mask that identifies which of the detections are within
243+
A boolean mask that identifies which of the coords are within
240244
the radius.
241245
"""
242-
det_ra = np.deg2rad(detections.ra.to_numpy())
243-
det_dec = np.deg2rad(detections.dec.to_numpy())
246+
det_ra = np.deg2rad(coords.lon.to_numpy())
247+
det_dec = np.deg2rad(coords.lat.to_numpy())
244248

245249
center_ra = np.deg2rad(ra)
246250
center_dec = np.deg2rad(dec)

thor/observations/observations.py

+36-50
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from adam_core.observers import Observers
1010
from adam_core.time import Timestamp
1111

12+
from .photometry import Photometry
13+
1214

1315
class ObserversWithStates(qv.Table):
1416
state_id = qv.Int64Column()
@@ -26,19 +28,12 @@ class Observations(qv.Table):
2628
assign a unique state ID. If not using this constructor, please ensure that the detections are sorted
2729
by time and observatory code and that each unique combination of time and observatory code has a unique
2830
state ID.
29-
30-
Columns
31-
-------
32-
detections : `~adam_core.observations.detections.PointSourceDetections`
33-
A table of point source detections.
34-
observatory_code : `~qv.StringColumn`
35-
The observatory code for each detection.
36-
state_id : `~qv.Int64Column`
37-
The state ID for each detection.
3831
"""
3932

40-
detections = PointSourceDetections.as_column()
41-
observatory_code = qv.StringColumn()
33+
id = qv.StringColumn()
34+
exposure_id = qv.StringColumn()
35+
coordinates = SphericalCoordinates.as_column()
36+
photometry = Photometry.as_column()
4237
state_id = qv.Int64Column()
4338

4439
@classmethod
@@ -100,13 +95,15 @@ def from_detections_and_exposures(
10095
)
10196

10297
# Extract the exposure IDs and the observatory codes from the exposures table
103-
exposure_obscodes = pa.table(
104-
[exposures.id, exposures.observatory_code],
105-
names=["exposure_id", "observatory_code"],
98+
exposure_filters_obscodes = pa.table(
99+
[exposures.id, exposures.filter, exposures.observatory_code],
100+
names=["exposure_id", "filter", "observatory_code"],
106101
)
107102

108103
# Join the detection times and the exposure IDs so that each detection has an observatory code
109-
obscode_times = detections_flattened.join(exposure_obscodes, ["exposure_id"])
104+
obscode_times = detections_flattened.join(
105+
exposure_filters_obscodes, ["exposure_id"]
106+
)
110107

111108
# Group the detections by the observatory code and the detection times and then grab the unique ones
112109
unique_obscode_times = obscode_times.group_by(
@@ -139,23 +136,34 @@ def from_detections_and_exposures(
139136
[("state_id", "ascending")]
140137
)
141138

139+
sigmas = np.zeros((len(detections_with_states), 6))
140+
sigmas[:, 1] = detections_with_states["ra_sigma"].to_numpy(zero_copy_only=False)
141+
sigmas[:, 2] = detections_with_states["dec_sigma"].to_numpy(
142+
zero_copy_only=False
143+
)
144+
142145
return cls.from_kwargs(
143-
detections=PointSourceDetections.from_kwargs(
144-
id=detections_with_states["id"],
145-
exposure_id=detections_with_states["exposure_id"],
146+
id=detections_with_states["id"],
147+
exposure_id=detections_with_states["exposure_id"],
148+
coordinates=SphericalCoordinates.from_kwargs(
149+
lon=detections_with_states["ra"],
150+
lat=detections_with_states["dec"],
146151
time=Timestamp.from_kwargs(
147152
days=detections_with_states["days"],
148153
nanos=detections_with_states["nanos"],
149154
scale="utc",
150155
),
151-
ra=detections_with_states["ra"],
152-
ra_sigma=detections_with_states["ra_sigma"],
153-
dec=detections_with_states["dec"],
154-
dec_sigma=detections_with_states["dec_sigma"],
156+
covariance=CoordinateCovariances.from_sigmas(sigmas),
157+
origin=Origin.from_kwargs(
158+
code=detections_with_states["observatory_code"]
159+
),
160+
frame="equatorial",
161+
),
162+
photometry=Photometry.from_kwargs(
163+
filter=detections_with_states["filter"],
155164
mag=detections_with_states["mag"],
156165
mag_sigma=detections_with_states["mag_sigma"],
157166
),
158-
observatory_code=detections_with_states["observatory_code"],
159167
state_id=detections_with_states["state_id"],
160168
)
161169

@@ -174,7 +182,7 @@ def get_observers(self) -> ObserversWithStates:
174182
for code, observations_i in self.group_by_observatory_code():
175183
# Extract unique times and make sure they are sorted
176184
# by time in ascending order
177-
unique_times = observations_i.detections.time.unique()
185+
unique_times = observations_i.coordinates.time.unique()
178186
unique_times = unique_times.sort_by(["days", "nanos"])
179187

180188
# States are defined by unique times and observatory codes and
@@ -194,28 +202,6 @@ def get_observers(self) -> ObserversWithStates:
194202
observers = qv.concatenate(observers_with_states)
195203
return observers.sort_by("state_id")
196204

197-
def to_spherical_coordinates(self) -> SphericalCoordinates:
198-
"""
199-
Convert the observations to spherical coordinates which can be used
200-
to calculate residuals.
201-
202-
Returns
203-
-------
204-
coordinates : `~adam_core.coordinates.spherical.SphericalCoordinates`
205-
The detections represented as spherical coordinates.
206-
"""
207-
sigmas = np.zeros((len(self), 6))
208-
sigmas[:, 1] = self.detections.ra_sigma.to_numpy(zero_copy_only=False)
209-
sigmas[:, 2] = self.detections.dec_sigma.to_numpy(zero_copy_only=False)
210-
return SphericalCoordinates.from_kwargs(
211-
lon=self.detections.ra,
212-
lat=self.detections.dec,
213-
time=self.detections.time,
214-
covariance=CoordinateCovariances.from_sigmas(sigmas),
215-
origin=Origin.from_kwargs(code=self.observatory_code),
216-
frame="equatorial",
217-
)
218-
219205
def select_exposure(self, exposure_id: int) -> "Observations":
220206
"""
221207
Select observations from a single exposure.
@@ -230,7 +216,7 @@ def select_exposure(self, exposure_id: int) -> "Observations":
230216
observations : `~Observations`
231217
Observations from the specified exposure.
232218
"""
233-
return self.apply_mask(pc.equal(self.detections.exposure_id, exposure_id))
219+
return self.apply_mask(pc.equal(self.exposure_id, exposure_id))
234220

235221
def group_by_exposure(self) -> Iterator[Tuple[str, "Observations"]]:
236222
"""
@@ -242,7 +228,7 @@ def group_by_exposure(self) -> Iterator[Tuple[str, "Observations"]]:
242228
observations : Iterator[`~thor.observations.observations.Observations`]
243229
Observations belonging to individual exposures.
244230
"""
245-
exposure_ids = self.detections.exposure_id
231+
exposure_ids = self.exposure_id
246232
for exposure_id in exposure_ids.unique().sort():
247233
yield exposure_id.as_py(), self.select_exposure(exposure_id)
248234

@@ -260,7 +246,7 @@ def select_observatory_code(self, observatory_code) -> "Observations":
260246
observations : `~Observations`
261247
Observations from the specified observatory.
262248
"""
263-
return self.apply_mask(pc.equal(self.observatory_code, observatory_code))
249+
return self.apply_mask(pc.equal(self.coordinates.origin.code, observatory_code))
264250

265251
def group_by_observatory_code(self) -> Iterator[Tuple[str, "Observations"]]:
266252
"""
@@ -271,7 +257,7 @@ def group_by_observatory_code(self) -> Iterator[Tuple[str, "Observations"]]:
271257
observations : Iterator[`~thor.observations.observations.Observations`]
272258
Observations belonging to individual observatories.
273259
"""
274-
observatory_codes = self.observatory_code
260+
observatory_codes = self.coordinates.origin.code
275261
for observatory_code in observatory_codes.unique().sort():
276262
yield observatory_code.as_py(), self.select_observatory_code(
277263
observatory_code

thor/orbit.py

+7-38
Original file line numberDiff line numberDiff line change
@@ -76,48 +76,19 @@ def range_observations_worker(
7676
"""
7777
observations_state = observations.select("state_id", state_id)
7878
ephemeris_state = ephemeris.select("id", state_id)
79-
detections_state = observations_state.detections
8079

8180
# Get the heliocentric position vector of the object at the time of the exposure
8281
r = ephemeris_state.ephemeris.aberrated_coordinates.r[0]
8382

8483
# Get the observer's heliocentric coordinates
8584
observer_i = ephemeris_state.observer
8685

87-
# Create an array of observatory codes for the detections
88-
num_detections = len(observations_state)
89-
observatory_codes = np.repeat(
90-
observations_state.observatory_code[0].as_py(), num_detections
91-
)
92-
93-
# The following can be replaced with:
94-
# coords = observations_state.to_spherical(observatory_codes)
95-
# Start replacement:
96-
sigma_data = np.vstack(
97-
[
98-
pa.nulls(num_detections, pa.float64()),
99-
detections_state.ra_sigma.to_numpy(zero_copy_only=False),
100-
detections_state.dec_sigma.to_numpy(zero_copy_only=False),
101-
pa.nulls(num_detections, pa.float64()),
102-
pa.nulls(num_detections, pa.float64()),
103-
pa.nulls(num_detections, pa.float64()),
104-
]
105-
).T
106-
coords = SphericalCoordinates.from_kwargs(
107-
lon=detections_state.ra,
108-
lat=detections_state.dec,
109-
time=detections_state.time,
110-
covariance=CoordinateCovariances.from_sigmas(sigma_data),
111-
origin=Origin.from_kwargs(code=observatory_codes),
112-
frame="equatorial",
113-
)
114-
# End replacement (only once
115-
# https://github.com/B612-Asteroid-Institute/adam_core/pull/45 is merged)
116-
11786
return RangedPointSourceDetections.from_kwargs(
118-
id=detections_state.id,
119-
exposure_id=detections_state.exposure_id,
120-
coordinates=assume_heliocentric_distance(r, coords, observer_i.coordinates),
87+
id=observations_state.id,
88+
exposure_id=observations_state.exposure_id,
89+
coordinates=assume_heliocentric_distance(
90+
r, observations_state.coordinates, observer_i.coordinates
91+
),
12192
state_id=observations_state.state_id,
12293
)
12394

@@ -205,9 +176,7 @@ def _is_cache_fresh(self, observations: Observations) -> bool:
205176
if self._cached_ephemeris is None or self._cached_observation_ids is None:
206177
return False
207178
elif pc.all(
208-
pc.is_in(
209-
observations.detections.id.sort(), self._cached_observation_ids.sort()
210-
)
179+
pc.is_in(observations.id.sort(), self._cached_observation_ids.sort())
211180
).as_py():
212181
return True
213182
else:
@@ -231,7 +200,7 @@ def _cache_ephemeris(
231200
None
232201
"""
233202
self._cached_ephemeris = ephemeris
234-
self._cached_observation_ids = observations.detections.id
203+
self._cached_observation_ids = observations.id
235204

236205
def propagate(
237206
self,

0 commit comments

Comments
 (0)