Skip to content

Commit

Permalink
refactor: stage 3
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Feb 27, 2025
1 parent 486aa35 commit ea34b7a
Show file tree
Hide file tree
Showing 8 changed files with 371 additions and 277 deletions.
53 changes: 28 additions & 25 deletions src/pvnet_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
import os
import tempfile
import warnings
from importlib.metadata import PackageNotFoundError, version

import dask
Expand All @@ -14,13 +13,10 @@
import typer
from nowcasting_datamodel.connection import DatabaseConnection
from nowcasting_datamodel.models.base import Base_Forecast
from nowcasting_datamodel.read.read_gsp import get_latest_gsp_capacities
from pvnet.models.base_model import BaseModel as PVNetBaseModel

from pvnet_app.config import get_union_of_configs, save_yaml_config
from pvnet_app.data.nwp import (
download_all_nwp_data, preprocess_nwp_data, check_model_nwp_inputs_available
)
from pvnet_app.data.nwp import UKVDownloader, ECMWFDownloader
from pvnet_app.data.gsp import get_gsp_and_national_capacities
from pvnet_app.data.satellite import (
check_model_satellite_inputs_available,
Expand Down Expand Up @@ -94,7 +90,7 @@ def save_batch_to_s3(batch, model_name, s3_directory):
fs.put(save_batch, f"{s3_directory}/{save_batch}")
logger.info(
f"Saved first batch for model {model_name} to {s3_directory}/{save_batch}",
)
)
os.remove(save_batch)
logger.info("Removed local copy of batch")
except Exception as e:
Expand Down Expand Up @@ -174,7 +170,7 @@ def app(
# Without this line the dataloader will hang if multiple workers are used
dask.config.set(scheduler="single-threaded")

# Unpack the environment variables
# --- Unpack the environment variables
use_day_ahead_model = get_boolean_env_var("DAY_AHEAD_MODEL", default=False)
use_ecmwf_only = get_boolean_env_var("USE_ECMWF_ONLY", default=False)
run_extra_models = get_boolean_env_var("RUN_EXTRA_MODELS", default=False)
Expand All @@ -183,9 +179,11 @@ def app(
save_gsp_sum = get_boolean_env_var("SAVE_GSP_SUM", default=False)

db_url = os.environ["DB_URL"] # Will raise KeyError if not set
batch_s3_save_dir = os.getenv("SAVE_BATCHES_DIR", None)
s3_batch_save_dir = os.getenv("SAVE_BATCHES_DIR", None)
ecmwf_source_path = os.getenv("NWP_ECMWF_ZARR_PATH", None)
ukv_source_path = os.getenv("NWP_UKV_ZARR_PATH", None)

# Log version and variables
# --- Log version and variables
logger.info(f"Using `pvnet` library version: {__pvnet_version__}")
logger.info(f"Using `pvnet_app` library version: {__version__}")
logger.info(f"Making forecast for init time: {t0}")
Expand All @@ -197,7 +195,7 @@ def app(
logger.info(f"Using adjuster: {use_adjuster}")
logger.info(f"Saving GSP sum: {save_gsp_sum}")

# Get the model configurations
# --- Get the model configurations
model_configs = get_all_models(
allow_use_adjuster=use_adjuster,
allow_save_gsp_sum=save_gsp_sum,
Expand All @@ -206,6 +204,8 @@ def app(
run_extra_models=run_extra_models,
use_ocf_data_sampler=use_ocf_data_sampler,
)
if len(model_configs)==0:
raise Exception("No models found after filtering")

# Open connection to the database - used for pulling GSP capacitites and writing forecasts
db_connection = DatabaseConnection(url=db_url, base=Base_Forecast, echo=False)
Expand All @@ -215,15 +215,15 @@ def app(
# ---------------------------------------------------------------------------
# 1. Prepare data sources

# Get capacities from the database
# --- Get capacities from the database
logger.info("Loading capacities from the database")
gsp_capacities, national_capacity = get_gsp_and_national_capacities(
db_connection=db_connection,
gsp_ids=gsp_ids,
t0=t0,
)

# Download satellite data
# --- Download satellite data
logger.info("Downloading satellite data")
sat_available = download_all_sat_data()

Expand All @@ -233,19 +233,21 @@ def app(
else:
sat_datetimes = preprocess_sat_data(t0, use_legacy=not use_ocf_data_sampler)

# Download NWP data
# --- Download and process NWP data
logger.info("Downloading NWP data")
download_all_nwp_data()

# Preprocess the NWP data
preprocess_nwp_data()
ecmwf_downloader = ECMWFDownloader(source_path=ecmwf_source_path)
ecmwf_downloader.run()

ukv_downloader = UKVDownloader(source_path=ukv_source_path)
ukv_downloader.run()

# ---------------------------------------------------------------------------
# 2. Set up models

# Prepare all the models which can be run
forecast_compilers = {}
data_config_paths = []
used_data_config_paths = []
for model_config in model_configs:
# First load the data config
data_config_path = PVNetBaseModel.get_data_config(
Expand All @@ -257,11 +259,12 @@ def app(
logger.info(f"Checking that the input data for model '{model_config.name}' exists")
model_can_run = (
check_model_satellite_inputs_available(data_config_path, t0, sat_datetimes)
and
check_model_nwp_inputs_available(data_config_path, t0)
and ecmwf_downloader.check_model_inputs_available(data_config_path, t0)
and ukv_downloader.check_model_inputs_available(data_config_path, t0)
)

if model_can_run:
logger.info(f"The input data for model '{model_config.name}' is available")
# Set up a forecast compiler for the model
forecast_compilers[model_config.name] = ForecastCompiler(
model_config=model_config,
Expand All @@ -273,15 +276,15 @@ def app(
)

# Store the config filename so we can create batches suitable for all models
data_config_paths.append(data_config_path)
used_data_config_paths.append(data_config_path)
else:
warnings.warn(f"The model {model_config.name} cannot be run with input data available")
logger.warning(f"The model {model_config.name} cannot be run with input data available")

if len(forecast_compilers) == 0:
raise Exception("No models were compatible with the available input data.")

# Find the config with values suitable for running all models
common_config = get_union_of_configs(data_config_paths)
common_config = get_union_of_configs(used_data_config_paths)

# Save the commmon config
common_config_path = f"{temp_dir.name}/common_config_path.yaml"
Expand All @@ -308,9 +311,9 @@ def app(
for i, batch in enumerate(dataloader):
logger.info(f"Predicting for batch: {i}")

if (batch_s3_save_dir is not None) and i == 0:
if (s3_batch_save_dir is not None) and i == 0:
model_name = next(iter(forecast_compilers))
save_batch_to_s3(batch, model_name, s3_directory)
save_batch_to_s3(batch, model_name, s3_batch_save_dir)

for forecast_compiler in forecast_compilers.values():
forecast_compiler.predict_batch(batch)
Expand All @@ -325,7 +328,7 @@ def app(
# ---------------------------------------------------------------------------
# Escape clause for making predictions locally
if not write_predictions:
return forecast_compilers[0].da_abs_all
return next(iter(forecast_compilers.values())).da_abs_all

# ---------------------------------------------------------------------------
# Write predictions to database
Expand Down
Loading

0 comments on commit ea34b7a

Please sign in to comment.