diff --git a/thor/orbit.py b/thor/orbit.py index 49b3969a..56732784 100644 --- a/thor/orbit.py +++ b/thor/orbit.py @@ -134,7 +134,7 @@ def _is_cache_fresh(self, observations: Observations) -> bool: return False elif pc.all( pc.is_in( - self._cached_observation_ids.sort(), observations.detections.id.sort() + observations.detections.id.sort(), self._cached_observation_ids.sort() ) ).as_py(): return True @@ -274,7 +274,8 @@ def generate_ephemeris_from_observations( def range_observations( self, - observations: qv.Linkage[PointSourceDetections, Exposures], + observations: Observations, + propagator: Propagator = PYOORB(), max_processes: Optional[int] = 1, ) -> RangedPointSourceDetections: """ @@ -283,8 +284,10 @@ def range_observations( Parameters ---------- - observations : qv.Linkage[PointSourceDetections, Exposures] + observations : `~thor.observations.observations.Observations` Observations to range. + propagator : `~adam_core.propagator.propagator.Propagator`, optional + Propagator to use to propagate the orbit. Defaults to PYOORB. max_processes : int, optional Number of processes to use to propagate the orbit. Defaults to 1. @@ -293,41 +296,44 @@ def range_observations( ranged_point_source_detections : `~thor.orbit.RangedPointSourceDetections` The ranged detections. """ - exposures = observations.right_table - ephemeris = self.generate_ephemeris( - exposures.observers(), max_processes=max_processes + # Generate an ephemeris for each unique observation time and observatory + # code combination + ephemeris = self.generate_ephemeris_from_observations( + observations, propagator=propagator, max_processes=max_processes ) - # Get the light-time corrected state vector: the state vector - # at the time where the light reflected/emitted from the object - # would have reached the observer. - # We will use this state vector to get the heliocentric distance of - # the object at the time of the exposure. - # TODO: We could use other adam_core functionality to calculate this if need - # be and we may need to if we plan on mapping covariances. - propagated_orbit = ephemeris.left_table.aberrated_coordinates - - # TODO: We could investigate using concurrent futures here to parallelize - # this loop + + # Link the ephemeris to the observations + link = qv.Linkage( + ephemeris, + observations, + left_keys=ephemeris.id, + right_keys=observations.state_id, + ) + + # Do a sorted iteration over the unique state IDs rpsds = [] - for propagated_orbit_i, exposure_i in zip(propagated_orbit, exposures): + state_ids = observations.state_id.unique().sort() + for state_id in state_ids: - # Select the detections that belong to this exposure - detections_i = observations.select_left(exposure_i.id[0]) + # Select the ephemeris and observations for this state + ephemeris_i = link.select_left(state_id) + observations_i = link.select_right(state_id) + detections_i = observations_i.detections # Get the heliocentric distance of the object at the time of the exposure - r_mag = propagated_orbit_i.r_mag[0] + r_mag = ephemeris_i.ephemeris.aberrated_coordinates.r_mag[0] # Get the observer's heliocentric coordinates - observer_i = exposure_i.observers() + observer_i = ephemeris_i.observer # Create an array of observatory codes for the detections - num_detections = len(detections_i) + num_detections = len(observations_i) observatory_codes = np.repeat( - exposure_i.observatory_code[0].as_py(), num_detections + observations_i.observatory_code[0].as_py(), num_detections ) # The following can be replaced with: - # coords = detections_i.to_spherical(observatory_codes) + # coords = observations_i.to_spherical(observatory_codes) # Start replacement: sigma_data = np.vstack( [ @@ -357,21 +363,12 @@ def range_observations( coordinates=assume_heliocentric_distance( r_mag, coords, observer_i.coordinates ), + state_id=observations_i.state_id, ) ) - # Sort ranged detections by time ranged_detections = qv.concatenate(rpsds) - table = pa.table( - [ranged_detections.coordinates.time.jd()], - names=["time_jd"], - ) - - indices = pc.sort_indices( - table, - (("time_jd", "ascending"),), - ) - return ranged_detections.take(indices) + return ranged_detections def assume_heliocentric_distance(