Skip to content

Commit 2310489

Browse files
authored
Use adam_core's new initialize_use_ray utility (#130)
1 parent f623dd8 commit 2310489

13 files changed

+53
-110
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ venv/
132132
ENV/
133133
env.bak/
134134
venv.bak/
135+
.python-version
135136

136137
# Spyder project settings
137138
.spyderproject

.pre-commit-config.yaml

+17-19
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,30 @@
11
# See https://pre-commit.com for more information
22
# See https://pre-commit.com/hooks.html for more hooks
3-
default_language_version:
4-
python: python3.10
53
repos:
6-
- repo: https://github.com/pre-commit/pre-commit-hooks
4+
- repo: https://github.com/pre-commit/pre-commit-hooks
75
rev: v3.2.0
86
hooks:
9-
- id: trailing-whitespace
10-
args: [--markdown-linebreak-ext=md]
11-
- id: end-of-file-fixer
12-
- id: check-yaml
13-
- id: check-added-large-files
14-
- repo: https://github.com/PyCQA/isort
7+
- id: trailing-whitespace
8+
args: [--markdown-linebreak-ext=md]
9+
- id: end-of-file-fixer
10+
- id: check-yaml
11+
- id: check-added-large-files
12+
- repo: https://github.com/PyCQA/isort
1513
rev: 5.12.0
1614
hooks:
17-
- id: isort
18-
additional_dependencies:
19-
- toml
20-
- repo: https://github.com/psf/black
15+
- id: isort
16+
additional_dependencies:
17+
- toml
18+
- repo: https://github.com/psf/black
2119
rev: 22.10.0
2220
hooks:
23-
- id: black
24-
- repo: https://github.com/pre-commit/mirrors-mypy
21+
- id: black
22+
- repo: https://github.com/pre-commit/mirrors-mypy
2523
rev: v1.1.1
2624
hooks:
27-
- id: mypy
25+
- id: mypy
2826
exclude: bench/
2927
additional_dependencies:
30-
- 'types-pyyaml'
31-
- 'types-requests'
32-
- 'types-python-dateutil'
28+
- "types-pyyaml"
29+
- "types-requests"
30+
- "types-python-dateutil"

setup.cfg

+2-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ setup_requires =
3131
wheel
3232
setuptools_scm >= 6.0
3333
install_requires =
34-
adam-core @ git+https://github.com/B612-Asteroid-Institute/adam_core@main
34+
adam-core @ git+https://github.com/B612-Asteroid-Institute/adam_core@main#egg=adam_core
3535
astropy >= 5.3.1
3636
astroquery
3737
difi
@@ -40,6 +40,7 @@ install_requires =
4040
numpy
4141
numba
4242
pandas
43+
psutil
4344
pyarrow >= 14.0.0
4445
pydantic < 2.0.0
4546
pyyaml >= 5.1

thor/clusters.py

+3-14
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import quivr as qv
1212
import ray
1313
from adam_core.propagator import _iterate_chunks
14+
from adam_core.ray_cluster import initialize_use_ray
1415

1516
from .range_and_transform import TransformedDetections
1617

@@ -507,7 +508,6 @@ def cluster_velocity(
507508
if len(clusters) == 0:
508509
return Clusters.empty(), ClusterMembers.empty()
509510
else:
510-
511511
cluster_ids = []
512512
cluster_num_obs = []
513513
cluster_members_cluster_ids = []
@@ -633,9 +633,6 @@ def cluster_and_link(
633633
Algorithm to use. Can be "dbscan" or "hotspot_2d".
634634
num_jobs : int, optional
635635
Number of jobs to launch.
636-
parallel_backend : str, optional
637-
Which parallelization backend to use {'ray', 'mp', 'cf'}.
638-
Defaults to using Python's concurrent futures module ('cf').
639636
640637
Returns
641638
-------
@@ -691,14 +688,8 @@ def cluster_and_link(
691688
mjd0 = mjd[first][0]
692689
dt = mjd - mjd0
693690

694-
if max_processes is None or max_processes > 1:
695-
696-
if not ray.is_initialized():
697-
logger.info(
698-
f"Ray is not initialized. Initializing with {max_processes}..."
699-
)
700-
ray.init(address="auto", num_cpus=max_processes)
701-
691+
use_ray = initialize_use_ray(num_cpus=max_processes)
692+
if use_ray:
702693
# Put all arrays (which can be large) in ray's
703694
# local object store ahead of time
704695
obs_ids_ref = ray.put(obs_ids)
@@ -742,7 +733,6 @@ def cluster_and_link(
742733
)
743734

744735
else:
745-
746736
for vxi_chunk, vyi_chunk in zip(
747737
_iterate_chunks(vxx, chunk_size), _iterate_chunks(vyy, chunk_size)
748738
):
@@ -776,7 +766,6 @@ def cluster_and_link(
776766
)
777767

778768
else:
779-
780769
clusters = Clusters.empty()
781770
cluster_members = ClusterMembers.empty()
782771

thor/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
class Config(BaseModel):
1111
max_processes: Optional[int] = None
12+
ray_memory_bytes: int = 0
1213
propagator: Literal["PYOORB"] = "PYOORB"
1314
cell_radius: float = 10
1415
vx_min: float = -0.1

thor/main.py

+4-15
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import quivr as qv
99
import ray
1010
from adam_core.propagator import PYOORB
11+
from adam_core.ray_cluster import initialize_use_ray
1112

1213
from .checkpointing import create_checkpoint_data, load_initial_checkpoint_values
1314
from .clusters import cluster_and_link
@@ -25,20 +26,6 @@
2526
logger = logging.getLogger("thor")
2627

2728

28-
def initialize_use_ray(config: Config) -> bool:
29-
use_ray = False
30-
if config.max_processes is None or config.max_processes > 1:
31-
# Initialize ray
32-
if not ray.is_initialized():
33-
logger.info(
34-
f"Ray is not initialized. Initializing with {config.max_processes} cpus..."
35-
)
36-
ray.init(num_cpus=config.max_processes)
37-
38-
use_ray = True
39-
return use_ray
40-
41-
4229
def initialize_test_orbit(
4330
test_orbit: TestOrbits,
4431
working_dir: Optional[str] = None,
@@ -132,7 +119,9 @@ def link_test_orbit(
132119
else:
133120
raise ValueError(f"Unknown propagator: {config.propagator}")
134121

135-
use_ray = initialize_use_ray(config)
122+
use_ray = initialize_use_ray(
123+
num_cpus=config.max_processes, object_store_bytes=config.ray_memory_bytes
124+
)
136125

137126
refs_to_free = []
138127
if (

thor/observations/filters.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import quivr as qv
88
import ray
99
from adam_core.coordinates import SphericalCoordinates
10+
from adam_core.ray_cluster import initialize_use_ray
1011

1112
from thor.config import Config
1213
from thor.observations.observations import Observations
@@ -160,13 +161,8 @@ def apply(
160161
ephemeris = test_orbit.generate_ephemeris_from_observations(observations)
161162

162163
filtered_observations_list = []
163-
if max_processes is None or max_processes > 1:
164-
if not ray.is_initialized():
165-
logger.info(
166-
f"Ray is not initialized. Initializing with {max_processes}..."
167-
)
168-
ray.init(num_cpus=max_processes)
169-
164+
use_ray = initialize_use_ray(num_cpus=max_processes)
165+
if use_ray:
170166
refs_to_free = []
171167
if observations_ref is None:
172168
observations_ref = ray.put(observations)

thor/orbit.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,13 @@
3838

3939

4040
class RangedPointSourceDetections(qv.Table):
41-
4241
id = qv.StringColumn()
4342
exposure_id = qv.StringColumn()
4443
coordinates = SphericalCoordinates.as_column()
4544
state_id = qv.Int64Column()
4645

4746

4847
class TestOrbitEphemeris(qv.Table):
49-
5048
id = qv.Int64Column()
5149
ephemeris = Ephemeris.as_column()
5250
observer = Observers.as_column()
@@ -97,7 +95,6 @@ def range_observations_worker(
9795

9896

9997
class TestOrbits(qv.Table):
100-
10198
orbit_id = qv.StringColumn(default=lambda: uuid.uuid4().hex)
10299
object_id = qv.StringColumn(nullable=True)
103100
bundle_id = qv.Int64Column(nullable=True)
@@ -199,7 +196,11 @@ def propagate(
199196
The test orbit propagated to the given times.
200197
"""
201198
return propagator.propagate_orbits(
202-
self.to_orbits(), times, max_processes=max_processes, chunk_size=1
199+
self.to_orbits(),
200+
times,
201+
max_processes=max_processes,
202+
chunk_size=1,
203+
parallel_backend="ray",
203204
)
204205

205206
def generate_ephemeris(
@@ -226,7 +227,11 @@ def generate_ephemeris(
226227
The ephemeris of the test orbit at the given observers.
227228
"""
228229
return propagator.generate_ephemeris(
229-
self.to_orbits(), observers, max_processes=max_processes, chunk_size=1
230+
self.to_orbits(),
231+
observers,
232+
max_processes=max_processes,
233+
chunk_size=1,
234+
parallel_backend="ray",
230235
)
231236

232237
def generate_ephemeris_from_observations(

thor/orbit_selection.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def generate_test_orbits(
232232
start_time,
233233
max_processes=max_processes,
234234
parallel_backend="ray",
235-
chunk_size=1000,
235+
chunk_size=500,
236236
)
237237
propagation_end_time = time.perf_counter()
238238
logger.info(

thor/orbits/attribution.py

+4-15
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from adam_core.orbits import Orbits
1313
from adam_core.propagator import PYOORB
1414
from adam_core.propagator.utils import _iterate_chunks
15+
from adam_core.ray_cluster import initialize_use_ray
1516
from sklearn.neighbors import BallTree
1617

1718
from ..observations.observations import Observations
@@ -165,7 +166,6 @@ def attribution_worker(
165166
radius_rad = np.radians(radius)
166167
residuals = []
167168
for _, ephemeris_i, observations_i in linkage.iterate():
168-
169169
# Extract the observation IDs and times
170170
obs_ids = observations_i.id.to_numpy(zero_copy_only=False)
171171
obs_times = observations_i.coordinates.time.mjd().to_numpy(zero_copy_only=False)
@@ -279,12 +279,8 @@ def attribute_observations(
279279
observation_indices = np.arange(0, len(observations))
280280

281281
attributions_list = []
282-
if max_processes is None or max_processes > 1:
283-
284-
if not ray.is_initialized():
285-
logger.info(f"Ray is not initialized. Initializing with {max_processes}...")
286-
ray.init(address="auto", max_processes=max_processes)
287-
282+
use_ray = initialize_use_ray(num_cpus=max_processes)
283+
if use_ray:
288284
refs_to_free = []
289285
if orbits_ref is None:
290286
orbits_ref = ray.put(orbits)
@@ -421,14 +417,8 @@ def merge_and_extend_orbits(
421417
odp_orbits_list = []
422418
odp_orbit_members_list = []
423419
if len(orbits_iter) > 0 and len(observations_iter) > 0:
424-
420+
use_ray = initialize_use_ray(num_cpus=max_processes)
425421
if use_ray:
426-
if not ray.is_initialized():
427-
logger.info(
428-
f"Ray is not initialized. Initializing with {max_processes}..."
429-
)
430-
ray.init(address="auto", max_processes=max_processes)
431-
432422
refs_to_free = []
433423
if observations_ref is None:
434424
observations_ref = ray.put(observations)
@@ -437,7 +427,6 @@ def merge_and_extend_orbits(
437427

438428
converged = False
439429
while not converged:
440-
441430
if use_ray:
442431
# Orbits will change with differential correction so we need to add them
443432
# to the object store at the start of each iteration (we cannot simply

thor/orbits/iod.py

+3-14
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from adam_core.coordinates.residuals import Residuals
1313
from adam_core.propagator import PYOORB, Propagator
1414
from adam_core.propagator.utils import _iterate_chunks
15+
from adam_core.ray_cluster import initialize_use_ray
1516

1617
from ..clusters import ClusterMembers
1718
from ..observations.observations import Observations
@@ -126,13 +127,11 @@ def iod_worker(
126127
propagator: Type[Propagator] = PYOORB,
127128
propagator_kwargs: dict = {},
128129
) -> Tuple[FittedOrbits, FittedOrbitMembers]:
129-
130130
prop = propagator(**propagator_kwargs)
131131

132132
iod_orbits_list = []
133133
iod_orbit_members_list = []
134134
for linkage_id in linkage_ids:
135-
136135
time_start = time.time()
137136
logger.debug(f"Finding initial orbit for linkage {linkage_id}...")
138137

@@ -379,7 +378,6 @@ def iod(
379378
# belonging to one object yield a good initial orbit but the presence of outlier
380379
# observations is skewing the sum total of the residuals and chi2
381380
elif num_outliers > 0:
382-
383381
logger.debug("Attempting to identify possible outliers.")
384382
for o in range(num_outliers):
385383
# Select i highest observations that contribute to
@@ -424,11 +422,9 @@ def iod(
424422
j += 1
425423

426424
if not converged or not processable:
427-
428425
return FittedOrbits.empty(), FittedOrbitMembers.empty()
429426

430427
else:
431-
432428
orbit = FittedOrbits.from_kwargs(
433429
orbit_id=orbit_sol.orbit_id,
434430
object_id=orbit_sol.object_id,
@@ -574,18 +570,11 @@ def initial_orbit_determination(
574570
iod_orbits_list = []
575571
iod_orbit_members_list = []
576572
if len(observations) > 0 and len(linkage_members) > 0:
577-
578573
# Extract linkage IDs
579574
linkage_ids = linkage_members.column(linkage_id_col).unique()
580575

581-
if max_processes is None or max_processes > 1:
582-
583-
if not ray.is_initialized():
584-
logger.info(
585-
f"Ray is not initialized. Initializing with {max_processes}..."
586-
)
587-
ray.init(address="auto", num_cpus=max_processes)
588-
576+
use_ray = initialize_use_ray(num_cpus=max_processes)
577+
if use_ray:
589578
refs_to_free = []
590579
if linkage_members_ref is None:
591580
linkage_members_ref = ray.put(linkage_members)

0 commit comments

Comments
 (0)