Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ak/more perf improvements #130

Merged
merged 2 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions src/adam_core/coordinates/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,7 @@ def _cartesian_to_spherical(
)


def cartesian_to_spherical(
coords_cartesian: Union[np.ndarray, jnp.ndarray]
) -> jnp.ndarray:
def cartesian_to_spherical(coords_cartesian: np.ndarray) -> np.ndarray:
"""
Convert Cartesian coordinates to a spherical coordinates.

Expand Down Expand Up @@ -163,16 +161,17 @@ def cartesian_to_spherical(
(same unit of time as the x, y, and z velocities).
"""
# Define chunk size
chunk_size = 50
chunk_size = 200

# Process in chunks
coords_spherical_chunks = []
# Process in chunk
coords_spherical: np.ndarray = np.empty((0, 6))
for cartesian_chunk in process_in_chunks(coords_cartesian, chunk_size):
coords_spherical_chunk = _cartesian_to_spherical_vmap(cartesian_chunk)
coords_spherical_chunks.append(coords_spherical_chunk)
coords_spherical = np.concatenate(
(coords_spherical, np.asarray(coords_spherical_chunk))
)

# Concatenate chunks and remove padding
coords_spherical = jnp.concatenate(coords_spherical_chunks, axis=0)
coords_spherical = coords_spherical[: len(coords_cartesian)]

return coords_spherical
Expand Down Expand Up @@ -290,16 +289,17 @@ def spherical_to_cartesian(
vz : z-velocity in the same units of z per arbitrary unit of time.
"""
# Define chunk size
chunk_size = 50
chunk_size = 200

# Process in chunks
coords_cartesian_chunks = []
coords_cartesian: np.ndarray = np.empty((0, 6))
for spherical_chunk in process_in_chunks(coords_spherical, chunk_size):
coords_cartesian_chunk = _spherical_to_cartesian_vmap(spherical_chunk)
coords_cartesian_chunks.append(coords_cartesian_chunk)
coords_cartesian = np.concatenate(
(coords_cartesian, np.asarray(coords_cartesian_chunk))
)

# Concatenate chunks and remove padding
coords_cartesian = jnp.concatenate(coords_cartesian_chunks, axis=0)
# Remove padding
coords_cartesian = coords_cartesian[: len(coords_spherical)]

return coords_cartesian
Expand Down Expand Up @@ -585,7 +585,7 @@ def cartesian_to_keplerian(
tp : time of periapsis passage in days.
"""
# Define chunk size
chunk_size = 50
chunk_size = 200

# Process in chunks
coords_keplerian_chunks = []
Expand Down Expand Up @@ -989,7 +989,7 @@ def keplerian_to_cartesian(
raise ValueError(err)

# Define chunk size
chunk_size = 50
chunk_size = 200

# Process in chunks
coords_cartesian_chunks = []
Expand Down Expand Up @@ -1247,7 +1247,7 @@ def cometary_to_cartesian(
vz : z-velocity in units of au per day.
"""
# Define chunk size
chunk_size = 50
chunk_size = 200

# Process in chunks
coords_cartesian_chunks = []
Expand Down
19 changes: 9 additions & 10 deletions src/adam_core/dynamics/ephemeris.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,11 @@ def generate_ephemeris_2body(
times = propagated_orbits.coordinates.time.mjd().to_numpy(zero_copy_only=False)

# Define chunk size
chunk_size = 50
chunk_size = 200

# Process in chunks
ephemeris_chunks = []
light_time_chunks = []
ephemeris_spherical: np.ndarray = np.empty((0, 6))
light_time: np.ndarray = np.empty((0,))

for orbits_chunk, times_chunk, observer_coords_chunk, mu_chunk in zip(
process_in_chunks(propagated_orbits_barycentric.coordinates.values, chunk_size),
Expand All @@ -230,15 +230,14 @@ def generate_ephemeris_2body(
tol,
stellar_aberration,
)
ephemeris_chunks.append(ephemeris_chunk)
light_time_chunks.append(light_time_chunk)
ephemeris_spherical = np.concatenate(
(ephemeris_spherical, np.asarray(ephemeris_chunk))
)
light_time = np.concatenate((light_time, np.asarray(light_time_chunk)))

# Concatenate chunks and remove padding
ephemeris_spherical = jnp.concatenate(ephemeris_chunks, axis=0)[:num_entries]
light_time = jnp.concatenate(light_time_chunks, axis=0)[:num_entries]

ephemeris_spherical = np.array(ephemeris_spherical)
light_time = np.array(light_time)
ephemeris_spherical = np.array(ephemeris_spherical)[:num_entries]
light_time = np.array(light_time)[:num_entries]

if not propagated_orbits.coordinates.covariance.is_all_nan():

Expand Down
11 changes: 5 additions & 6 deletions src/adam_core/dynamics/propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def propagate_2body(
object_ids = orbits.object_id.to_numpy(zero_copy_only=False)

# Define chunk size
chunk_size = 50 # Example chunk size
chunk_size = 200 # Changed from 1000

# Prepare arrays for chunk processing
# This creates a n x m matrix where n is the number of orbits and m is the number of times
Expand All @@ -121,7 +121,7 @@ def propagate_2body(
t1_ = np.tile(t1, n_orbits)

# Process in chunks
orbits_propagated_chunks = []
orbits_propagated: np.ndarray = np.empty((0, 6))
for orbits_chunk, t0_chunk, t1_chunk, mu_chunk in zip(
process_in_chunks(orbits_array_, chunk_size),
process_in_chunks(t0_, chunk_size),
Expand All @@ -131,10 +131,9 @@ def propagate_2body(
orbits_propagated_chunk = _propagate_2body_vmap(
orbits_chunk, t0_chunk, t1_chunk, mu_chunk, max_iter, tol
)
orbits_propagated_chunks.append(orbits_propagated_chunk)

# Concatenate all chunks
orbits_propagated = jnp.concatenate(orbits_propagated_chunks, axis=0)
orbits_propagated = np.concatenate(
(orbits_propagated, np.asarray(orbits_propagated_chunk))
)

# Remove padding
orbits_propagated = orbits_propagated[: n_orbits * n_times]
Expand Down
6 changes: 4 additions & 2 deletions src/adam_core/observers/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def get_observer_state(
o_vec_ITRF93 = np.dot(R_EARTH_EQUATORIAL, o_hat_ITRF93)

# Warning! Converting times to ET will incur a loss of precision.
epochs_et = times.rescale("tdb").et()
epochs_et = times.et()
unique_epochs_et_tdb = epochs_et.unique()

N = len(epochs_et)
Expand Down Expand Up @@ -150,7 +150,7 @@ def get_observer_state(
-OMEGA_EARTH * R_EARTH_EQUATORIAL * rotation_direction
)

return CartesianCoordinates.from_kwargs(
observer_states = CartesianCoordinates.from_kwargs(
time=times,
x=r_obs[:, 0],
y=r_obs[:, 1],
Expand All @@ -161,3 +161,5 @@ def get_observer_state(
frame=frame,
origin=Origin.from_kwargs(code=[origin.name for i in range(len(times))]),
)

return observer_states
1 change: 0 additions & 1 deletion src/adam_core/time/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def et(self) -> pa.lib.DoubleArray:
Returns the times as ET seconds in a pyarrow array.
"""
tdb = self.rescale("tdb")

mjd = tdb.mjd()
return pc.multiply(pc.subtract(mjd, _J2000_TDB_MJD), 86400)

Expand Down
4 changes: 2 additions & 2 deletions src/adam_core/utils/chunking.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import jax.numpy as jnp
import numpy as np


def pad_to_fixed_size(array, target_shape, pad_value=0):
Expand All @@ -20,7 +20,7 @@ def pad_to_fixed_size(array, target_shape, pad_value=0):
Padded array with desired shape
"""
pad_width = [(0, max(0, t - s)) for s, t in zip(array.shape, target_shape)]
return jnp.pad(array, pad_width, constant_values=pad_value)
return np.pad(array, pad_width, constant_values=pad_value)


def process_in_chunks(array, chunk_size):
Expand Down
9 changes: 4 additions & 5 deletions src/adam_core/utils/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,9 @@ def get_perturber_state(
setup_SPICE()

# Convert epochs to ET in TDB
epochs_et = times.rescale("tdb").et()
epochs_et = times.et()
unique_epochs_et = epochs_et.unique()
N = len(times)
# Get position of the body in km and km/s in the desired frame and measured from the desired origin
states = np.empty((N, 6), dtype=np.float64)

for i, epoch in enumerate(unique_epochs_et):
Expand All @@ -144,9 +143,9 @@ def get_perturber_state(
)
states[mask, :] = state

# Convert to AU and AU per day
# Convert units (vectorized operations)
states = states / KM_P_AU
states[:, 3:] = states[:, 3:] * S_P_DAY
states[:, 3:] *= S_P_DAY

return CartesianCoordinates.from_kwargs(
time=times,
Expand All @@ -157,5 +156,5 @@ def get_perturber_state(
vy=states[:, 4],
vz=states[:, 5],
frame=frame,
origin=Origin.from_kwargs(code=[origin.name for i in range(N)]),
origin=Origin.from_kwargs(code=[origin.name] * N),
)
Loading