Skip to content

Commit

Permalink
refactor stage 4
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Feb 28, 2025
1 parent 721f111 commit b7a7cff
Show file tree
Hide file tree
Showing 4 changed files with 513 additions and 447 deletions.
35 changes: 22 additions & 13 deletions src/pvnet_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,9 @@
from pvnet.models.base_model import BaseModel as PVNetBaseModel

from pvnet_app.config import get_union_of_configs, save_yaml_config
from pvnet_app.data.satellite import SatelliteDownloader
from pvnet_app.data.nwp import UKVDownloader, ECMWFDownloader
from pvnet_app.data.gsp import get_gsp_and_national_capacities
from pvnet_app.data.satellite import (
check_model_satellite_inputs_available,
download_all_sat_data,
preprocess_sat_data,
)
from pvnet_app.dataloader import get_dataloader
from pvnet_app.forecast_compiler import ForecastCompiler
from pvnet_app.model_configs.pydantic_models import get_all_models
Expand Down Expand Up @@ -111,7 +107,9 @@ def get_boolean_env_var(env_var: str, default: bool) -> bool:
The boolean value of the environment variable.
"""
if env_var in os.environ:
return os.getenv(env_var).lower() == "true"
env_var_value = os.getenv(env_var).lower()
assert env_var_value in ["true", "false"]
return env_var_value == "true"
else:
return default

Expand Down Expand Up @@ -182,6 +180,10 @@ def app(
s3_batch_save_dir = os.getenv("SAVE_BATCHES_DIR", None)
ecmwf_source_path = os.getenv("NWP_ECMWF_ZARR_PATH", None)
ukv_source_path = os.getenv("NWP_UKV_ZARR_PATH", None)
sat_source_path_5 = os.getenv("SATELLITE_ZARR_PATH", None)
sat_source_path_15 = (
None if (sat_source_path_5 is None) else sat_source_path_5.replace(".zarr", "_15.zarr")
)

# --- Log version and variables
logger.info(f"Using `pvnet` library version: {__pvnet_version__}")
Expand Down Expand Up @@ -225,13 +227,15 @@ def app(

# --- Download satellite data
logger.info("Downloading satellite data")
sat_available = download_all_sat_data()

sat_downloader = SatelliteDownloader(
t0=t0,
source_path_5=sat_source_path_5,
source_path_15=sat_source_path_15,
legacy=(not use_ocf_data_sampler),
)
sat_downloader.run()

# Preprocess the satellite data if available and store available timesteps
if not sat_available:
sat_datetimes = pd.DatetimeIndex([])
else:
sat_datetimes = preprocess_sat_data(t0, use_legacy=not use_ocf_data_sampler)

# --- Download and process NWP data
logger.info("Downloading NWP data")
Expand All @@ -258,7 +262,7 @@ def app(
# Check if the data available will allow the model to run
logger.info(f"Checking that the input data for model '{model_config.name}' exists")
model_can_run = (
check_model_satellite_inputs_available(data_config_path, t0, sat_datetimes)
sat_downloader.check_model_inputs_available(data_config_path, t0)
and ecmwf_downloader.check_model_inputs_available(data_config_path, t0)
and ukv_downloader.check_model_inputs_available(data_config_path, t0)
)
Expand Down Expand Up @@ -318,6 +322,11 @@ def app(
for forecast_compiler in forecast_compilers.values():
forecast_compiler.predict_batch(batch)

# Delete the downloaded data
sat_downloader.clean_up()
ecmwf_downloader.clean_up()
ukv_downloader.clean_up()

# ---------------------------------------------------------------------------
# Merge batch results to xarray DataArray
logger.info("Processing raw predictions to DataArray")
Expand Down
4 changes: 4 additions & 0 deletions src/pvnet_app/data/nwp.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ def run(self) -> None:
self.valid_times = get_nwp_valid_times(ds)

self.resave(ds)

def clean_up(self) -> None:
"""Remove the downloaded data"""
shutil.rmtree(self.destination_path, ignore_errors=True)


def check_model_inputs_available(
Expand Down
Loading

0 comments on commit b7a7cff

Please sign in to comment.