Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always initialize to use existing cluster if it exists, set default o… #130

Merged
merged 6 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ venv/
ENV/
env.bak/
venv.bak/
.python-version

# Spyder project settings
.spyderproject
Expand Down
36 changes: 17 additions & 19 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,32 +1,30 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
default_language_version:
python: python3.10
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/PyCQA/isort
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
additional_dependencies:
- toml
- repo: https://github.com/psf/black
- id: isort
additional_dependencies:
- toml
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
- id: black
- repo: https://github.com/pre-commit/mirrors-mypy
- id: black
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.1.1
hooks:
- id: mypy
- id: mypy
exclude: bench/
additional_dependencies:
- 'types-pyyaml'
- 'types-requests'
- 'types-python-dateutil'
- "types-pyyaml"
- "types-requests"
- "types-python-dateutil"
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ setup_requires =
wheel
setuptools_scm >= 6.0
install_requires =
adam-core @ git+https://github.com/B612-Asteroid-Institute/adam_core@main
adam-core @ git+https://github.com/B612-Asteroid-Institute/adam_core@main#egg=adam_core
astropy >= 5.3.1
astroquery
difi
Expand All @@ -40,6 +40,7 @@ install_requires =
numpy
numba
pandas
psutil
pyarrow >= 14.0.0
pydantic < 2.0.0
pyyaml >= 5.1
Expand Down
17 changes: 3 additions & 14 deletions thor/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import quivr as qv
import ray
from adam_core.propagator import _iterate_chunks
from adam_core.ray_cluster import initialize_use_ray

from .range_and_transform import TransformedDetections

Expand Down Expand Up @@ -507,7 +508,6 @@ def cluster_velocity(
if len(clusters) == 0:
return Clusters.empty(), ClusterMembers.empty()
else:

cluster_ids = []
cluster_num_obs = []
cluster_members_cluster_ids = []
Expand Down Expand Up @@ -633,9 +633,6 @@ def cluster_and_link(
Algorithm to use. Can be "dbscan" or "hotspot_2d".
num_jobs : int, optional
Number of jobs to launch.
parallel_backend : str, optional
Which parallelization backend to use {'ray', 'mp', 'cf'}.
Defaults to using Python's concurrent futures module ('cf').

Returns
-------
Expand Down Expand Up @@ -691,14 +688,8 @@ def cluster_and_link(
mjd0 = mjd[first][0]
dt = mjd - mjd0

if max_processes is None or max_processes > 1:

if not ray.is_initialized():
logger.info(
f"Ray is not initialized. Initializing with {max_processes}..."
)
ray.init(address="auto", num_cpus=max_processes)

use_ray = initialize_use_ray(num_cpus=max_processes)
if use_ray:
# Put all arrays (which can be large) in ray's
# local object store ahead of time
obs_ids_ref = ray.put(obs_ids)
Expand Down Expand Up @@ -742,7 +733,6 @@ def cluster_and_link(
)

else:

for vxi_chunk, vyi_chunk in zip(
_iterate_chunks(vxx, chunk_size), _iterate_chunks(vyy, chunk_size)
):
Expand Down Expand Up @@ -776,7 +766,6 @@ def cluster_and_link(
)

else:

clusters = Clusters.empty()
cluster_members = ClusterMembers.empty()

Expand Down
1 change: 1 addition & 0 deletions thor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

class Config(BaseModel):
max_processes: Optional[int] = None
ray_memory_bytes: int = 0
propagator: Literal["PYOORB"] = "PYOORB"
cell_radius: float = 10
vx_min: float = -0.1
Expand Down
19 changes: 4 additions & 15 deletions thor/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import quivr as qv
import ray
from adam_core.propagator import PYOORB
from adam_core.ray_cluster import initialize_use_ray

from .checkpointing import create_checkpoint_data, load_initial_checkpoint_values
from .clusters import cluster_and_link
Expand All @@ -25,20 +26,6 @@
logger = logging.getLogger("thor")


def initialize_use_ray(config: Config) -> bool:
use_ray = False
if config.max_processes is None or config.max_processes > 1:
# Initialize ray
if not ray.is_initialized():
logger.info(
f"Ray is not initialized. Initializing with {config.max_processes} cpus..."
)
ray.init(num_cpus=config.max_processes)

use_ray = True
return use_ray


def initialize_test_orbit(
test_orbit: TestOrbits,
working_dir: Optional[str] = None,
Expand Down Expand Up @@ -132,7 +119,9 @@ def link_test_orbit(
else:
raise ValueError(f"Unknown propagator: {config.propagator}")

use_ray = initialize_use_ray(config)
use_ray = initialize_use_ray(
num_cpus=config.max_processes, object_store_bytes=config.ray_memory_bytes
)

refs_to_free = []
if (
Expand Down
10 changes: 3 additions & 7 deletions thor/observations/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import quivr as qv
import ray
from adam_core.coordinates import SphericalCoordinates
from adam_core.ray_cluster import initialize_use_ray

from thor.config import Config
from thor.observations.observations import Observations
Expand Down Expand Up @@ -160,13 +161,8 @@ def apply(
ephemeris = test_orbit.generate_ephemeris_from_observations(observations)

filtered_observations_list = []
if max_processes is None or max_processes > 1:
if not ray.is_initialized():
logger.info(
f"Ray is not initialized. Initializing with {max_processes}..."
)
ray.init(num_cpus=max_processes)

use_ray = initialize_use_ray(num_cpus=max_processes)
if use_ray:
refs_to_free = []
if observations_ref is None:
observations_ref = ray.put(observations)
Expand Down
15 changes: 10 additions & 5 deletions thor/orbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,13 @@


class RangedPointSourceDetections(qv.Table):

id = qv.StringColumn()
exposure_id = qv.StringColumn()
coordinates = SphericalCoordinates.as_column()
state_id = qv.Int64Column()


class TestOrbitEphemeris(qv.Table):

id = qv.Int64Column()
ephemeris = Ephemeris.as_column()
observer = Observers.as_column()
Expand Down Expand Up @@ -97,7 +95,6 @@ def range_observations_worker(


class TestOrbits(qv.Table):

orbit_id = qv.StringColumn(default=lambda: uuid.uuid4().hex)
object_id = qv.StringColumn(nullable=True)
bundle_id = qv.Int64Column(nullable=True)
Expand Down Expand Up @@ -199,7 +196,11 @@ def propagate(
The test orbit propagated to the given times.
"""
return propagator.propagate_orbits(
self.to_orbits(), times, max_processes=max_processes, chunk_size=1
self.to_orbits(),
times,
max_processes=max_processes,
chunk_size=1,
parallel_backend="ray",
)

def generate_ephemeris(
Expand All @@ -226,7 +227,11 @@ def generate_ephemeris(
The ephemeris of the test orbit at the given observers.
"""
return propagator.generate_ephemeris(
self.to_orbits(), observers, max_processes=max_processes, chunk_size=1
self.to_orbits(),
observers,
max_processes=max_processes,
chunk_size=1,
parallel_backend="ray",
)

def generate_ephemeris_from_observations(
Expand Down
2 changes: 1 addition & 1 deletion thor/orbit_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def generate_test_orbits(
start_time,
max_processes=max_processes,
parallel_backend="ray",
chunk_size=1000,
chunk_size=500,
)
propagation_end_time = time.perf_counter()
logger.info(
Expand Down
19 changes: 4 additions & 15 deletions thor/orbits/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from adam_core.orbits import Orbits
from adam_core.propagator import PYOORB
from adam_core.propagator.utils import _iterate_chunks
from adam_core.ray_cluster import initialize_use_ray
from sklearn.neighbors import BallTree

from ..observations.observations import Observations
Expand Down Expand Up @@ -165,7 +166,6 @@ def attribution_worker(
radius_rad = np.radians(radius)
residuals = []
for _, ephemeris_i, observations_i in linkage.iterate():

# Extract the observation IDs and times
obs_ids = observations_i.id.to_numpy(zero_copy_only=False)
obs_times = observations_i.coordinates.time.mjd().to_numpy(zero_copy_only=False)
Expand Down Expand Up @@ -279,12 +279,8 @@ def attribute_observations(
observation_indices = np.arange(0, len(observations))

attributions_list = []
if max_processes is None or max_processes > 1:

if not ray.is_initialized():
logger.info(f"Ray is not initialized. Initializing with {max_processes}...")
ray.init(address="auto", max_processes=max_processes)

use_ray = initialize_use_ray(num_cpus=max_processes)
if use_ray:
refs_to_free = []
if orbits_ref is None:
orbits_ref = ray.put(orbits)
Expand Down Expand Up @@ -421,14 +417,8 @@ def merge_and_extend_orbits(
odp_orbits_list = []
odp_orbit_members_list = []
if len(orbits_iter) > 0 and len(observations_iter) > 0:

use_ray = initialize_use_ray(num_cpus=max_processes)
if use_ray:
if not ray.is_initialized():
logger.info(
f"Ray is not initialized. Initializing with {max_processes}..."
)
ray.init(address="auto", max_processes=max_processes)

refs_to_free = []
if observations_ref is None:
observations_ref = ray.put(observations)
Expand All @@ -437,7 +427,6 @@ def merge_and_extend_orbits(

converged = False
while not converged:

if use_ray:
# Orbits will change with differential correction so we need to add them
# to the object store at the start of each iteration (we cannot simply
Expand Down
17 changes: 3 additions & 14 deletions thor/orbits/iod.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from adam_core.coordinates.residuals import Residuals
from adam_core.propagator import PYOORB, Propagator
from adam_core.propagator.utils import _iterate_chunks
from adam_core.ray_cluster import initialize_use_ray

from ..clusters import ClusterMembers
from ..observations.observations import Observations
Expand Down Expand Up @@ -126,13 +127,11 @@ def iod_worker(
propagator: Type[Propagator] = PYOORB,
propagator_kwargs: dict = {},
) -> Tuple[FittedOrbits, FittedOrbitMembers]:

prop = propagator(**propagator_kwargs)

iod_orbits_list = []
iod_orbit_members_list = []
for linkage_id in linkage_ids:

time_start = time.time()
logger.debug(f"Finding initial orbit for linkage {linkage_id}...")

Expand Down Expand Up @@ -379,7 +378,6 @@ def iod(
# belonging to one object yield a good initial orbit but the presence of outlier
# observations is skewing the sum total of the residuals and chi2
elif num_outliers > 0:

logger.debug("Attempting to identify possible outliers.")
for o in range(num_outliers):
# Select i highest observations that contribute to
Expand Down Expand Up @@ -424,11 +422,9 @@ def iod(
j += 1

if not converged or not processable:

return FittedOrbits.empty(), FittedOrbitMembers.empty()

else:

orbit = FittedOrbits.from_kwargs(
orbit_id=orbit_sol.orbit_id,
object_id=orbit_sol.object_id,
Expand Down Expand Up @@ -574,18 +570,11 @@ def initial_orbit_determination(
iod_orbits_list = []
iod_orbit_members_list = []
if len(observations) > 0 and len(linkage_members) > 0:

# Extract linkage IDs
linkage_ids = linkage_members.column(linkage_id_col).unique()

if max_processes is None or max_processes > 1:

if not ray.is_initialized():
logger.info(
f"Ray is not initialized. Initializing with {max_processes}..."
)
ray.init(address="auto", num_cpus=max_processes)

use_ray = initialize_use_ray(num_cpus=max_processes)
if use_ray:
refs_to_free = []
if linkage_members_ref is None:
linkage_members_ref = ray.put(linkage_members)
Expand Down
Loading