Skip to content

Commit 0f7e98e

Browse files
committed
Clarify checkpoint data types in link_test_orbit
1 parent 92c0cba commit 0f7e98e

File tree

4 files changed

+37
-45
lines changed

4 files changed

+37
-45
lines changed

thor/main.py

+29-29
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33
import pathlib
44
import time
55
from dataclasses import dataclass
6-
from typing import Any, Iterable, Iterator, List, Literal, Optional, Union
6+
from typing import Any, Iterable, Iterator, List, Literal, Optional, Tuple
77

88
import quivr as qv
99
import ray
10-
from adam_core.coordinates import CartesianCoordinates
1110
from adam_core.propagator import PYOORB
12-
from adam_core.time import Timestamp
1311

1412
from .clusters import ClusterMembers, Clusters, cluster_and_link
1513
from .config import Config, initialize_config
@@ -26,6 +24,16 @@
2624

2725
logger = logging.getLogger("thor")
2826

27+
VALID_STAGES = Literal[
28+
"filter_observations",
29+
"range_and_transform",
30+
"cluster_and_link",
31+
"initial_orbit_determination",
32+
"differential_correction",
33+
"recover_orbits",
34+
"complete",
35+
]
36+
2937

3038
def initialize_use_ray(config: Config) -> bool:
3139
use_ray = False
@@ -43,25 +51,17 @@ def initialize_use_ray(config: Config) -> bool:
4351

4452
@dataclass
4553
class CheckpointData:
46-
stage: Literal[
47-
"filter_observations",
48-
"range_and_transform",
49-
"cluster_and_link",
50-
"initial_orbit_determination",
51-
"differential_correction",
52-
"recover_orbits",
53-
"complete",
54-
]
55-
filtered_observations: Optional[Observations] = None
56-
transformed_detections: Optional[TransformedDetections] = None
57-
clusters: Optional[Clusters] = None
58-
cluster_members: Optional[ClusterMembers] = None
59-
iod_orbits: Optional[FittedOrbits] = None
60-
iod_orbit_members: Optional[FittedOrbitMembers] = None
61-
od_orbits: Optional[FittedOrbits] = None
62-
od_orbit_members: Optional[FittedOrbitMembers] = None
63-
recovered_orbits: Optional[FittedOrbits] = None
64-
recovered_orbit_members: Optional[FittedOrbitMembers] = None
54+
stage: VALID_STAGES
55+
filtered_observations: Observations = Observations.empty()
56+
transformed_detections: TransformedDetections = TransformedDetections.empty()
57+
clusters: Clusters = Clusters.empty()
58+
cluster_members: ClusterMembers = ClusterMembers.empty()
59+
iod_orbits: FittedOrbits = FittedOrbits.empty()
60+
iod_orbit_members: FittedOrbitMembers = FittedOrbitMembers.empty()
61+
od_orbits: FittedOrbits = FittedOrbits.empty()
62+
od_orbit_members: FittedOrbitMembers = FittedOrbitMembers.empty()
63+
recovered_orbits: FittedOrbits = FittedOrbits.empty()
64+
recovered_orbit_members: FittedOrbitMembers = FittedOrbitMembers.empty()
6565

6666

6767
def initialize_test_orbit(
@@ -86,7 +86,7 @@ def load_initial_checkpoint_values(
8686
8787
We want to avoid loading objects into memory that are not required.
8888
"""
89-
stage = "filter_observations"
89+
stage: VALID_STAGES = "filter_observations"
9090
# Without a checkpoint directory, we always start at the beginning
9191
if test_orbit_directory is None:
9292
return CheckpointData(stage=stage)
@@ -245,8 +245,8 @@ class LinkTestOrbitStageResult:
245245
"differential_correction",
246246
"recover_orbits",
247247
]
248-
result: Iterable[Any]
249-
path: Optional[Iterable[str]] = None
248+
result: Iterable[qv.AnyTable]
249+
path: Tuple[Optional[str], ...] = (None,)
250250

251251

252252
def link_test_orbit(
@@ -284,11 +284,11 @@ def link_test_orbit(
284284

285285
test_orbit_directory = None
286286
if working_dir is not None:
287-
working_dir = pathlib.Path(working_dir)
287+
working_dir_path = pathlib.Path(working_dir)
288288
logger.info(f"Using working directory: {working_dir}")
289-
test_orbit_directory = pathlib.Path(working_dir, test_orbit.orbit_id)
289+
test_orbit_directory = pathlib.Path(working_dir_path, test_orbit.orbit_id)
290290
test_orbit_directory.mkdir(parents=True, exist_ok=True)
291-
inputs_dir = pathlib.Path(working_dir, "inputs")
291+
inputs_dir = pathlib.Path(working_dir_path, "inputs")
292292
inputs_dir.mkdir(parents=True, exist_ok=True)
293293

294294
initialize_test_orbit(test_orbit, working_dir)
@@ -317,7 +317,7 @@ def link_test_orbit(
317317

318318
if checkpoint.stage == "complete":
319319
logger.info("Found recovered orbits in checkpoint, exiting early...")
320-
path = None
320+
path: Tuple[Optional[str], ...] = (None,)
321321
if test_orbit_directory:
322322
path = (
323323
os.path.join(test_orbit_directory, "recovered_orbits.parquet"),

thor/observations/filters.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,11 @@ def filter_observations(
304304
if len(filtered_observations) > 0:
305305
filtered_observations = qv.defragment(filtered_observations)
306306
filtered_observations = filtered_observations.sort_by(
307-
["coordinates.time.days", "coordinates.time.nanos", "coordinates.origin.code"]
307+
[
308+
"coordinates.time.days",
309+
"coordinates.time.nanos",
310+
"coordinates.origin.code",
311+
]
308312
)
309313

310314
return filtered_observations

thor/observations/states.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ def calculate_state_ids(
2828

2929
# Append index column so we can maintain the original order
3030
table = table.append_column(
31-
pa.field("index", pa.int64()),
32-
pa.array(np.arange(0, len(table)))
31+
pa.field("index", pa.int64()), pa.array(np.arange(0, len(table)))
3332
)
3433

3534
# Select only the relevant columns
@@ -63,6 +62,5 @@ def calculate_state_ids(
6362
unique_time_origins, ["time.days", "time.nanos", "origin.code"]
6463
).sort_by([("index", "ascending")])
6564

66-
6765
# Now return the state IDs
6866
return coordinates_with_states.column("state_id").combine_chunks()

thor/tests/test_main.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,7 @@ def test_link_test_orbit(
238238
else:
239239
integration_config.max_processes = 1
240240

241-
(
242-
test_orbit,
243-
observations,
244-
obs_ids_expected,
245-
integration_config,
246-
) = setup_test_data(
241+
(test_orbit, observations, obs_ids_expected, integration_config,) = setup_test_data(
247242
object_id, orbits, observations, integration_config, max_arc_length=14
248243
)
249244

@@ -270,12 +265,7 @@ def test_benchmark_link_test_orbit(
270265
else:
271266
integration_config.max_processes = 1
272267

273-
(
274-
test_orbit,
275-
observations,
276-
obs_ids_expected,
277-
integration_config,
278-
) = setup_test_data(
268+
(test_orbit, observations, obs_ids_expected, integration_config,) = setup_test_data(
279269
object_id, orbits, observations, integration_config, max_arc_length=14
280270
)
281271

0 commit comments

Comments
 (0)