Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
akoumjian committed Dec 4, 2024
1 parent eae13d3 commit 37d0091
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 34 deletions.
53 changes: 25 additions & 28 deletions precovery/frame_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,17 +183,14 @@ def window_centers(
datasets: Optional[set[str]] = None,
) -> WindowCenters:
"""Return the midpoint and obscode of all time windows with data in them."""

# Build base query with minimal columns
query = (
sq.select(
self.frames.c.obscode,
self.frames.c.exposure_mjd_mid,
)
.where(
(self.frames.c.exposure_mjd_mid < end_mjd)
& (self.frames.c.exposure_mjd_mid >= start_mjd)
)
query = sq.select(
self.frames.c.obscode,
self.frames.c.exposure_mjd_mid,
).where(
(self.frames.c.exposure_mjd_mid < end_mjd)
& (self.frames.c.exposure_mjd_mid >= start_mjd)
)

if datasets is not None:
Expand All @@ -203,7 +200,7 @@ def window_centers(
chunk_size = 100000
rows = []
result = self.dbconn.execution_options(stream_results=True).execute(query)

while True:
chunk = result.fetchmany(chunk_size)
if not chunk:
Expand All @@ -215,22 +212,23 @@ def window_centers(

# Process results using PyArrow for better performance
obscodes, mjds = zip(*rows)

# Convert to PyArrow arrays for faster computation
mjds_arr = pa.array(mjds)
window_ids = pc.floor(pc.divide(
pc.subtract(mjds_arr, pa.scalar(start_mjd)),
pa.scalar(window_size_days)
))
window_ids = pc.floor(
pc.divide(
pc.subtract(mjds_arr, pa.scalar(start_mjd)), pa.scalar(window_size_days)
)
)

# Group by obscode and window_id using PyArrow
unique_pairs = set((obs, wid.as_py()) for obs, wid in zip(obscodes, window_ids))

if not unique_pairs:
return WindowCenters.empty()

final_obscodes, final_window_ids = zip(*unique_pairs)

# Calculate window centers
window_starts = [start_mjd + wid * window_size_days for wid in final_window_ids]
window_centers = [ws + window_size_days / 2 for ws in window_starts]
Expand All @@ -243,6 +241,7 @@ def window_centers(
)
window_centers = window_centers.sort_by(["time.days", "time.nanos"])
return window_centers

def propagation_targets(
self,
window: WindowCenters,
Expand Down Expand Up @@ -696,6 +695,7 @@ def __init__(

def close(self):
self.idx.close()

# for f in self.data_files.values():
# f.close()

Expand Down Expand Up @@ -831,17 +831,12 @@ def _open_data_files(self):
files = sorted(glob.glob(matcher, recursive=True))
for f in files:
abspath = os.path.abspath(f)
name = os.path.basename(f)
year_month_str = os.path.basename(os.path.dirname(abspath))
dataset_id = os.path.basename(os.path.dirname(os.path.dirname(abspath)))
data_uri = f"{dataset_id}/{year_month_str}/{name}"
if dataset_id not in self.n_data_files.keys():
self.n_data_files[dataset_id] = {}
if year_month_str not in self.n_data_files[dataset_id].keys():
self.n_data_files[dataset_id][year_month_str] = 0
# self.data_files[data_uri] = open(
# abspath, "rb" if self.mode == "r" else "a+b"
# )
self.n_data_files[dataset_id][year_month_str] += 1

def _current_data_file_name(self, dataset_id: str, year_month_str: str):
Expand Down Expand Up @@ -875,7 +870,7 @@ def get_observations(self, exp: HealpixFrame) -> ObservationsTable:
path = os.path.abspath(os.path.join(self.data_root, data_uri))

with open(path, "rb") as f:
# f = self.data_files[data_uri]
# f = self.data_files[data_uri]
f.seek(data_offset)
data_layout = struct.Struct(DATA_LAYOUT)
datagram_size = struct.calcsize(DATA_LAYOUT)
Expand All @@ -895,7 +890,9 @@ def get_observations(self, exp: HealpixFrame) -> ObservationsTable:
) = data_layout.unpack(raw)
id = f.read(id_size)
bytes_read += datagram_size + id_size
observations.append((mjd, ra, dec, ra_sigma, dec_sigma, mag, mag_sigma, id))
observations.append(
(mjd, ra, dec, ra_sigma, dec_sigma, mag, mag_sigma, id)
)
(mjds, ras, decs, ra_sigmas, dec_sigmas, mags, mag_sigmas, ids) = zip(
*observations
)
Expand Down Expand Up @@ -948,7 +945,7 @@ def store_observations(
except KeyError as ke: # NOQA: F841
self.new_data_file(dataset_id, year_month_str)
path = self._current_data_file_full(dataset_id, year_month_str)

with open(path, "a+b") as f:
if hasattr(observations, "__len__"):
logger.info(f"Writing {len(observations)} observations to {f.name}") # type: ignore
Expand Down Expand Up @@ -984,7 +981,7 @@ def new_data_file(self, dataset_id: str, year_month_str: str):
current_data_file = self._current_data_file_full(dataset_id, year_month_str)
os.makedirs(os.path.dirname(current_data_file), exist_ok=True)
# touch the file
with open(current_data_file, "a+b") as f:
with open(current_data_file, "a+b") as f: # NOQA: F841
pass

# f = open(current_data_file, "a+b")
Expand Down
1 change: 0 additions & 1 deletion precovery/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Optional, Tuple, Type

import quivr as qv
import ray
from adam_core.orbits import Orbits
from adam_core.propagator import Propagator

Expand Down
13 changes: 9 additions & 4 deletions precovery/precovery_db.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import os
import time
from typing import Optional, Tuple, Type

import numpy as np
Expand Down Expand Up @@ -503,7 +502,9 @@ def check_window(
"""
assert len(window) == 1, "Use _check_windows for multiple windows"
assert len(orbit) == 1, "_check_window only support one orbit for now"
logger.info(f"check_window orbit: {orbit.orbit_id[0].as_py()} obscode: {window.obscode[0].as_py()} window: {window.window_start().mjd()[0].as_py()} to {window.window_end().mjd()[0].as_py()}")
logger.info(
f"check_window orbit: {orbit.orbit_id[0].as_py()} obscode: {window.obscode[0].as_py()} window: {window.window_start().mjd()[0].as_py()} to {window.window_end().mjd()[0].as_py()}"
)
db = PrecoveryDatabase.from_dir(db_dir, mode="r", allow_version_mismatch=True)
obscode = window.obscode[0].as_py()
propagation_targets = db.frames.idx.propagation_targets(
Expand Down Expand Up @@ -712,7 +713,9 @@ def precover(
# group windows by obscodes so that many windows can be searched at once
for obscode in windows.obscode.unique():
obscode_windows = windows.select("obscode", obscode)
logger.info(f"searching {len(obscode_windows)} windows for obscode {obscode}")
logger.info(
f"searching {len(obscode_windows)} windows for obscode {obscode}"
)

candidates_obscode, frame_candidates_obscode = self._check_windows(
obscode_windows,
Expand Down Expand Up @@ -743,7 +746,9 @@ def _check_windows(
Find all observations that match orbit within a list of windows
"""
assert len(orbit) == 1, "_check_windows only support one orbit for now"
windows = windows.sort_by([("time.days", "descending"), ("time.nanos", "descending")])
windows = windows.sort_by(
[("time.days", "descending"), ("time.nanos", "descending")]
)
logger.info(
f"_check_windows orbit: {orbit.orbit_id[0].as_py()} windows: {len(windows)} obscode: {windows.obscode.unique().to_pylist()}"
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ lint = { composite = [
"black --check ./precovery",
"isort --check-only ./precovery",
] }
fix = "ruff ./precovery --fix"
fix = "ruff check ./precovery --fix"
typecheck = "mypy ./precovery"

test = "pytest --benchmark-skip -m 'not profile' {args}"
Expand Down

0 comments on commit 37d0091

Please sign in to comment.