-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add rolling aggregate features + tests + field to pj
- Loading branch information
1 parent
f080aa5
commit fe666a9
Showing
4 changed files
with
126 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import pandas as pd | ||
|
||
|
||
def add_rolling_aggregate_features( | ||
data: pd.DataFrame, rolling_window: str = "24h" | ||
) -> pd.DataFrame: | ||
""" | ||
Adds rolling aggregate features to the input dataframe. | ||
These features are calculated with an aggregation over a rolling window of the data. | ||
A list of requested features is used to determine whether to add the rolling features | ||
or not. | ||
Args: | ||
data: Input dataframe to which the rolling features will be added. | ||
rolling_window: Rolling window size in str format following | ||
https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases | ||
Returns: | ||
DataFrame with added rolling features. | ||
""" | ||
# Ensure the index is a DatetimeIndex | ||
if not isinstance(data.index, pd.DatetimeIndex): | ||
raise ValueError("The DataFrame index must be a DatetimeIndex.") | ||
|
||
if "load" not in data.columns: | ||
raise ValueError("The DataFrame must contain a 'load' column.") | ||
rolling_window_load = data["load"].rolling(window=rolling_window) | ||
data[f"rolling_median_load_{rolling_window}"] = rolling_window_load.median() | ||
data[f"rolling_max_load_{rolling_window}"] = rolling_window_load.max() | ||
data[f"rolling_min_load_{rolling_window}"] = rolling_window_load.min() | ||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
||
from openstef.feature_engineering.rolling_features import add_rolling_aggregate_features | ||
|
||
|
||
def test_add_rolling_aggregate_features(): | ||
# Generate 2 days of data at 15-minute intervals | ||
num_points = int(24 * 60 / 15 * 2) | ||
data = pd.DataFrame( | ||
index=pd.date_range( | ||
start="2023-01-01 00:00:00", freq="15min", periods=num_points | ||
) | ||
) | ||
data["load"] = list(range(num_points)) | ||
|
||
# Apply the function | ||
output_data = add_rolling_aggregate_features(data) | ||
|
||
# Verify the columns are created | ||
assert "rolling_median_load_24h" in output_data.columns | ||
assert "rolling_max_load_24h" in output_data.columns | ||
assert "rolling_min_load_24h" in output_data.columns | ||
|
||
# Validate the rolling features | ||
rolling_window = "24h" | ||
rolling_window_load = data["load"].rolling(window=rolling_window) | ||
rolling_median_expected = rolling_window_load.median() | ||
rolling_max_expected = rolling_window_load.max() | ||
rolling_min_expected = rolling_window_load.min() | ||
|
||
assert np.allclose( | ||
output_data[f"rolling_median_load_{rolling_window}"], rolling_median_expected | ||
) | ||
assert np.allclose( | ||
output_data[f"rolling_max_load_{rolling_window}"], rolling_max_expected | ||
) | ||
assert np.allclose( | ||
output_data[f"rolling_min_load_{rolling_window}"], rolling_min_expected | ||
) | ||
|
||
|
||
def test_add_rolling_aggregate_features_flatline(): | ||
# Generate 2 days of data at 15-minute intervals | ||
num_points = int(24 * 60 / 15 * 2) | ||
data = pd.DataFrame( | ||
index=pd.date_range( | ||
start="2023-01-01 00:00:00", freq="15min", periods=num_points | ||
) | ||
) | ||
all_ones = [1.0] * num_points | ||
data["load"] = all_ones | ||
|
||
# Apply the function | ||
output_data = add_rolling_aggregate_features(data) | ||
|
||
# Verify the columns are created | ||
assert "rolling_median_load_24h" in output_data.columns | ||
assert "rolling_max_load_24h" in output_data.columns | ||
assert "rolling_min_load_24h" in output_data.columns | ||
|
||
# Validate the rolling features | ||
rolling_window = "24h" | ||
assert np.all(output_data[f"rolling_median_load_{rolling_window}"] == all_ones) | ||
assert np.all(output_data[f"rolling_max_load_{rolling_window}"] == all_ones) | ||
assert np.all(output_data[f"rolling_min_load_{rolling_window}"] == all_ones) | ||
|
||
|
||
def test_add_rolling_aggregate_features_non_datetime_index(): | ||
# Test for non-datetime index | ||
data = pd.DataFrame(index=range(10)) | ||
|
||
with pytest.raises( | ||
ValueError, match="The DataFrame index must be a DatetimeIndex." | ||
): | ||
add_rolling_aggregate_features(data) | ||
|
||
|
||
def test_add_rolling_aggregate_features_no_load_column(): | ||
# Test for dataframe without load column | ||
data = pd.DataFrame( | ||
index=pd.date_range(start="2023-01-01 00:00:00", freq="15min", periods=10), | ||
columns=["not_load"], | ||
) | ||
|
||
with pytest.raises(ValueError, match="The DataFrame must contain a 'load' column."): | ||
add_rolling_aggregate_features(data) |