Skip to content

Allow direct initialization of NBeats model without from_dataset #1837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

MeherBhaskar
Copy link

@MeherBhaskar MeherBhaskar commented May 19, 2025

Description

This PR fixes issue #1548 where the NBeats model fails when initialized without using the from_dataset method. The changes enable direct initialization while maintaining backward compatibility and adding proper validation.

Changes

  • Enhanced __init__ method with comprehensive parameter validation
  • Added clear documentation for both initialization methods (direct and from_dataset)
  • Added validation for parameter list lengths to match stack types
  • Updated from_dataset method documentation to clarify its role
  • Added unit tests for direct initialization

Now users can:

  1. Use from_dataset (recommended) for standard time series forecasting
  2. Initialize directly for custom use cases
  3. Get clear error messages when parameters are invalid

Example of direct initialization:

model = NBeats(
    stack_types=["trend", "seasonality"],
    num_blocks=[3, 3],
    num_block_layers=[3, 3],
    widths=[32, 512],
    sharing=[True, True],
    expansion_coefficient_lengths=[3, 7],
    prediction_length=24,
    context_length=72,
)

Checklist

  • Linked issues (if existing)
  • Amended changelog for large changes (and added myself there as contributor)
  • Added/modified tests
  • Used pre-commit hooks when committing to ensure that code is compliant with hooks. Install hooks with pre-commit install.
    To run hooks independent of commit, execute pre-commit run --all-files

Make sure to have fun coding!

@MeherBhaskar MeherBhaskar changed the title low direct initialization of NBeats model without from_dataset Allow direct initialization of NBeats model without from_dataset May 19, 2025
@MeherBhaskar MeherBhaskar marked this pull request as draft May 19, 2025 04:39
@MeherBhaskar MeherBhaskar marked this pull request as ready for review May 19, 2025 04:42
@MeherBhaskar MeherBhaskar marked this pull request as draft May 19, 2025 04:43
@MeherBhaskar MeherBhaskar marked this pull request as ready for review May 19, 2025 04:49
@MeherBhaskar MeherBhaskar marked this pull request as draft May 19, 2025 04:51
Copy link

codecov bot commented May 19, 2025

Codecov Report

Attention: Patch coverage is 90.00000% with 1 line in your changes missing coverage. Please review.

Please upload report for BASE (main@4c30351). Learn more about missing BASE report.

Files with missing lines Patch % Lines
pytorch_forecasting/models/nbeats/_nbeats.py 90.00% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1837   +/-   ##
=======================================
  Coverage        ?   86.86%           
=======================================
  Files           ?       51           
  Lines           ?     5673           
  Branches        ?        0           
=======================================
  Hits            ?     4928           
  Misses          ?      745           
  Partials        ?        0           
Flag Coverage Δ
cpu 86.86% <90.00%> (?)
pytest 86.86% <90.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Quick question, why are we changing the old test? Does this mean existing code breaks?

We should avoid to break working code of users.

@MeherBhaskar
Copy link
Author

MeherBhaskar commented May 21, 2025

Hi @fkiraly,
The change to update the widths from [4, 4, 4] to [4, 4] was made to match the number of stacks used by the N-BEATS model (["trend", "seasonality"])

Passing a list of length 3 was causing the following error:

ValueError: Length of widths (3) must match length of stack_types (2)

Let me know if you'd like to handle this differently!

@MeherBhaskar MeherBhaskar requested a review from fkiraly May 21, 2025 04:19
@fkiraly
Copy link
Collaborator

fkiraly commented May 21, 2025

Can you explain why the [4, 4, 4] was previously passing them in the tests? If it worked previously - and seems to have been standard expectation, an dnow we change it, we might break someone's code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants