Skip to content

Commit

Permalink
Use chunking / padding in more jitted functions (#129)
Browse files Browse the repository at this point in the history
* Use chunking / padding in more jitted functions
  • Loading branch information
akoumjian authored Dec 6, 2024
1 parent 33c6481 commit 3059a45
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 31 deletions.
94 changes: 84 additions & 10 deletions src/adam_core/coordinates/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from jax import config, jit, lax, vmap

from ..constants import Constants as c
from ..utils.chunking import process_in_chunks
from . import types
from .cartesian import CartesianCoordinates
from .cometary import CometaryCoordinates
Expand Down Expand Up @@ -161,7 +162,19 @@ def cartesian_to_spherical(
vlat : Latitudinal velocity in degrees per arbitrary unit of time.
(same unit of time as the x, y, and z velocities).
"""
coords_spherical = _cartesian_to_spherical_vmap(coords_cartesian)
# Define chunk size
chunk_size = 50

# Process in chunks
coords_spherical_chunks = []
for cartesian_chunk in process_in_chunks(coords_cartesian, chunk_size):
coords_spherical_chunk = _cartesian_to_spherical_vmap(cartesian_chunk)
coords_spherical_chunks.append(coords_spherical_chunk)

# Concatenate chunks and remove padding
coords_spherical = jnp.concatenate(coords_spherical_chunks, axis=0)
coords_spherical = coords_spherical[: len(coords_cartesian)]

return coords_spherical


Expand Down Expand Up @@ -276,7 +289,19 @@ def spherical_to_cartesian(
vy : y-velocity in the same units of y per arbitrary unit of time.
vz : z-velocity in the same units of z per arbitrary unit of time.
"""
coords_cartesian = _spherical_to_cartesian_vmap(coords_spherical)
# Define chunk size
chunk_size = 50

# Process in chunks
coords_cartesian_chunks = []
for spherical_chunk in process_in_chunks(coords_spherical, chunk_size):
coords_cartesian_chunk = _spherical_to_cartesian_vmap(spherical_chunk)
coords_cartesian_chunks.append(coords_cartesian_chunk)

# Concatenate chunks and remove padding
coords_cartesian = jnp.concatenate(coords_cartesian_chunks, axis=0)
coords_cartesian = coords_cartesian[: len(coords_spherical)]

return coords_cartesian


Expand Down Expand Up @@ -537,7 +562,7 @@ def cartesian_to_keplerian(
vz : z-velocity in units of au per day.
t0 : {`~numpy.ndarray`, `~jax.numpy.ndarray`} (N)
Epoch at which cometary elements are defined in MJD TDB.
mu : {`~numpy.ndarray`, `~jax.numpy.ndarray`} (N, 6)
mu : {`~numpy.ndarray`, `~jax.numpy.ndarray`} (N)
Gravitational parameter (GM) of the attracting body in units of
au**3 / d**2.
Expand All @@ -559,7 +584,25 @@ def cartesian_to_keplerian(
P : period in days.
tp : time of periapsis passage in days.
"""
coords_keplerian = _cartesian_to_keplerian_vmap(coords_cartesian, t0, mu)
# Define chunk size
chunk_size = 50

# Process in chunks
coords_keplerian_chunks = []
for cartesian_chunk, t0_chunk, mu_chunk in zip(
process_in_chunks(coords_cartesian, chunk_size),
process_in_chunks(t0, chunk_size),
process_in_chunks(mu, chunk_size),
):
coords_keplerian_chunk = _cartesian_to_keplerian_vmap(
cartesian_chunk, t0_chunk, mu_chunk
)
coords_keplerian_chunks.append(coords_keplerian_chunk)

# Concatenate chunks and remove padding
coords_keplerian = jnp.concatenate(coords_keplerian_chunks, axis=0)
coords_keplerian = coords_keplerian[: len(coords_cartesian)]

return coords_keplerian


Expand Down Expand Up @@ -945,9 +988,24 @@ def keplerian_to_cartesian(
)
raise ValueError(err)

coords_cartesian = _keplerian_to_cartesian_a_vmap(
coords_keplerian, mu, max_iter, tol
)
# Define chunk size
chunk_size = 50

# Process in chunks
coords_cartesian_chunks = []
for keplerian_chunk, mu_chunk in zip(
process_in_chunks(coords_keplerian, chunk_size),
process_in_chunks(mu, chunk_size),
):
coords_cartesian_chunk = _keplerian_to_cartesian_a_vmap(
keplerian_chunk, mu_chunk, max_iter, tol
)
coords_cartesian_chunks.append(coords_cartesian_chunk)

# Concatenate chunks and remove padding
coords_cartesian = jnp.concatenate(coords_cartesian_chunks, axis=0)
coords_cartesian = coords_cartesian[: len(coords_keplerian)]

return coords_cartesian


Expand Down Expand Up @@ -1188,9 +1246,25 @@ def cometary_to_cartesian(
vy : y-velocity in units of au per day.
vz : z-velocity in units of au per day.
"""
coords_cartesian = _cometary_to_cartesian_vmap(
coords_cometary, t0, mu, max_iter, tol
)
# Define chunk size
chunk_size = 50

# Process in chunks
coords_cartesian_chunks = []
for cometary_chunk, t0_chunk, mu_chunk in zip(
process_in_chunks(coords_cometary, chunk_size),
process_in_chunks(t0, chunk_size),
process_in_chunks(mu, chunk_size),
):
coords_cartesian_chunk = _cometary_to_cartesian_vmap(
cometary_chunk, t0_chunk, mu_chunk, max_iter, tol
)
coords_cartesian_chunks.append(coords_cartesian_chunk)

# Concatenate chunks and remove padding
coords_cartesian = jnp.concatenate(coords_cartesian_chunks, axis=0)
coords_cartesian = coords_cartesian[: len(coords_cometary)]

return coords_cartesian


Expand Down
2 changes: 1 addition & 1 deletion src/adam_core/dynamics/ephemeris.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from ..observers.observers import Observers
from ..orbits.ephemeris import Ephemeris
from ..orbits.orbits import Orbits
from ..utils.chunking import process_in_chunks
from .aberrations import _add_light_time, add_stellar_aberration
from .propagation import process_in_chunks


@jit
Expand Down
21 changes: 1 addition & 20 deletions src/adam_core/dynamics/propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..coordinates.origin import Origin
from ..orbits.orbits import Orbits
from ..time import Timestamp
from ..utils.chunking import process_in_chunks
from .lagrange import apply_lagrange_coefficients, calc_lagrange_coefficients

config.update("jax_enable_x64", True)
Expand Down Expand Up @@ -69,26 +70,6 @@ def _propagate_2body(
)


def pad_to_fixed_size(array, target_shape, pad_value=0):
"""
Pad an array to a fixed shape with a specified pad value.
"""
pad_width = [(0, max(0, t - s)) for s, t in zip(array.shape, target_shape)]
return jnp.pad(array, pad_width, constant_values=pad_value)


def process_in_chunks(array, chunk_size):
"""
Yield chunks of the array with a fixed size, padding the last chunk if necessary.
"""
n = array.shape[0]
for i in range(0, n, chunk_size):
chunk = array[i : i + chunk_size]
if chunk.shape[0] < chunk_size:
chunk = pad_to_fixed_size(chunk, (chunk_size,) + chunk.shape[1:])
yield chunk


def propagate_2body(
orbits: Orbits,
times: Timestamp,
Expand Down
47 changes: 47 additions & 0 deletions src/adam_core/utils/chunking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import jax.numpy as jnp


def pad_to_fixed_size(array, target_shape, pad_value=0):
"""
Pad an array to a fixed shape with a specified pad value.
Parameters
----------
array : array-like
Array to pad
target_shape : tuple
Desired output shape
pad_value : int or float, optional
Value to use for padding, by default 0
Returns
-------
padded_array : array-like
Padded array with desired shape
"""
pad_width = [(0, max(0, t - s)) for s, t in zip(array.shape, target_shape)]
return jnp.pad(array, pad_width, constant_values=pad_value)


def process_in_chunks(array, chunk_size):
"""
Yield chunks of the array with a fixed size, padding the last chunk if necessary.
Parameters
----------
array : array-like
Array to process in chunks
chunk_size : int
Size of each chunk
Yields
------
chunk : array-like
Array chunk of fixed size (padded if necessary)
"""
n = array.shape[0]
for i in range(0, n, chunk_size):
chunk = array[i : i + chunk_size]
if chunk.shape[0] < chunk_size:
chunk = pad_to_fixed_size(chunk, (chunk_size,) + chunk.shape[1:])
yield chunk

0 comments on commit 3059a45

Please sign in to comment.