Skip to content

Allow direct initialization of NBeats model without from_dataset #1837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
174 changes: 121 additions & 53 deletions pytorch_forecasting/models/nbeats/_nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,54 +43,88 @@
**kwargs,
):
"""
Initialize NBeats Model - use its :py:meth:`~from_dataset` method if possible.
Initialize NBeats Model.

Based on the article
`N-BEATS: Neural basis expansion analysis for interpretable time series
forecasting <http://arxiv.org/abs/1905.10437>`_. The network has (if used as ensemble) outperformed all
other methods
including ensembles of traditional statical methods in the M4 competition. The M4 competition is arguably
the most
important benchmark for univariate time series forecasting.
The model can be initialized in two ways:
1. Using the :py:meth:`~from_dataset` classmethod
(recommended for standard time series forecasting)
2. Direct initialization with required parameters (for custom use cases)

The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently shown to consistently outperform
N-BEATS.
Based on the article `N-BEATS: Neural basis expansion analysis for
interpretable time series forecasting <http://arxiv.org/abs/1905.10437>`_.
The network has (if used as ensemble) outperformed all other methods including
ensembles of traditional statical methods in the M4 competition.

The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently
shown to consistently outperform N-BEATS.

Args:
stack_types: One of the following values: “generic”, “seasonality" or “trend". A list of strings
of length 1 or ‘num_stacks’. Default and recommended value
for generic mode: [“generic”] Recommended value for interpretable mode: [“trend”,”seasonality”]
num_blocks: The number of blocks per stack. A list of ints of length 1 or ‘num_stacks’.
Default and recommended value for generic mode: [1] Recommended value for interpretable mode: [3]
num_block_layers: Number of fully connected layers with ReLu activation per block. A list of ints of length
1 or ‘num_stacks’.
Default and recommended value for generic mode: [4] Recommended value for interpretable mode: [4]
width: Widths of the fully connected layers with ReLu activation in the blocks.
A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [512]
Recommended value for interpretable mode: [256, 2048]
sharing: Whether the weights are shared with the other blocks per stack.
A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [False]
Recommended value for interpretable mode: [True]
expansion_coefficient_length: If the type is “G” (generic), then the length of the expansion
coefficient.
If type is “T” (trend), then it corresponds to the degree of the polynomial. If the type is “S”
(seasonal) then this is the minimum period allowed, e.g. 2 for changes every timestep.
A list of ints of length 1 or ‘num_stacks’. Default value for generic mode: [32] Recommended value for
interpretable mode: [3]
prediction_length: Length of the prediction. Also known as 'horizon'.
context_length: Number of time units that condition the predictions. Also known as 'lookback period'.
Should be between 1-10 times the prediction length.
backcast_loss_ratio: weight of backcast in comparison to forecast when calculating the loss.
A weight of 1.0 means that forecast and backcast loss is weighted the same (regardless of backcast and
forecast lengths). Defaults to 0.0, i.e. no weight.
loss: loss to optimize. Defaults to MASE().
log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training
failures
reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10
logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that are logged during training.
Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()])
**kwargs: additional arguments to :py:class:`~BaseModel`.
""" # noqa: E501
stack_types: One of the following values: "generic", "seasonality" or
"trend".
A list of strings of length 1 or 'num_stacks'.
Default and recommended value for generic mode: ["generic"].
Recommended value for interpretable mode: ["trend","seasonality"]
num_blocks: The number of blocks per stack. A list of ints of length 1 or
'num_stacks'. Default and recommended value for generic mode: [1].
Recommended value for interpretable mode: [3]
num_block_layers: Number of fully connected layers with
ReLu activation per block.
A list of ints of length 1 or 'num_stacks'. Default and recommended
value for generic mode: [4].
Recommended value for interpretable mode: [4]
width: Width of fully connected layers with ReLu activation.
A list of ints (length = 'num_stacks').
Default generic mode: [512]
Default interpretable mode: [256, 2048]
sharing: Share weights between blocks per stack.
A list of bools (length = 'num_stacks').
Default generic mode: [False]
Default interpretable mode: [True]
expansion_coefficient_lengths: Configures each stack type:
- "generic": expansion coefficient length
- "trend": polynomial degree
- "seasonal": minimum period for changes
A list of ints (length = 'num_stacks').
Default generic mode: [32]
Default interpretable mode: [3]
prediction_length: Length of the prediction horizon
context_length: Number of timesteps for predictions.
Should be 1-10x prediction_length.
dropout: Dropout rate (0.0 to 1.0)
learning_rate: Initial learning rate
log_interval: Logging frequency (-1 = end of epoch)
log_gradient_flow: If to log gradient flow, this takes time and should be
only done to diagnose training failures
log_val_interval: Log validation metrics every x batches.
weight_decay: L2 regularization factor
backcast_loss_ratio: Ratio of backcast loss vs forecast loss.
loss: PyTorch metric to optimize. Defaults to MASE()
reduce_on_plateau_patience: Patience after which learning rate is reduced
logging_metrics: List of metrics logged during training. Defaults to
nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()])
**kwargs: Additional arguments for BaseModel

Example:
Direct initialization:

>>> from pytorch_forecasting.models import NBeats
>>> model = NBeats(
... stack_types=["trend", "seasonality"],
... num_blocks=[3, 3],
... num_block_layers=[3, 3],
... widths=[32, 512],
... sharing=[True, True],
... expansion_coefficient_lengths=[3, 7],
... prediction_length=24,
... context_length=72,
... )

Initialization from dataset (recommended):

>>> from pytorch_forecasting import TimeSeriesDataSet, NBeats
>>> dataset = TimeSeriesDataSet(...)
>>> model = NBeats.from_dataset(dataset)
"""
if expansion_coefficient_lengths is None:
expansion_coefficient_lengths = [3, 7]
if sharing is None:
Expand All @@ -107,6 +141,32 @@
logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()])
if loss is None:
loss = MASE()

# Validate parameters
if not isinstance(prediction_length, int) or prediction_length < 1:
raise ValueError("prediction_length must be a positive integer")
if not isinstance(context_length, int) or context_length < 1:
raise ValueError("context_length must be a positive integer")

Check warning on line 149 in pytorch_forecasting/models/nbeats/_nbeats.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/nbeats/_nbeats.py#L149

Added line #L149 was not covered by tests
if not all(s in ["generic", "seasonality", "trend"] for s in stack_types):
raise ValueError(
"stack_types must contain only 'generic', 'seasonality', or 'trend'"
)

# Validate list lengths
n_stacks = len(stack_types)
for param_name, param_value in [
("num_blocks", num_blocks),
("num_block_layers", num_block_layers),
("widths", widths),
("sharing", sharing),
("expansion_coefficient_lengths", expansion_coefficient_lengths),
]:
if len(param_value) != n_stacks:
raise ValueError(
f"Length of {param_name} ({len(param_value)}) must match "
f"length of stack_types ({n_stacks})"
)

self.save_hyperparameters()
super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs)

Expand Down Expand Up @@ -223,15 +283,22 @@
@classmethod
def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs):
"""
Convenience function to create network from :py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`.
Create an NBeats model from a TimeSeriesDataSet.

This is the recommended way to create an NBeats model for standard
time series forecasting. For custom uses where dataset constraints
don't fit, initialize the model directly using the constructor.

Args:
dataset (TimeSeriesDataSet): dataset where sole predictor is the target.
**kwargs: additional arguments to be passed to ``__init__`` method.

Returns:
NBeats
""" # noqa: E501
NBeats: initialized model

Raises:
AssertionError: if dataset constraints are not met
"""
new_kwargs = {
"prediction_length": dataset.max_prediction_length,
"context_length": dataset.max_encoder_length,
Expand Down Expand Up @@ -352,17 +419,18 @@
"""
Plot interpretation.

Plot two pannels: prediction and backcast vs actuals and
decomposition of prediction into trend, seasonality and generic forecast.
Plot two pannels: prediction and backcast vs actuals and decomposition of prediction
into trend, seasonality and generic forecast.

Args:
x (Dict[str, torch.Tensor]): network input
output (Dict[str, torch.Tensor]): network output
idx (int): index of sample for which to plot the interpretation.
ax (List[matplotlib axes], optional): list of two matplotlib axes onto which to plot the interpretation.
Defaults to None.
plot_seasonality_and_generic_on_secondary_axis (bool, optional): if to plot seasonality and
generic forecast on secondary axis in second panel. Defaults to False.
ax (List[matplotlib axes], optional): list of two matplotlib axes onto which to
plot the interpretation. Defaults to None.
plot_seasonality_and_generic_on_secondary_axis (bool, optional): if to plot
seasonality and generic forecast on secondary axis in second panel. Defaults
to False.

Returns:
plt.Figure: matplotlib figure
Expand Down
40 changes: 38 additions & 2 deletions tests/test_models/test_nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_integration(dataloaders_fixed_window_without_covariates, tmp_path):
train_dataloader.dataset,
learning_rate=0.15,
log_gradient_flow=True,
widths=[4, 4, 4],
widths=[4, 4],
log_interval=1000,
backcast_loss_ratio=1.0,
)
Expand Down Expand Up @@ -77,7 +77,7 @@ def model(dataloaders_fixed_window_without_covariates):
dataset,
learning_rate=0.15,
log_gradient_flow=True,
widths=[4, 4, 4],
widths=[4, 4],
log_interval=1000,
backcast_loss_ratio=1.0,
)
Expand All @@ -101,3 +101,39 @@ def test_interpretation(model, dataloaders_fixed_window_without_covariates):
fast_dev_run=True,
)
model.plot_interpretation(raw_predictions.x, raw_predictions.output, idx=0)


def test_direct_initialization():
# Test that the model can be initialized directly without from_dataset
net = NBeats(
stack_types=["trend", "seasonality"],
num_blocks=[3, 3],
num_block_layers=[3, 3],
widths=[32, 512],
sharing=[True, True],
expansion_coefficient_lengths=[3, 7],
prediction_length=24,
context_length=72,
)
assert len(net.net_blocks) == 6 # 2 stacks * 3 blocks each
assert net.hparams.prediction_length == 24
assert net.hparams.context_length == 72

# Test validation of parameters
with pytest.raises(ValueError, match="stack_types must contain only"):
NBeats(stack_types=["invalid_type"])

with pytest.raises(ValueError, match="Length of num_blocks"):
NBeats(
stack_types=["trend", "seasonality"],
num_blocks=[3], # Should be length 2
prediction_length=24,
context_length=72,
)

with pytest.raises(ValueError, match="prediction_length must be"):
NBeats(
stack_types=["trend", "seasonality"],
prediction_length=0, # Invalid
context_length=72,
)
Loading