Skip to content

Commit

Permalink
Parameterize simulator (#15)
Browse files Browse the repository at this point in the history
* Parameterize rebound settings

* Parameterize rebound settings

* typing
  • Loading branch information
akoumjian authored Dec 16, 2024
1 parent 8450a00 commit 96de5a7
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 4 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,20 @@ propagator = ASSISTPropagator()
ephemerides = propagator.generate_ephemeris(sbdb_orbits, observers)
```

## Configuration

When initializing the `ASSISTPropagator`, you can configure several parameters that control the integration.
These parameters are passed directly to REBOUND's IAS15 integrator. The IAS15 integrator is a high accuracy integrator that uses adaptive timestepping to maintain precision while optimizing performance.

- `min_dt`: Minimum timestep for the integrator (default: 1e-15 days)
- `initial_dt`: Initial timestep for the integrator (default: 0.001 days)
- `adaptive_mode`: Controls the adaptive timestep behavior (default: 2)

These parameters are passed directly to REBOUND's IAS15 integrator. The IAS15 integrator is a high accuracy integrator that uses adaptive timestepping to maintain precision while optimizing performance.

Example:

```python
propagator = ASSISTPropagator(min_dt=1e-12, initial_dt=0.0001, adaptive_mode=2)
```

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ dev = [
check = {composite = ["lint", "typecheck"]}
format = { composite = ["black ./src/adam_assist", "isort ./src/adam_assist"]}
lint = { composite = ["ruff check ./src/adam_assist", "black --check ./src/adam_assist", "isort --check-only ./src/adam_assist"] }
fix = "ruff ./src/adam_assist --fix"
fix = "ruff check ./src/adam_assist --fix"
typecheck = "mypy --strict ./src/adam_assist"

test = "pytest --benchmark-disable {args}"
Expand Down
26 changes: 23 additions & 3 deletions src/adam_assist/propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
C = c.C

try:
from adam_core.propagator.adam_assist_version import __version__
from adam_assist.version import __version__
except ImportError:
__version__ = "0.0.0"

Expand Down Expand Up @@ -60,6 +60,25 @@ def hash_orbit_ids_to_uint32(

class ASSISTPropagator(Propagator, ImpactMixin): # type: ignore

def __init__(
self,
*args: object, # Generic type for arbitrary positional arguments
min_dt: float = 1e-15,
initial_dt: float = 0.001,
adaptive_mode: int = 2,
**kwargs: object, # Generic type for arbitrary keyword arguments
) -> None:
super().__init__(*args, **kwargs)
if min_dt <= 0:
raise ValueError("min_dt must be positive")
if initial_dt <= 0:
raise ValueError("initial_dt must be positive")
if min_dt > initial_dt:
raise ValueError("min_dt must be smaller than initial_dt")
self.min_dt = min_dt
self.initial_dt = initial_dt
self.adaptive_mode = adaptive_mode

def _propagate_orbits(self, orbits: OrbitType, times: TimestampType) -> OrbitType:
"""
Propagate the orbits to the specified times.
Expand Down Expand Up @@ -109,8 +128,9 @@ def _propagate_orbits_inner(
)
sim = None
sim = rebound.Simulation()
sim.ri_ias15.min_dt = 1e-15
sim.ri_ias15.adaptive_mode = 2
sim.dt = self.initial_dt
sim.ri_ias15.min_dt = self.min_dt
sim.ri_ias15.adaptive_mode = self.adaptive_mode

# Set the simulation time, relative to the jd_ref
start_tdb_time = orbits.coordinates.time.jd().to_numpy()[0]
Expand Down
95 changes: 95 additions & 0 deletions tests/test_propagator_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import pyarrow as pa
import pytest
from adam_core.coordinates import CartesianCoordinates, Origin
from adam_core.orbits import Orbits
from adam_core.time import Timestamp

from adam_assist import ASSISTPropagator


@pytest.fixture
def basic_orbit():
"""Create a basic test orbit"""
return Orbits.from_kwargs(
orbit_id=["test"],
coordinates=CartesianCoordinates.from_kwargs(
x=[1.0],
y=[0.0],
z=[0.0],
vx=[0.0],
vy=[1.0],
vz=[0.0],
time=Timestamp.from_mjd([60000], scale="tdb"),
origin=Origin.from_kwargs(code=["SUN"]),
frame="ecliptic",
),
)

def test_default_settings():
"""Test that default settings are applied correctly"""
prop = ASSISTPropagator()
assert prop.min_dt == 1e-15
assert prop.initial_dt == 0.001
assert prop.adaptive_mode == 2

def test_custom_settings():
"""Test that custom settings are applied correctly"""
prop = ASSISTPropagator(min_dt=1e-12, initial_dt=0.01, adaptive_mode=1)
assert prop.min_dt == 1e-12
assert prop.initial_dt == 0.01
assert prop.adaptive_mode == 1

def test_invalid_min_dt():
"""Test that invalid min_dt raises ValueError"""
with pytest.raises(ValueError, match="min_dt must be positive"):
ASSISTPropagator(min_dt=-1e-15)

with pytest.raises(ValueError, match="min_dt must be positive"):
ASSISTPropagator(min_dt=0)

def test_invalid_initial_dt():
"""Test that invalid initial_dt raises ValueError"""
with pytest.raises(ValueError, match="initial_dt must be positive"):
ASSISTPropagator(initial_dt=-0.001)

with pytest.raises(ValueError, match="initial_dt must be positive"):
ASSISTPropagator(initial_dt=0)

def test_min_dt_greater_than_initial():
"""Test that min_dt > initial_dt raises ValueError"""
with pytest.raises(ValueError, match="min_dt must be smaller than initial_dt"):
ASSISTPropagator(min_dt=0.1, initial_dt=0.01)

def test_propagation_with_different_settings(basic_orbit):
"""Test that propagation works with different settings"""
# Test with default settings
prop_default = ASSISTPropagator()

# Test with more restrictive settings
prop_restrictive = ASSISTPropagator(min_dt=1e-12, initial_dt=0.0001)

# Test with less restrictive settings
prop_loose = ASSISTPropagator(min_dt=1e-9, initial_dt=0.01)

# Propagate for 10 days with each propagator
target_time = Timestamp.from_mjd([60010], scale="tdb")

result_default = prop_default.propagate_orbits(basic_orbit, target_time)
result_restrictive = prop_restrictive.propagate_orbits(basic_orbit, target_time)
result_loose = prop_loose.propagate_orbits(basic_orbit, target_time)

# All should produce results
assert len(result_default) == 1
assert len(result_restrictive) == 1
assert len(result_loose) == 1

# Results should be similar but not identical due to different step sizes
# Using a relatively loose tolerance since we expect some differences
tolerance = 1e-6

default_pos = result_default.coordinates.values[0, :3]
restrictive_pos = result_restrictive.coordinates.values[0, :3]
loose_pos = result_loose.coordinates.values[0, :3]

assert abs(default_pos - restrictive_pos).max() < tolerance
assert abs(default_pos - loose_pos).max() < tolerance

0 comments on commit 96de5a7

Please sign in to comment.