Skip to content

Commit 3ec04ee

Browse files
committed
Add InputObservations
1 parent a6025b7 commit 3ec04ee

File tree

1 file changed

+119
-105
lines changed

1 file changed

+119
-105
lines changed

thor/observations/observations.py

+119-105
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,101 @@
1010
from adam_core.time import Timestamp
1111

1212
from .photometry import Photometry
13+
from .states import calculate_state_ids
1314

1415

1516
class ObserversWithStates(qv.Table):
1617
state_id = qv.Int64Column()
1718
observers = Observers.as_column()
1819

1920

20-
class Observations(qv.Table):
21-
"""
22-
The Observations table stored invidual point source detections and a state ID for each unique
23-
combination of detection time and observatory code. The state ID is used as reference to a specific
24-
observing geometry.
25-
26-
The recommended constructor to use is `~Observations.from_detections_and_exposures`, as this function
27-
will sort the detections by time and observatory code, and for each unique combination of the two
28-
assign a unique state ID. If not using this constructor, please ensure that the detections are sorted
29-
by time and observatory code and that each unique combination of time and observatory code has a unique
30-
state ID.
31-
"""
21+
class InputObservations(qv.Table):
22+
id = qv.StringColumn()
23+
exposure_id = qv.StringColumn()
24+
time = Timestamp.as_column()
25+
ra = qv.Float64Column()
26+
dec = qv.Float64Column()
27+
ra_sigma = qv.Float64Column(nullable=True)
28+
dec_sigma = qv.Float64Column(nullable=True)
29+
ra_dec_cov = qv.Float64Column(nullable=True)
30+
mag = qv.Float64Column()
31+
mag_sigma = qv.Float64Column(nullable=True)
32+
filter = qv.StringColumn()
33+
observatory_code = qv.StringColumn()
3234

35+
36+
class Observations(qv.Table):
3337
id = qv.StringColumn()
3438
exposure_id = qv.StringColumn()
3539
coordinates = SphericalCoordinates.as_column()
3640
photometry = Photometry.as_column()
3741
state_id = qv.Int64Column()
3842

43+
@classmethod
44+
def from_input_observations(cls, observations: InputObservations) -> "Observations":
45+
"""
46+
Create a THOR observations table from an InputObservations table. The InputObservations table
47+
are sorted by ascending time and observatory code.
48+
49+
Parameters
50+
----------
51+
observations : `~InputObservations`
52+
A table of input observations.
53+
54+
Returns
55+
-------
56+
observations : `~Observations`
57+
A table of THOR observations.
58+
"""
59+
# Sort the observations by time and observatory code
60+
observations_sorted = observations.sort_by(
61+
["time.days", "time.nanos", "observatory_code"]
62+
)
63+
64+
# If the times are not in UTC, convert them to UTC
65+
if observations_sorted.time.scale != "utc":
66+
observations_sorted = observations_sorted.set_column(
67+
"time", observations_sorted.time.rescale("utc")
68+
)
69+
70+
# Extract the sigma and covariance values for RA and Dec
71+
ra_sigma = observations_sorted.ra_sigma.to_numpy(zero_copy_only=False)
72+
dec_sigma = observations_sorted.dec_sigma.to_numpy(zero_copy_only=False)
73+
ra_dec_cov = observations_sorted.ra_dec_cov.to_numpy(zero_copy_only=False)
74+
75+
# Create the covariance matrices
76+
covariance_matrices = np.full((len(observations_sorted), 6, 6), np.nan)
77+
covariance_matrices[:, 1, 1] = ra_sigma**2
78+
covariance_matrices[:, 2, 2] = dec_sigma**2
79+
covariance_matrices[:, 1, 2] = ra_dec_cov
80+
covariance_matrices[:, 2, 1] = ra_dec_cov
81+
covariances = CoordinateCovariances.from_matrix(covariance_matrices)
82+
83+
# Create the coordinates table
84+
coords = SphericalCoordinates.from_kwargs(
85+
lon=observations_sorted.ra,
86+
lat=observations_sorted.dec,
87+
time=observations_sorted.time,
88+
covariance=covariances,
89+
origin=Origin.from_kwargs(code=observations_sorted.observatory_code),
90+
frame="equatorial",
91+
)
92+
93+
# Create the photometry table
94+
photometry = Photometry.from_kwargs(
95+
filter=observations_sorted.filter,
96+
mag=observations_sorted.mag,
97+
mag_sigma=observations_sorted.mag_sigma,
98+
)
99+
100+
return cls.from_kwargs(
101+
id=observations_sorted.id,
102+
exposure_id=observations_sorted.exposure_id,
103+
coordinates=coords,
104+
photometry=photometry,
105+
state_id=calculate_state_ids(coords),
106+
)
107+
39108
@classmethod
40109
def from_detections_and_exposures(
41110
cls, detections: PointSourceDetections, exposures: Exposures
@@ -60,113 +129,58 @@ def from_detections_and_exposures(
60129
# TODO: One thing we could try in the future is truncating observation times to ~1ms and using those to group
61130
# into indvidual states (i.e. if two observations are within 1ms of each other, they are in the same state). Again,
62131
# this only matters for those detections that have times that differ from the midpoint time of the exposure (LSST)
63-
64132
# If the detection times are not in UTC, convert them to UTC
65133
if detections.time.scale != "utc":
66134
detections = detections.set_column("time", detections.time.rescale("utc"))
67135

68-
# Flatten the detections table (i.e. remove the nested columns). Unfortunately joins on tables
69-
# with nested columns are not all supported in pyarrow
70-
detections_flattened = pa.table(
71-
[
72-
detections.id,
73-
detections.exposure_id,
74-
detections.time.days,
75-
detections.time.nanos,
76-
detections.ra,
77-
detections.ra_sigma,
78-
detections.dec,
79-
detections.dec_sigma,
80-
detections.mag,
81-
detections.mag_sigma,
82-
],
83-
names=[
84-
"id",
85-
"exposure_id",
86-
"days",
87-
"nanos",
88-
"ra",
89-
"ra_sigma",
90-
"dec",
91-
"dec_sigma",
92-
"mag",
93-
"mag_sigma",
94-
],
136+
# Join the detections and exposures tables
137+
detections_flattened = detections.flattened_table()
138+
exposures_flattened = exposures.flattened_table()
139+
detections_exposures = detections_flattened.join(
140+
exposures_flattened, ["exposure_id"], right_keys=["id"]
95141
)
96142

97-
# Extract the exposure IDs and the observatory codes from the exposures table
98-
exposure_filters_obscodes = pa.table(
99-
[exposures.id, exposures.filter, exposures.observatory_code],
100-
names=["exposure_id", "filter", "observatory_code"],
143+
# Create covariance matrices
144+
sigmas = np.zeros((len(detections_exposures), 6))
145+
sigmas[:, 1] = detections_exposures["ra_sigma"].to_numpy(zero_copy_only=False)
146+
sigmas[:, 2] = detections_exposures["dec_sigma"].to_numpy(zero_copy_only=False)
147+
covariances = CoordinateCovariances.from_sigmas(sigmas)
148+
149+
# Create the coordinates table
150+
coordinates = SphericalCoordinates.from_kwargs(
151+
lon=detections_exposures["ra"],
152+
lat=detections_exposures["dec"],
153+
time=Timestamp.from_kwargs(
154+
days=detections_exposures["time.days"],
155+
nanos=detections_exposures["time.nanos"],
156+
scale="utc",
157+
),
158+
covariance=covariances,
159+
origin=Origin.from_kwargs(code=detections_exposures["observatory_code"]),
160+
frame="equatorial",
101161
)
102162

103-
# Join the detection times and the exposure IDs so that each detection has an observatory code
104-
obscode_times = detections_flattened.join(
105-
exposure_filters_obscodes, ["exposure_id"]
163+
# Create the photometry table
164+
photometry = Photometry.from_kwargs(
165+
filter=detections_exposures["filter"],
166+
mag=detections_exposures["mag"],
167+
mag_sigma=detections_exposures["mag_sigma"],
106168
)
107169

108-
# Group the detections by the observatory code and the detection times and then grab the unique ones
109-
unique_obscode_times = obscode_times.group_by(
110-
["days", "nanos", "observatory_code"]
111-
).aggregate([])
112-
113-
# Now sort the unique detections by the observatory code and the detection time
114-
unique_obscode_times = unique_obscode_times.sort_by(
170+
return cls.from_kwargs(
171+
id=detections_exposures["id"],
172+
exposure_id=detections_exposures["exposure_id"],
173+
coordinates=coordinates,
174+
photometry=photometry,
175+
state_id=calculate_state_ids(coordinates),
176+
).sort_by(
115177
[
116-
("days", "ascending"),
117-
("nanos", "ascending"),
118-
("observatory_code", "ascending"),
178+
"coordinates.time.days",
179+
"coordinates.time.nanos",
180+
"coordinates.origin.code",
119181
]
120182
)
121183

122-
# For each unique detection time and observatory code assign a unique state ID
123-
unique_obscode_times = unique_obscode_times.add_column(
124-
0,
125-
pa.field("state_id", pa.int64()),
126-
pa.array(np.arange(0, len(unique_obscode_times))),
127-
)
128-
129-
# Join the unique observatory code and detections back to the original detections
130-
detections_with_states = obscode_times.join(
131-
unique_obscode_times, ["days", "nanos", "observatory_code"]
132-
)
133-
134-
# Now sort the detections one final time by state ID
135-
detections_with_states = detections_with_states.sort_by(
136-
[("state_id", "ascending")]
137-
)
138-
139-
sigmas = np.zeros((len(detections_with_states), 6))
140-
sigmas[:, 1] = detections_with_states["ra_sigma"].to_numpy(zero_copy_only=False)
141-
sigmas[:, 2] = detections_with_states["dec_sigma"].to_numpy(
142-
zero_copy_only=False
143-
)
144-
145-
return cls.from_kwargs(
146-
id=detections_with_states["id"],
147-
exposure_id=detections_with_states["exposure_id"],
148-
coordinates=SphericalCoordinates.from_kwargs(
149-
lon=detections_with_states["ra"],
150-
lat=detections_with_states["dec"],
151-
time=Timestamp.from_kwargs(
152-
days=detections_with_states["days"],
153-
nanos=detections_with_states["nanos"],
154-
scale="utc",
155-
),
156-
covariance=CoordinateCovariances.from_sigmas(sigmas),
157-
origin=Origin.from_kwargs(
158-
code=detections_with_states["observatory_code"]
159-
),
160-
frame="equatorial",
161-
),
162-
photometry=Photometry.from_kwargs(
163-
filter=detections_with_states["filter"],
164-
mag=detections_with_states["mag"],
165-
mag_sigma=detections_with_states["mag_sigma"],
166-
),
167-
state_id=detections_with_states["state_id"],
168-
)
169-
170184
def get_observers(self) -> ObserversWithStates:
171185
"""
172186
Get the observers table for these observations. The observers table

0 commit comments

Comments
 (0)