Skip to content

Commit 7ab1230

Browse files
committed
Fix checkpoint typing
1 parent 840bcf0 commit 7ab1230

File tree

5 files changed

+316
-213
lines changed

5 files changed

+316
-213
lines changed

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ install_requires =
4141
numba
4242
pandas
4343
pyarrow >= 14.0.0
44-
pydantic
44+
pydantic < 2.0.0
4545
pyyaml >= 5.1
4646
quivr @ git+https://github.com/moeyensj/quivr@concatenate-empty-attributes
4747
ray[default]

thor/checkpointing.py

+273
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
import logging
2+
import pathlib
3+
from typing import Annotated, Dict, Literal, Optional, Type, Union
4+
5+
import pydantic
6+
import quivr as qv
7+
import ray
8+
9+
from thor.clusters import ClusterMembers, Clusters
10+
from thor.observations.observations import Observations
11+
from thor.orbit_determination.fitted_orbits import FittedOrbitMembers, FittedOrbits
12+
from thor.range_and_transform import TransformedDetections
13+
14+
logger = logging.getLogger("thor")
15+
16+
17+
VALID_STAGES = Literal[
18+
"filter_observations",
19+
"range_and_transform",
20+
"cluster_and_link",
21+
"initial_orbit_determination",
22+
"differential_correction",
23+
"recover_orbits",
24+
"complete",
25+
]
26+
27+
28+
class FilterObservations(pydantic.BaseModel):
29+
stage: Literal["filter_observations"]
30+
31+
32+
class RangeAndTransform(pydantic.BaseModel):
33+
class Config:
34+
arbitrary_types_allowed = True
35+
36+
stage: Literal["range_and_transform"]
37+
filtered_observations: Union[Observations, ray.ObjectRef]
38+
39+
40+
class ClusterAndLink(pydantic.BaseModel):
41+
class Config:
42+
arbitrary_types_allowed = True
43+
44+
stage: Literal["cluster_and_link"]
45+
filtered_observations: Union[Observations, ray.ObjectRef]
46+
transformed_detections: TransformedDetections
47+
48+
49+
class InitialOrbitDetermination(pydantic.BaseModel):
50+
class Config:
51+
arbitrary_types_allowed = True
52+
53+
stage: Literal["initial_orbit_determination"]
54+
filtered_observations: Observations
55+
clusters: Clusters
56+
cluster_members: ClusterMembers
57+
58+
59+
class DifferentialCorrection(pydantic.BaseModel):
60+
class Config:
61+
arbitrary_types_allowed = True
62+
63+
stage: Literal["differential_correction"]
64+
filtered_observations: Observations
65+
iod_orbits: FittedOrbits
66+
iod_orbit_members: FittedOrbitMembers
67+
68+
69+
class RecoverOrbits(pydantic.BaseModel):
70+
class Config:
71+
arbitrary_types_allowed = True
72+
73+
stage: Literal["recover_orbits"]
74+
filtered_observations: Observations
75+
od_orbits: FittedOrbits
76+
od_orbit_members: FittedOrbitMembers
77+
78+
79+
class Complete(pydantic.BaseModel):
80+
class Config:
81+
arbitrary_types_allowed = True
82+
83+
stage: Literal["complete"]
84+
recovered_orbits: FittedOrbits
85+
recovered_orbit_members: FittedOrbitMembers
86+
87+
88+
CheckpointData = Annotated[
89+
Union[
90+
FilterObservations,
91+
RangeAndTransform,
92+
ClusterAndLink,
93+
InitialOrbitDetermination,
94+
DifferentialCorrection,
95+
RecoverOrbits,
96+
Complete,
97+
],
98+
pydantic.Field(discriminator="stage"),
99+
]
100+
101+
# A mapping from stage to model class
102+
stage_to_model: Dict[str, Type[pydantic.BaseModel]] = {
103+
"filter_observations": FilterObservations,
104+
"range_and_transform": RangeAndTransform,
105+
"cluster_and_link": ClusterAndLink,
106+
"initial_orbit_determination": InitialOrbitDetermination,
107+
"differential_correction": DifferentialCorrection,
108+
"recover_orbits": RecoverOrbits,
109+
"complete": Complete,
110+
}
111+
112+
113+
def create_checkpoint_data(stage: VALID_STAGES, **data) -> CheckpointData:
114+
"""
115+
Create checkpoint data from the given stage and data.
116+
"""
117+
model = stage_to_model.get(stage)
118+
if model:
119+
return model(stage=stage, **data)
120+
raise ValueError(f"Invalid stage: {stage}")
121+
122+
123+
def load_initial_checkpoint_values(
124+
test_orbit_directory: Optional[pathlib.Path] = None,
125+
) -> CheckpointData:
126+
"""
127+
Check for completed stages and return values from disk if they exist.
128+
129+
We want to avoid loading objects into memory that are not required.
130+
"""
131+
stage: VALID_STAGES = "filter_observations"
132+
# Without a checkpoint directory, we always start at the beginning
133+
if test_orbit_directory is None:
134+
return create_checkpoint_data(stage)
135+
136+
# filtered_observations is always needed when it exists
137+
filtered_observations_path = pathlib.Path(
138+
test_orbit_directory, "filtered_observations.parquet"
139+
)
140+
# If it doesn't exist, start at the beginning.
141+
if not filtered_observations_path.exists():
142+
return create_checkpoint_data(stage)
143+
logger.info("Found filtered observations")
144+
filtered_observations = Observations.from_parquet(filtered_observations_path)
145+
146+
# Unfortunately we have to reinitialize the times to set the attribute
147+
# correctly.
148+
filtered_observations = qv.defragment(filtered_observations)
149+
filtered_observations = filtered_observations.sort_by(
150+
[
151+
"coordinates.time.days",
152+
"coordinates.time.nanos",
153+
"coordinates.origin.code",
154+
]
155+
)
156+
157+
# If the pipeline was started but we have recovered_orbits already, we
158+
# are done and should exit early.
159+
recovered_orbits_path = pathlib.Path(
160+
test_orbit_directory, "recovered_orbits.parquet"
161+
)
162+
recovered_orbit_members_path = pathlib.Path(
163+
test_orbit_directory, "recovered_orbit_members.parquet"
164+
)
165+
if recovered_orbits_path.exists() and recovered_orbit_members_path.exists():
166+
logger.info("Found recovered orbits in checkpoint")
167+
recovered_orbits = FittedOrbits.from_parquet(recovered_orbits_path)
168+
recovered_orbit_members = FittedOrbitMembers.from_parquet(
169+
recovered_orbit_members_path
170+
)
171+
172+
# Unfortunately we have to reinitialize the times to set the attribute
173+
# correctly.
174+
recovered_orbits = qv.defragment(recovered_orbits)
175+
recovered_orbits = recovered_orbits.sort_by(
176+
[
177+
"coordinates.time.days",
178+
"coordinates.time.nanos",
179+
]
180+
)
181+
182+
return create_checkpoint_data(
183+
"complete",
184+
recovered_orbits=recovered_orbits,
185+
recovered_orbit_members=recovered_orbit_members,
186+
)
187+
188+
# Now with filtered_observations available, we can check for the later
189+
# stages in reverse order.
190+
od_orbits_path = pathlib.Path(test_orbit_directory, "od_orbits.parquet")
191+
od_orbit_members_path = pathlib.Path(
192+
test_orbit_directory, "od_orbit_members.parquet"
193+
)
194+
if od_orbits_path.exists() and od_orbit_members_path.exists():
195+
logger.info("Found OD orbits in checkpoint")
196+
od_orbits = FittedOrbits.from_parquet(od_orbits_path)
197+
od_orbit_members = FittedOrbitMembers.from_parquet(od_orbit_members_path)
198+
199+
# Unfortunately we have to reinitialize the times to set the attribute
200+
# correctly.
201+
od_orbits = qv.defragment(od_orbits)
202+
od_orbits = od_orbits.sort_by(
203+
[
204+
"coordinates.time.days",
205+
"coordinates.time.nanos",
206+
]
207+
)
208+
209+
return create_checkpoint_data(
210+
"recover_orbits",
211+
filtered_observations=filtered_observations,
212+
od_orbits=od_orbits,
213+
od_orbit_members=od_orbit_members,
214+
)
215+
216+
iod_orbits_path = pathlib.Path(test_orbit_directory, "iod_orbits.parquet")
217+
iod_orbit_members_path = pathlib.Path(
218+
test_orbit_directory, "iod_orbit_members.parquet"
219+
)
220+
if iod_orbits_path.exists() and iod_orbit_members_path.exists():
221+
logger.info("Found IOD orbits")
222+
iod_orbits = FittedOrbits.from_parquet(iod_orbits_path)
223+
iod_orbit_members = FittedOrbitMembers.from_parquet(iod_orbit_members_path)
224+
225+
# Unfortunately we have to reinitialize the times to set the attribute
226+
# correctly.
227+
iod_orbits = qv.defragment(iod_orbits)
228+
iod_orbits = iod_orbits.sort_by(
229+
[
230+
"coordinates.time.days",
231+
"coordinates.time.nanos",
232+
]
233+
)
234+
235+
return create_checkpoint_data(
236+
"differential_correction",
237+
filtered_observations=filtered_observations,
238+
iod_orbits=iod_orbits,
239+
iod_orbit_members=iod_orbit_members,
240+
)
241+
242+
clusters_path = pathlib.Path(test_orbit_directory, "clusters.parquet")
243+
cluster_members_path = pathlib.Path(test_orbit_directory, "cluster_members.parquet")
244+
if clusters_path.exists() and cluster_members_path.exists():
245+
logger.info("Found clusters")
246+
clusters = Clusters.from_parquet(clusters_path)
247+
cluster_members = ClusterMembers.from_parquet(cluster_members_path)
248+
249+
return create_checkpoint_data(
250+
"initial_orbit_determination",
251+
filtered_observations=filtered_observations,
252+
clusters=clusters,
253+
cluster_members=cluster_members,
254+
)
255+
256+
transformed_detections_path = pathlib.Path(
257+
test_orbit_directory, "transformed_detections.parquet"
258+
)
259+
if transformed_detections_path.exists():
260+
logger.info("Found transformed detections")
261+
transformed_detections = TransformedDetections.from_parquet(
262+
transformed_detections_path
263+
)
264+
265+
return create_checkpoint_data(
266+
"cluster_and_link",
267+
filtered_observations=filtered_observations,
268+
transformed_detections=transformed_detections,
269+
)
270+
271+
return create_checkpoint_data(
272+
"range_and_transform", filtered_observations=filtered_observations
273+
)

0 commit comments

Comments
 (0)