From 3d2a90dabb88ef0a7a9a73d8e76e1fb7f1f56c1a Mon Sep 17 00:00:00 2001 From: Alec Koumjian Date: Mon, 9 Dec 2024 10:07:50 -0500 Subject: [PATCH] Ak/more perf improvements (#130) * this seems like a better chunk size * Change chunking --- src/adam_core/coordinates/transform.py | 32 +++++++++++++------------- src/adam_core/dynamics/ephemeris.py | 19 ++++++++------- src/adam_core/dynamics/propagation.py | 11 ++++----- src/adam_core/observers/state.py | 6 +++-- src/adam_core/time/time.py | 1 - src/adam_core/utils/chunking.py | 4 ++-- src/adam_core/utils/spice.py | 9 ++++---- 7 files changed, 40 insertions(+), 42 deletions(-) diff --git a/src/adam_core/coordinates/transform.py b/src/adam_core/coordinates/transform.py index eb4b7855..03f65b18 100644 --- a/src/adam_core/coordinates/transform.py +++ b/src/adam_core/coordinates/transform.py @@ -131,9 +131,7 @@ def _cartesian_to_spherical( ) -def cartesian_to_spherical( - coords_cartesian: Union[np.ndarray, jnp.ndarray] -) -> jnp.ndarray: +def cartesian_to_spherical(coords_cartesian: np.ndarray) -> np.ndarray: """ Convert Cartesian coordinates to a spherical coordinates. @@ -163,16 +161,17 @@ def cartesian_to_spherical( (same unit of time as the x, y, and z velocities). """ # Define chunk size - chunk_size = 50 + chunk_size = 200 - # Process in chunks - coords_spherical_chunks = [] + # Process in chunk + coords_spherical: np.ndarray = np.empty((0, 6)) for cartesian_chunk in process_in_chunks(coords_cartesian, chunk_size): coords_spherical_chunk = _cartesian_to_spherical_vmap(cartesian_chunk) - coords_spherical_chunks.append(coords_spherical_chunk) + coords_spherical = np.concatenate( + (coords_spherical, np.asarray(coords_spherical_chunk)) + ) # Concatenate chunks and remove padding - coords_spherical = jnp.concatenate(coords_spherical_chunks, axis=0) coords_spherical = coords_spherical[: len(coords_cartesian)] return coords_spherical @@ -290,16 +289,17 @@ def spherical_to_cartesian( vz : z-velocity in the same units of z per arbitrary unit of time. """ # Define chunk size - chunk_size = 50 + chunk_size = 200 # Process in chunks - coords_cartesian_chunks = [] + coords_cartesian: np.ndarray = np.empty((0, 6)) for spherical_chunk in process_in_chunks(coords_spherical, chunk_size): coords_cartesian_chunk = _spherical_to_cartesian_vmap(spherical_chunk) - coords_cartesian_chunks.append(coords_cartesian_chunk) + coords_cartesian = np.concatenate( + (coords_cartesian, np.asarray(coords_cartesian_chunk)) + ) - # Concatenate chunks and remove padding - coords_cartesian = jnp.concatenate(coords_cartesian_chunks, axis=0) + # Remove padding coords_cartesian = coords_cartesian[: len(coords_spherical)] return coords_cartesian @@ -585,7 +585,7 @@ def cartesian_to_keplerian( tp : time of periapsis passage in days. """ # Define chunk size - chunk_size = 50 + chunk_size = 200 # Process in chunks coords_keplerian_chunks = [] @@ -989,7 +989,7 @@ def keplerian_to_cartesian( raise ValueError(err) # Define chunk size - chunk_size = 50 + chunk_size = 200 # Process in chunks coords_cartesian_chunks = [] @@ -1247,7 +1247,7 @@ def cometary_to_cartesian( vz : z-velocity in units of au per day. """ # Define chunk size - chunk_size = 50 + chunk_size = 200 # Process in chunks coords_cartesian_chunks = [] diff --git a/src/adam_core/dynamics/ephemeris.py b/src/adam_core/dynamics/ephemeris.py index 5d6cf2ee..d4178527 100644 --- a/src/adam_core/dynamics/ephemeris.py +++ b/src/adam_core/dynamics/ephemeris.py @@ -208,11 +208,11 @@ def generate_ephemeris_2body( times = propagated_orbits.coordinates.time.mjd().to_numpy(zero_copy_only=False) # Define chunk size - chunk_size = 50 + chunk_size = 200 # Process in chunks - ephemeris_chunks = [] - light_time_chunks = [] + ephemeris_spherical: np.ndarray = np.empty((0, 6)) + light_time: np.ndarray = np.empty((0,)) for orbits_chunk, times_chunk, observer_coords_chunk, mu_chunk in zip( process_in_chunks(propagated_orbits_barycentric.coordinates.values, chunk_size), @@ -230,15 +230,14 @@ def generate_ephemeris_2body( tol, stellar_aberration, ) - ephemeris_chunks.append(ephemeris_chunk) - light_time_chunks.append(light_time_chunk) + ephemeris_spherical = np.concatenate( + (ephemeris_spherical, np.asarray(ephemeris_chunk)) + ) + light_time = np.concatenate((light_time, np.asarray(light_time_chunk))) # Concatenate chunks and remove padding - ephemeris_spherical = jnp.concatenate(ephemeris_chunks, axis=0)[:num_entries] - light_time = jnp.concatenate(light_time_chunks, axis=0)[:num_entries] - - ephemeris_spherical = np.array(ephemeris_spherical) - light_time = np.array(light_time) + ephemeris_spherical = np.array(ephemeris_spherical)[:num_entries] + light_time = np.array(light_time)[:num_entries] if not propagated_orbits.coordinates.covariance.is_all_nan(): diff --git a/src/adam_core/dynamics/propagation.py b/src/adam_core/dynamics/propagation.py index c05dbfc0..75fdf4f7 100644 --- a/src/adam_core/dynamics/propagation.py +++ b/src/adam_core/dynamics/propagation.py @@ -107,7 +107,7 @@ def propagate_2body( object_ids = orbits.object_id.to_numpy(zero_copy_only=False) # Define chunk size - chunk_size = 50 # Example chunk size + chunk_size = 200 # Changed from 1000 # Prepare arrays for chunk processing # This creates a n x m matrix where n is the number of orbits and m is the number of times @@ -121,7 +121,7 @@ def propagate_2body( t1_ = np.tile(t1, n_orbits) # Process in chunks - orbits_propagated_chunks = [] + orbits_propagated: np.ndarray = np.empty((0, 6)) for orbits_chunk, t0_chunk, t1_chunk, mu_chunk in zip( process_in_chunks(orbits_array_, chunk_size), process_in_chunks(t0_, chunk_size), @@ -131,10 +131,9 @@ def propagate_2body( orbits_propagated_chunk = _propagate_2body_vmap( orbits_chunk, t0_chunk, t1_chunk, mu_chunk, max_iter, tol ) - orbits_propagated_chunks.append(orbits_propagated_chunk) - - # Concatenate all chunks - orbits_propagated = jnp.concatenate(orbits_propagated_chunks, axis=0) + orbits_propagated = np.concatenate( + (orbits_propagated, np.asarray(orbits_propagated_chunk)) + ) # Remove padding orbits_propagated = orbits_propagated[: n_orbits * n_times] diff --git a/src/adam_core/observers/state.py b/src/adam_core/observers/state.py index a0e9be8b..1b36b68e 100644 --- a/src/adam_core/observers/state.py +++ b/src/adam_core/observers/state.py @@ -121,7 +121,7 @@ def get_observer_state( o_vec_ITRF93 = np.dot(R_EARTH_EQUATORIAL, o_hat_ITRF93) # Warning! Converting times to ET will incur a loss of precision. - epochs_et = times.rescale("tdb").et() + epochs_et = times.et() unique_epochs_et_tdb = epochs_et.unique() N = len(epochs_et) @@ -150,7 +150,7 @@ def get_observer_state( -OMEGA_EARTH * R_EARTH_EQUATORIAL * rotation_direction ) - return CartesianCoordinates.from_kwargs( + observer_states = CartesianCoordinates.from_kwargs( time=times, x=r_obs[:, 0], y=r_obs[:, 1], @@ -161,3 +161,5 @@ def get_observer_state( frame=frame, origin=Origin.from_kwargs(code=[origin.name for i in range(len(times))]), ) + + return observer_states diff --git a/src/adam_core/time/time.py b/src/adam_core/time/time.py index beabc283..a4ba4017 100644 --- a/src/adam_core/time/time.py +++ b/src/adam_core/time/time.py @@ -48,7 +48,6 @@ def et(self) -> pa.lib.DoubleArray: Returns the times as ET seconds in a pyarrow array. """ tdb = self.rescale("tdb") - mjd = tdb.mjd() return pc.multiply(pc.subtract(mjd, _J2000_TDB_MJD), 86400) diff --git a/src/adam_core/utils/chunking.py b/src/adam_core/utils/chunking.py index a0b79808..7a114344 100644 --- a/src/adam_core/utils/chunking.py +++ b/src/adam_core/utils/chunking.py @@ -1,4 +1,4 @@ -import jax.numpy as jnp +import numpy as np def pad_to_fixed_size(array, target_shape, pad_value=0): @@ -20,7 +20,7 @@ def pad_to_fixed_size(array, target_shape, pad_value=0): Padded array with desired shape """ pad_width = [(0, max(0, t - s)) for s, t in zip(array.shape, target_shape)] - return jnp.pad(array, pad_width, constant_values=pad_value) + return np.pad(array, pad_width, constant_values=pad_value) def process_in_chunks(array, chunk_size): diff --git a/src/adam_core/utils/spice.py b/src/adam_core/utils/spice.py index 1e608f81..e5a6b5dd 100644 --- a/src/adam_core/utils/spice.py +++ b/src/adam_core/utils/spice.py @@ -131,10 +131,9 @@ def get_perturber_state( setup_SPICE() # Convert epochs to ET in TDB - epochs_et = times.rescale("tdb").et() + epochs_et = times.et() unique_epochs_et = epochs_et.unique() N = len(times) - # Get position of the body in km and km/s in the desired frame and measured from the desired origin states = np.empty((N, 6), dtype=np.float64) for i, epoch in enumerate(unique_epochs_et): @@ -144,9 +143,9 @@ def get_perturber_state( ) states[mask, :] = state - # Convert to AU and AU per day + # Convert units (vectorized operations) states = states / KM_P_AU - states[:, 3:] = states[:, 3:] * S_P_DAY + states[:, 3:] *= S_P_DAY return CartesianCoordinates.from_kwargs( time=times, @@ -157,5 +156,5 @@ def get_perturber_state( vy=states[:, 4], vz=states[:, 5], frame=frame, - origin=Origin.from_kwargs(code=[origin.name for i in range(N)]), + origin=Origin.from_kwargs(code=[origin.name] * N), )