3
3
import pathlib
4
4
import time
5
5
from dataclasses import dataclass
6
- from typing import Any , Iterable , Iterator , List , Literal , Optional , Union
6
+ from typing import Any , Iterable , Iterator , List , Literal , Optional , Tuple
7
7
8
8
import quivr as qv
9
9
import ray
10
- from adam_core .coordinates import CartesianCoordinates
11
10
from adam_core .propagator import PYOORB
12
- from adam_core .time import Timestamp
13
11
14
12
from .clusters import ClusterMembers , Clusters , cluster_and_link
15
13
from .config import Config , initialize_config
26
24
27
25
logger = logging .getLogger ("thor" )
28
26
27
+ VALID_STAGES = Literal [
28
+ "filter_observations" ,
29
+ "range_and_transform" ,
30
+ "cluster_and_link" ,
31
+ "initial_orbit_determination" ,
32
+ "differential_correction" ,
33
+ "recover_orbits" ,
34
+ "complete" ,
35
+ ]
36
+
29
37
30
38
def initialize_use_ray (config : Config ) -> bool :
31
39
use_ray = False
@@ -43,25 +51,17 @@ def initialize_use_ray(config: Config) -> bool:
43
51
44
52
@dataclass
45
53
class CheckpointData :
46
- stage : Literal [
47
- "filter_observations" ,
48
- "range_and_transform" ,
49
- "cluster_and_link" ,
50
- "initial_orbit_determination" ,
51
- "differential_correction" ,
52
- "recover_orbits" ,
53
- "complete" ,
54
- ]
55
- filtered_observations : Optional [Observations ] = None
56
- transformed_detections : Optional [TransformedDetections ] = None
57
- clusters : Optional [Clusters ] = None
58
- cluster_members : Optional [ClusterMembers ] = None
59
- iod_orbits : Optional [FittedOrbits ] = None
60
- iod_orbit_members : Optional [FittedOrbitMembers ] = None
61
- od_orbits : Optional [FittedOrbits ] = None
62
- od_orbit_members : Optional [FittedOrbitMembers ] = None
63
- recovered_orbits : Optional [FittedOrbits ] = None
64
- recovered_orbit_members : Optional [FittedOrbitMembers ] = None
54
+ stage : VALID_STAGES
55
+ filtered_observations : Observations = Observations .empty ()
56
+ transformed_detections : TransformedDetections = TransformedDetections .empty ()
57
+ clusters : Clusters = Clusters .empty ()
58
+ cluster_members : ClusterMembers = ClusterMembers .empty ()
59
+ iod_orbits : FittedOrbits = FittedOrbits .empty ()
60
+ iod_orbit_members : FittedOrbitMembers = FittedOrbitMembers .empty ()
61
+ od_orbits : FittedOrbits = FittedOrbits .empty ()
62
+ od_orbit_members : FittedOrbitMembers = FittedOrbitMembers .empty ()
63
+ recovered_orbits : FittedOrbits = FittedOrbits .empty ()
64
+ recovered_orbit_members : FittedOrbitMembers = FittedOrbitMembers .empty ()
65
65
66
66
67
67
def initialize_test_orbit (
@@ -86,7 +86,7 @@ def load_initial_checkpoint_values(
86
86
87
87
We want to avoid loading objects into memory that are not required.
88
88
"""
89
- stage = "filter_observations"
89
+ stage : VALID_STAGES = "filter_observations"
90
90
# Without a checkpoint directory, we always start at the beginning
91
91
if test_orbit_directory is None :
92
92
return CheckpointData (stage = stage )
@@ -245,8 +245,8 @@ class LinkTestOrbitStageResult:
245
245
"differential_correction" ,
246
246
"recover_orbits" ,
247
247
]
248
- result : Iterable [Any ]
249
- path : Optional [Iterable [ str ]] = None
248
+ result : Iterable [qv . AnyTable ]
249
+ path : Tuple [ Optional [str ], ... ] = ( None ,)
250
250
251
251
252
252
def link_test_orbit (
@@ -284,11 +284,11 @@ def link_test_orbit(
284
284
285
285
test_orbit_directory = None
286
286
if working_dir is not None :
287
- working_dir = pathlib .Path (working_dir )
287
+ working_dir_path = pathlib .Path (working_dir )
288
288
logger .info (f"Using working directory: { working_dir } " )
289
- test_orbit_directory = pathlib .Path (working_dir , test_orbit .orbit_id )
289
+ test_orbit_directory = pathlib .Path (working_dir_path , test_orbit .orbit_id )
290
290
test_orbit_directory .mkdir (parents = True , exist_ok = True )
291
- inputs_dir = pathlib .Path (working_dir , "inputs" )
291
+ inputs_dir = pathlib .Path (working_dir_path , "inputs" )
292
292
inputs_dir .mkdir (parents = True , exist_ok = True )
293
293
294
294
initialize_test_orbit (test_orbit , working_dir )
@@ -317,7 +317,7 @@ def link_test_orbit(
317
317
318
318
if checkpoint .stage == "complete" :
319
319
logger .info ("Found recovered orbits in checkpoint, exiting early..." )
320
- path = None
320
+ path : Tuple [ Optional [ str ], ...] = ( None ,)
321
321
if test_orbit_directory :
322
322
path = (
323
323
os .path .join (test_orbit_directory , "recovered_orbits.parquet" ),
0 commit comments