Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Omegaconfig PR1: basic functionality #97

Open
wants to merge 29 commits into
base: main
Choose a base branch
from

Conversation

gqcpm
Copy link
Contributor

@gqcpm gqcpm commented Sep 27, 2024

This is the first PR of #75. Here we want to:

  • Migrate over the attrs config classes from sleap.nn.config, starting with TrainingJobConfig and moving down the hierarchy. These get migrated to a submodule under sleap-nn/sleap_nn/config.
  • Update class definitions to new attrs API.
  • Replace cattr serialization with OmegaConf.
  • Replace the functionality of the oneof decorator with OmegaConf-based routines if possible.
  • Implement cross-field validation for linked attributes.

Summary by CodeRabbit

  • New Features

    • Introduced new configuration classes for data management, model architectures, and training parameters.
    • Added support for various data preprocessing and augmentation settings.
    • Enhanced model configuration with new backbone types and head configurations.
    • Implemented a comprehensive trainer configuration for improved training and validation processes.
    • Added unit tests for configuration classes to ensure correct default and custom settings.
  • Bug Fixes

    • Improved validation logic for model configurations to prevent invalid setups.
  • Documentation

    • Updated documentation to reflect new configuration options and usage guidelines.

Copy link
Contributor

coderabbitai bot commented Sep 27, 2024

Walkthrough

This pull request introduces significant changes to the configuration management of the sleap_nn project. A new file, data_config.py, is added to define classes for data management parameters in machine learning workflows. Additionally, model_config.py is updated to enhance the ModelConfig class and introduce new classes for various model architectures. The trainer_config.py file is also introduced, encapsulating configurations for training jobs and data loading. These changes aim to standardize and improve configuration handling across the project.

Changes

File Path Change Summary
sleap_nn/config/data_config.py Added classes: DataConfig, PreprocessingConfig, AugmentationConfig, IntensityConfig, GeometricConfig with various attributes for data management.
sleap_nn/config/model_config.py Enhanced ModelConfig with new attributes and methods; added classes: BackboneConfig, UNetConfig, ConvNextConfig, SwinTConfig, HeadsConfig, SingleInstanceConfig, CentroidConfig, CenteredInstanceConfig, BottomUpConfig, SingleInstanceConfMapsConfig, CentroidConfMapsConfig, CenteredInstanceConfMapsConfig, BottomUpConfMapsConfig, PAFConfig.
sleap_nn/config/trainer_config.py Introduced TrainerConfig, DataLoaderConfig, ModelCkptConfig, WandBConfig, OptimizerConfig, LRSchedulerConfig, and EarlyStoppingConfig classes for training job configurations.
tests/config/test_trainer_config.py Added unit tests for configuration classes related to training setup using pytest and OmegaConf.

Possibly related issues

Possibly related PRs

  • Refactor model pipeline #51: Refactor model pipeline - This PR modifies the model configuration, which may relate to how data configurations are utilized in the main PR's new data_config.py.
  • Refactor Augmentation config #67: Refactor Augmentation config - This PR directly addresses changes in the augmentation configuration, which is relevant to the new AugmentationConfig class introduced in the main PR.
  • Fix augmentation in TopdownConfmaps pipeline #78: Fix augmentation in TopdownConfmaps pipeline - This PR enhances the augmentation logic, which is closely related to the new augmentation classes in the main PR.
  • Implement tracker module #87: Implement tracker module - This PR introduces a tracking mechanism that may interact with the data configurations, particularly in how instances are processed, which could relate to the new data management classes in the main PR.
  • Fix sizematcher in Inference data pipline #102: Fix sizematcher in Inference data pipeline - This PR addresses issues in the inference pipeline that may involve the new configurations for data processing introduced in the main PR.
  • Convert Tensor images to PIL #105: Convert Tensor images to PIL - This PR modifies how images are processed and stored, which could relate to the new data handling introduced in the main PR.

Suggested reviewers

  • talmo: Suggested due to their involvement in the project and familiarity with the configuration management.

🐰 In the meadow, where data flows,
New configs sprout, as knowledge grows.
With models and trainers, all set to play,
Let's hop to the future, bright as the day!
🥕✨


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

codecov bot commented Sep 27, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 97.34%. Comparing base (f093ce2) to head (16c8f26).
Report is 20 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main      #97      +/-   ##
==========================================
+ Coverage   96.64%   97.34%   +0.70%     
==========================================
  Files          23       38      +15     
  Lines        1818     3694    +1876     
==========================================
+ Hits         1757     3596    +1839     
- Misses         61       98      +37     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Outside diff range and nitpick comments (3)
sleap_nn/config/model_config.py (2)

1-2: Consider adding explicit attribute definitions.

The ModelConfig class is correctly using the @attr.s decorator with auto_attribs=True. However, the class body is empty, which is unusual. Consider adding explicit attribute definitions to improve code clarity and enable IDE autocompletion.

Here's an example of how you could define the attributes:

@attr.s(auto_attribs=True)
class ModelConfig:
    backbone: dict
    heads: dict
    base_checkpoint: str

This will make the class structure more explicit and easier to understand at a glance.

🧰 Tools
🪛 Ruff

1-1: Undefined name attr

(F821)


3-9: Docstring looks good, minor formatting suggestion.

The docstring provides clear and informative descriptions of the class and its attributes. Well done!

Consider adding a period at the end of the last line for consistency:

-        base_checkpoint: Path to model folder for loading a checkpoint. Should contain the .h5 file
+        base_checkpoint: Path to model folder for loading a checkpoint. Should contain the .h5 file.
sleap_nn/config/training_job.py (1)

1-42: Overall assessment and recommendations

The TrainingJobConfig class provides a well-structured and documented foundation for managing training job configurations. However, there are a few improvements needed to make it complete and error-free:

  1. Add the missing imports for attr, DataConfig, and other required classes.
  2. Implement the remaining attributes mentioned in the class docstring.
  3. Consider adding type hints for all attributes to improve code readability and catch potential type-related issues early.
  4. If not already present in your project, consider adding a requirements.txt or setup.py file to specify the attrs library as a dependency.

Once these changes are made, the TrainingJobConfig class will be a robust and well-documented configuration management solution for your training jobs.

🧰 Tools
🪛 Ruff

27-27: Undefined name attr

(F821)


42-42: Undefined name DataConfig

(F821)


42-42: Undefined name attr

(F821)


42-42: Undefined name DataConfig

(F821)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between c3be90b and 72b9d7b.

📒 Files selected for processing (3)
  • sleap_nn/config/data_config.py (1 hunks)
  • sleap_nn/config/model_config.py (1 hunks)
  • sleap_nn/config/training_job.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py

1-1: Undefined name attr

(F821)


11-11: Undefined name LabelsConfig

(F821)


11-11: Undefined name attr

(F821)


11-11: Undefined name LabelsConfig

(F821)


12-12: Undefined name PreprocessingConfig

(F821)


12-12: Undefined name attr

(F821)


12-12: Undefined name PreprocessingConfig

(F821)


13-13: Undefined name InstanceCroppingConfig

(F821)


13-13: Undefined name attr

(F821)


13-13: Undefined name InstanceCroppingConfig

(F821)

sleap_nn/config/model_config.py

1-1: Undefined name attr

(F821)

sleap_nn/config/training_job.py

27-27: Undefined name attr

(F821)


42-42: Undefined name DataConfig

(F821)


42-42: Undefined name attr

(F821)


42-42: Undefined name DataConfig

(F821)

🔇 Additional comments (3)
sleap_nn/config/data_config.py (2)

3-9: Well-structured and informative docstring

The class docstring is well-written, providing a clear description of the class purpose and each of its attributes. It follows good documentation practices, which will help other developers understand the purpose and usage of this configuration class.


1-13: Overall well-structured and maintainable code

The DataConfig class is well-designed and follows good coding practices. It uses the attr library effectively for attribute management, has clear and descriptive naming, and includes informative docstrings. The suggested improvements (adding missing imports) are minor and easy to implement. Once these changes are made, the code will be more robust and easier to maintain.

🧰 Tools
🪛 Ruff

1-1: Undefined name attr

(F821)


11-11: Undefined name LabelsConfig

(F821)


11-11: Undefined name attr

(F821)


11-11: Undefined name LabelsConfig

(F821)


12-12: Undefined name PreprocessingConfig

(F821)


12-12: Undefined name attr

(F821)


12-12: Undefined name PreprocessingConfig

(F821)


13-13: Undefined name InstanceCroppingConfig

(F821)


13-13: Undefined name attr

(F821)


13-13: Undefined name InstanceCroppingConfig

(F821)

sleap_nn/config/training_job.py (1)

1-25: Excellent documentation!

The file-level docstring provides a clear and comprehensive explanation of the purpose and design principles behind the configuration classes. It effectively communicates the separation of configuration from implementation and highlights the benefits of this approach, including easy serialization and human-readability.

@@ -0,0 +1,9 @@
@attr.s(auto_attribs=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add import statement for attr.

The static analysis tool has flagged an undefined name attr. This is because the import statement for the attr module is missing.

Add the following import statement at the beginning of the file:

import attr

This will resolve the undefined name issue and make the usage of @attr.s decorator valid.

🧰 Tools
🪛 Ruff

1-1: Undefined name attr

(F821)

sleap_nn/config/data_config.py Outdated Show resolved Hide resolved
sleap_nn/config/data_config.py Outdated Show resolved Hide resolved
Comment on lines 27 to 40
@attr.s(auto_attribs=True)
class TrainingJobConfig:
"""Configuration of a training job.

Attributes:
data: Configuration options related to the training data.
model: Configuration options related to the model architecture.
optimization: Configuration options related to the training.
outputs: Configuration options related to outputs during training.
name: Optional name for this configuration profile.
description: Optional description of the configuration.
sleap_version: Version of SLEAP that generated this configuration.
filename: Path to this config file if it was loaded from disk.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add missing import for attr

The attr module is used to define the TrainingJobConfig class, but it's not imported. This is causing the static analysis tool to flag an error.

Please add the following import at the beginning of the file:

import attr

This will resolve the undefined name error for attr.

🧰 Tools
🪛 Ruff

27-27: Undefined name attr

(F821)

filename: Path to this config file if it was loaded from disk.
"""

data: DataConfig = attr.ib(factory=DataConfig)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add missing import for DataConfig

The DataConfig class is used as the type for the data attribute, but it's not imported. This is causing the static analysis tool to flag an error.

Please add the following import at the beginning of the file:

from sleap_nn.config.data import DataConfig

This will resolve the undefined name error for DataConfig.

🧰 Tools
🪛 Ruff

42-42: Undefined name DataConfig

(F821)


42-42: Undefined name attr

(F821)


42-42: Undefined name DataConfig

(F821)


⚠️ Potential issue

Add missing attributes to TrainingJobConfig

The class docstring mentions several attributes that are not present in the provided code snippet. To ensure consistency with the documentation and completeness of the class, please add the following attributes:

  • model
  • optimization
  • outputs
  • name
  • description
  • sleap_version
  • filename

Here's a suggested implementation for the missing attributes:

@attr.s(auto_attribs=True)
class TrainingJobConfig:
    # ... (existing docstring)

    data: DataConfig = attr.ib(factory=DataConfig)
    model: ModelConfig = attr.ib(factory=ModelConfig)
    optimization: OptimizationConfig = attr.ib(factory=OptimizationConfig)
    outputs: OutputConfig = attr.ib(factory=OutputConfig)
    name: Optional[str] = attr.ib(default=None)
    description: Optional[str] = attr.ib(default=None)
    sleap_version: str = attr.ib(default=attr.Factory(lambda: sleap.__version__))
    filename: Optional[str] = attr.ib(default=None)

Note: You'll need to import the necessary classes (ModelConfig, OptimizationConfig, OutputConfig) and Optional from typing. Also, import sleap to get the version.

🧰 Tools
🪛 Ruff

42-42: Undefined name DataConfig

(F821)


42-42: Undefined name attr

(F821)


42-42: Undefined name DataConfig

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Outside diff range and nitpick comments (2)
sleap_nn/config/data_config.py (2)

1-3: Remove unused import

The List type from typing is imported but not used in the file. Consider removing it to keep the imports clean.

Apply this diff to remove the unused import:

-from typing import Optional, Tuple, List, Dict
+from typing import Optional, Tuple, Dict
🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


38-51: Approve IntensityConfig class and suggest comment update

The IntensityConfig class is well-structured with appropriate attributes for intensity-related augmentation options. However, there's a comment indicating that this part is still under review.

Once the review of this section is complete, consider removing or updating the following comment:

# still looking over this part and below (function to check if is greater than 0)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 72b9d7b and 80e5915.

📒 Files selected for processing (1)
  • sleap_nn/config/data_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


18-18: Undefined name PreprocessingConfig

(F821)


18-18: Undefined name PreprocessingConfig

(F821)


20-20: Undefined name AugmentationConfig

(F821)


28-28: Undefined name Union

(F821)


35-35: Undefined name IntensityConfig

(F821)


36-36: Undefined name GeometricConfig

(F821)

🔇 Additional comments (2)
sleap_nn/config/data_config.py (2)

53-67: Approve GeometricConfig class

The GeometricConfig class is well-structured with appropriate attributes for geometric augmentation options. The default values seem reasonable, and the use of Optional types where appropriate is good practice.


1-67: Summary of review

Overall, the configuration classes in this file are well-structured and provide a comprehensive set of options for data handling, preprocessing, and augmentation in the SLEAP-NN project. The use of attr.s with auto_attribs=True is a good practice for creating classes with less boilerplate code.

Main points to address:

  1. Remove the unused List import.
  2. Fix circular import issues by using string annotations for types within the same file.
  3. Import the Union type from typing.
  4. Update or remove the comment in the IntensityConfig class once the review is complete.

After addressing these points, the file will be in excellent shape for use in the project.

🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


18-18: Undefined name PreprocessingConfig

(F821)


18-18: Undefined name PreprocessingConfig

(F821)


20-20: Undefined name AugmentationConfig

(F821)


28-28: Undefined name Union

(F821)


35-35: Undefined name IntensityConfig

(F821)


36-36: Undefined name GeometricConfig

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Outside diff range and nitpick comments (1)
sleap_nn/config/data_config.py (1)

1-3: Remove unused import

The List type is imported but not used in the file. Consider removing it to keep the imports clean.

Apply this diff to remove the unused import:

-from typing import Optional, Tuple, List, Dict
+from typing import Optional, Tuple, Dict
🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 80e5915 and 186be74.

📒 Files selected for processing (1)
  • sleap_nn/config/data_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


18-18: Undefined name PreprocessingConfig

(F821)


18-18: Undefined name PreprocessingConfig

(F821)


20-20: Undefined name AugmentationConfig

(F821)


28-28: Undefined name Union

(F821)


35-35: Undefined name IntensityConfig

(F821)


36-36: Undefined name GeometricConfig

(F821)

🔇 Additional comments (3)
sleap_nn/config/data_config.py (3)

38-61: Approve IntensityConfig class

The IntensityConfig class is well-structured with appropriate attributes for intensity augmentation options. The inclusion of validator methods for uniform_noise_min and uniform_noise_max is a good practice to ensure valid values.

The class implementation looks good and provides a comprehensive set of options for intensity-based augmentations.


63-77: Approve GeometricConfig class

The GeometricConfig class is well-structured with appropriate attributes for geometric augmentation options. The attribute names are clear and descriptive, and the types are suitable for their intended purposes.

The class implementation provides a comprehensive set of options for geometric augmentations, which should cover most common use cases in image augmentation pipelines.


1-77: Overall assessment: Well-structured configuration classes with minor improvements needed

The sleap_nn/config/data_config.py file introduces a set of well-structured configuration classes for data handling, preprocessing, and augmentation. The use of the attr library and type hints enhances code readability and maintainability.

Key points:

  1. The classes provide comprehensive options for various aspects of data processing and augmentation.
  2. The code follows good practices in terms of class structure and attribute definitions.
  3. Validator methods in the IntensityConfig class add an extra layer of safety.

Minor improvements suggested:

  1. Remove unused import (List).
  2. Add missing import (Union).
  3. Address potential circular import issues with string annotations.

Overall, the file provides a solid foundation for configuring data-related operations in the project. After addressing the minor issues, it will be ready for use.

🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


18-18: Undefined name PreprocessingConfig

(F821)


18-18: Undefined name PreprocessingConfig

(F821)


20-20: Undefined name AugmentationConfig

(F821)


28-28: Undefined name Union

(F821)


35-35: Undefined name IntensityConfig

(F821)


36-36: Undefined name GeometricConfig

(F821)

Comment on lines 23 to 30
@attr.s(auto_attribs=True)
class PreprocessingConfig:
is_rgb: bool = True
max_height: Optional[int] = None
max_width: Optional[int] = None
scale: Union[float, Tuple[float, float]] = 1.0
crop_hw: Optional[Tuple[int, int]] = None
min_crop_size: int = 32 #to help app work incase of error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Approve PreprocessingConfig class and suggest import fix

The PreprocessingConfig class is well-structured with appropriate attributes for preprocessing options. However, the Union type is used but not imported.

Add the following import at the top of the file:

from typing import Union
🧰 Tools
🪛 Ruff

28-28: Undefined name Union

(F821)

sleap_nn/config/data_config.py Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Outside diff range and nitpick comments (1)
sleap_nn/config/data_config.py (1)

1-3: Remove unused import

The List type is imported from typing but not used in the file. Consider removing it to keep the imports clean.

Apply this diff to remove the unused import:

-from typing import Optional, Tuple, List, Dict
+from typing import Optional, Tuple, Dict
🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 186be74 and 796788d.

📒 Files selected for processing (1)
  • sleap_nn/config/data_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


25-25: Undefined name PreprocessingConfig

(F821)


25-25: Undefined name PreprocessingConfig

(F821)


27-27: Undefined name AugmentationConfig

(F821)


46-46: Undefined name Union

(F821)


61-61: Undefined name IntensityConfig

(F821)


62-62: Undefined name GeometricConfig

(F821)

🔇 Additional comments (3)
sleap_nn/config/data_config.py (3)

64-103: LGTM: IntensityConfig class is well-structured

The IntensityConfig class is well-defined with appropriate attributes and validators for uniform noise parameters. The use of attr.s decorator with auto_attribs=True is a good practice for creating classes with less boilerplate code.


105-138: LGTM: GeometricConfig class is well-structured

The GeometricConfig class is well-defined with appropriate attributes for geometric augmentations. The use of attr.s decorator with auto_attribs=True is consistent with the other classes in this file.


1-138: Overall: Well-structured configuration classes for data handling

This file introduces a set of well-organized configuration classes for managing various aspects of data handling in machine learning workflows. The use of attr.s decorators with auto_attribs=True promotes clean and maintainable code. The classes cover essential aspects such as data sources, preprocessing, and augmentation options.

A few minor improvements have been suggested regarding imports and forward references. Once these are addressed, the file will provide a robust foundation for configuring data-related parameters in the project.

🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


25-25: Undefined name PreprocessingConfig

(F821)


25-25: Undefined name PreprocessingConfig

(F821)


27-27: Undefined name AugmentationConfig

(F821)


46-46: Undefined name Union

(F821)


61-61: Undefined name IntensityConfig

(F821)


62-62: Undefined name GeometricConfig

(F821)

sleap_nn/config/data_config.py Outdated Show resolved Hide resolved
sleap_nn/config/data_config.py Outdated Show resolved Hide resolved
sleap_nn/config/data_config.py Outdated Show resolved Hide resolved
@gqcpm gqcpm changed the title Greg/omegaconf basic func Implement Omegaconfig PR1: basic functionality Oct 9, 2024
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Outside diff range and nitpick comments (2)
sleap_nn/config/data_config.py (2)

1-3: Optimize imports

Please make the following changes to the import statements:

  1. Remove the unused import of List.
  2. Add Union to the imports from typing.

Apply this diff to optimize the imports:

 import attrs
 from omegaconf import MISSING
-from typing import Optional, Tuple, List, Dict
+from typing import Optional, Tuple, Dict, Union

This change will remove the unused import and add the necessary Union type.

🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


94-127: Approve GeometricConfig class and fix typo

The GeometricConfig class is well-structured with appropriate attributes and default values. The docstring provides clear explanations for each attribute.

There's a small typo in the docstring for the erase_ratio_min attribute. Please fix it as follows:

-        erase_ration_min: (float) Minimum value of range of aspect ratio of erased area. Default: 1.
+        erase_ratio_min: (float) Minimum value of range of aspect ratio of erased area. Default: 1.

This change will correct the attribute name in the docstring.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 796788d and 34548b8.

📒 Files selected for processing (1)
  • sleap_nn/config/data_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


25-25: Undefined name PreprocessingConfig

(F821)


25-25: Undefined name PreprocessingConfig

(F821)


27-27: Undefined name AugmentationConfig

(F821)


46-46: Undefined name Union

(F821)


61-61: Undefined name IntensityConfig

(F821)


62-62: Undefined name GeometricConfig

(F821)

🔇 Additional comments (3)
sleap_nn/config/data_config.py (3)

30-48: Approve PreprocessingConfig class and remind about Union import

The PreprocessingConfig class is well-structured with appropriate attributes and default values. The docstring provides clear explanations for each attribute.

Remember to add the Union import as suggested in the previous comment about optimizing imports.

🧰 Tools
🪛 Ruff

46-46: Undefined name Union

(F821)


64-92: Approve IntensityConfig class

The IntensityConfig class is well-structured with appropriate attributes, default values, and validators. The docstring provides clear explanations for each attribute. The use of attrs.field with validators for uniform_noise_min, uniform_noise_max, contrast_min, and contrast_max is a good practice to ensure valid input.


1-127: Overall assessment: Well-structured configuration classes with minor improvements needed

The data_config.py file introduces a set of well-structured configuration classes for managing data parameters in machine learning workflows. The use of the attrs library and detailed docstrings contributes to clean and self-documenting code.

Main points of improvement:

  1. Optimize imports by removing unused ones and adding missing ones.
  2. Use forward references to resolve potential circular import issues.
  3. Fix a minor typo in the GeometricConfig class docstring.

Once these changes are implemented, the file will be in excellent shape, providing a robust foundation for configuring data-related parameters in your project.

🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


25-25: Undefined name PreprocessingConfig

(F821)


25-25: Undefined name PreprocessingConfig

(F821)


27-27: Undefined name AugmentationConfig

(F821)


46-46: Undefined name Union

(F821)


61-61: Undefined name IntensityConfig

(F821)


62-62: Undefined name GeometricConfig

(F821)

Comment on lines 12 to 27
@attrs.define
class DataConfig:
"""Data configuration.

labels: Configuration options related to user labels for training or testing.
preprocessing: Configuration options related to data preprocessing.
instance_cropping: Configuration options related to instance cropping for centroid
and topdown models.
"""

provider: str="LabelsReader"
train_labels_path: str=MISSING
val_labels_path: str=MISSING
preprocessing: PreprocessingConfig = attrs.field(factory=PreprocessingConfig)
use_augmentations_train: bool=False
augmentation_config: Optional[AugmentationConfig] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use forward references for configuration classes

The DataConfig class looks well-structured, but there are undefined names for PreprocessingConfig and AugmentationConfig. To avoid potential circular imports, use forward references for these classes.

Apply this diff to use forward references:

+from __future__ import annotations
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from sleap_nn.config.data_config import PreprocessingConfig, AugmentationConfig

 @attrs.define
 class DataConfig:
     # ... (existing code)
-    preprocessing: PreprocessingConfig = attrs.field(factory=PreprocessingConfig)
-    augmentation_config: Optional[AugmentationConfig] = None
+    preprocessing: 'PreprocessingConfig' = attrs.field(factory=lambda: PreprocessingConfig())
+    augmentation_config: Optional['AugmentationConfig'] = None

This change will resolve the undefined names issue and prevent potential circular imports.

Committable suggestion was skipped due to low confidence.

🧰 Tools
🪛 Ruff

25-25: Undefined name PreprocessingConfig

(F821)


25-25: Undefined name PreprocessingConfig

(F821)


27-27: Undefined name AugmentationConfig

(F821)

Comment on lines 50 to 62
@attrs.define
class AugmentationConfig:
""" Configuration of Augmentation

Attributes:
random crop: (Optional) (Dict[float]) {"random_crop_p": None, "crop_height": None. "crop_width": None}, where random_crop_p is the probability of applying random crop and crop_height and crop_width are the desired output size (out_h, out_w) of the crop.
intensity: (Optional)
geometric: (Optional)
"""

random_crop: Optional[Dict[str, Optional[float]]] = None
intensity: Optional[IntensityConfig] = attrs.field(default=None)
geometric: Optional[GeometricConfig] = attrs.field(default=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use forward references for IntensityConfig and GeometricConfig

The AugmentationConfig class is well-structured, but there are undefined names for IntensityConfig and GeometricConfig. To avoid potential circular imports, use forward references for these classes.

Apply this diff to use forward references:

 @attrs.define
 class AugmentationConfig:
     # ... (existing code)
     random_crop: Optional[Dict[str, Optional[float]]] = None
-    intensity: Optional[IntensityConfig] = attrs.field(default=None)
-    geometric: Optional[GeometricConfig] = attrs.field(default=None)
+    intensity: Optional['IntensityConfig'] = attrs.field(default=None)
+    geometric: Optional['GeometricConfig'] = attrs.field(default=None)

This change will resolve the undefined names issue and prevent potential circular imports.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@attrs.define
class AugmentationConfig:
""" Configuration of Augmentation
Attributes:
random crop: (Optional) (Dict[float]) {"random_crop_p": None, "crop_height": None. "crop_width": None}, where random_crop_p is the probability of applying random crop and crop_height and crop_width are the desired output size (out_h, out_w) of the crop.
intensity: (Optional)
geometric: (Optional)
"""
random_crop: Optional[Dict[str, Optional[float]]] = None
intensity: Optional[IntensityConfig] = attrs.field(default=None)
geometric: Optional[GeometricConfig] = attrs.field(default=None)
@attrs.define
class AugmentationConfig:
""" Configuration of Augmentation
Attributes:
random crop: (Optional) (Dict[float]) {"random_crop_p": None, "crop_height": None. "crop_width": None}, where random_crop_p is the probability of applying random crop and crop_height and crop_width are the desired output size (out_h, out_w) of the crop.
intensity: (Optional)
geometric: (Optional)
"""
random_crop: Optional[Dict[str, Optional[float]]] = None
intensity: Optional['IntensityConfig'] = attrs.field(default=None)
geometric: Optional['GeometricConfig'] = attrs.field(default=None)
🧰 Tools
🪛 Ruff

61-61: Undefined name IntensityConfig

(F821)


62-62: Undefined name GeometricConfig

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (4)
sleap_nn/config/data_config.py (4)

1-3: Remove unused import

The List type is imported from typing but not used in the file. To keep the imports clean, you can remove it.

Apply this change:

- from typing import Optional, Tuple, List, Dict
+ from typing import Optional, Tuple, Dict
🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


12-35: Approve DataConfig class and suggest forward references

The DataConfig class is well-structured with clear attribute definitions and a comprehensive docstring. However, there are undefined names for PreprocessingConfig and AugmentationConfig.

To resolve the undefined names and prevent potential circular imports, use forward references for these classes. Apply this diff:

+from __future__ import annotations
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from sleap_nn.config.data_config import PreprocessingConfig, AugmentationConfig

 @attrs.define
 class DataConfig:
     # ... (existing code)
-    preprocessing: PreprocessingConfig = attrs.field(factory=PreprocessingConfig)
-    augmentation_config: Optional[AugmentationConfig] = None
+    preprocessing: 'PreprocessingConfig' = attrs.field(factory=lambda: PreprocessingConfig())
+    augmentation_config: Optional['AugmentationConfig'] = None

This change will resolve the undefined names issue and prevent potential circular imports.

🧰 Tools
🪛 Ruff

33-33: Undefined name PreprocessingConfig

(F821)


33-33: Undefined name PreprocessingConfig

(F821)


35-35: Undefined name AugmentationConfig

(F821)


38-56: Approve PreprocessingConfig class and add missing import

The PreprocessingConfig class is well-structured with appropriate attributes for preprocessing options. However, the Union type is used but not imported.

Add the following import at the top of the file:

- from typing import Optional, Tuple, Dict
+ from typing import Optional, Tuple, Dict, Union

This will resolve the undefined Union type issue.

🧰 Tools
🪛 Ruff

54-54: Undefined name Union

(F821)


58-70: Approve AugmentationConfig class and use forward references

The AugmentationConfig class is well-structured, but there are undefined names for IntensityConfig and GeometricConfig.

To resolve the undefined names and prevent potential circular imports, use forward references for these classes. Apply this diff:

 @attrs.define
 class AugmentationConfig:
     # ... (existing code)
     random_crop: Optional[Dict[str, Optional[float]]] = None
-    intensity: Optional[IntensityConfig] = attrs.field(default=None)
-    geometric: Optional[GeometricConfig] = attrs.field(default=None)
+    intensity: Optional['IntensityConfig'] = attrs.field(default=None)
+    geometric: Optional['GeometricConfig'] = attrs.field(default=None)

This change will resolve the undefined names issue and prevent potential circular imports.

🧰 Tools
🪛 Ruff

69-69: Undefined name IntensityConfig

(F821)


70-70: Undefined name GeometricConfig

(F821)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 34548b8 and 7ea94b6.

📒 Files selected for processing (1)
  • sleap_nn/config/data_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


33-33: Undefined name PreprocessingConfig

(F821)


33-33: Undefined name PreprocessingConfig

(F821)


35-35: Undefined name AugmentationConfig

(F821)


54-54: Undefined name Union

(F821)


69-69: Undefined name IntensityConfig

(F821)


70-70: Undefined name GeometricConfig

(F821)

🔇 Additional comments (3)
sleap_nn/config/data_config.py (3)

72-100: Approve IntensityConfig class

The IntensityConfig class is well-structured with clear attribute definitions, comprehensive docstring, and appropriate use of validators for certain attributes. This ensures data integrity and provides clear documentation for users.

The implementation looks good and follows best practices.


102-135: Approve GeometricConfig class

The GeometricConfig class is well-structured with clear attribute definitions, comprehensive docstring, and appropriate default values for all attributes. This provides a good starting point for users and clear documentation of the available options.

The implementation looks good and follows best practices.


1-135: Final approval: Well-structured configuration classes with minor improvements needed

Overall, the implementation of the configuration classes in sleap_nn/config/data_config.py is well-done. The classes are clearly structured, with comprehensive docstrings and appropriate use of the attrs library. The suggested improvements mainly revolve around import statements and using forward references to prevent circular imports.

Key points:

  1. Remove the unused List import.
  2. Add the missing Union import.
  3. Use forward references for PreprocessingConfig, AugmentationConfig, IntensityConfig, and GeometricConfig to resolve undefined names and prevent potential circular imports.

After implementing these minor changes, the code will be more robust and maintainable.

🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


33-33: Undefined name PreprocessingConfig

(F821)


33-33: Undefined name PreprocessingConfig

(F821)


35-35: Undefined name AugmentationConfig

(F821)


54-54: Undefined name Union

(F821)


69-69: Undefined name IntensityConfig

(F821)


70-70: Undefined name GeometricConfig

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (3)
sleap_nn/config/data_config.py (3)

1-3: Approve imports with minor suggestion

The imports look good overall. However, there's an unused import that can be removed.

Remove the unused List import:

-from typing import Optional, Tuple, List, Dict
+from typing import Optional, Tuple, Dict
🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


13-36: Approve DataConfig class and suggest forward references

The DataConfig class is well-structured with clear documentation. However, there are undefined names for PreprocessingConfig and AugmentationConfig.

To avoid potential circular imports, use forward references for these classes. Add the following import at the top of the file and modify the class attributes:

from __future__ import annotations

Then update the attributes:

-    preprocessing: PreprocessingConfig = attrs.field(factory=PreprocessingConfig)
-    augmentation_config: Optional[AugmentationConfig] = None
+    preprocessing: 'PreprocessingConfig' = attrs.field(factory=lambda: PreprocessingConfig())
+    augmentation_config: Optional['AugmentationConfig'] = None

This change will resolve the undefined names issue and prevent potential circular imports.

🧰 Tools
🪛 Ruff

34-34: Undefined name PreprocessingConfig

(F821)


34-34: Undefined name PreprocessingConfig

(F821)


36-36: Undefined name AugmentationConfig

(F821)


60-72: Approve AugmentationConfig class and suggest forward references

The AugmentationConfig class is well-structured with clear documentation. However, there are undefined names for IntensityConfig and GeometricConfig.

To avoid potential circular imports, use forward references for these classes. Modify the class attributes as follows:

-    intensity: Optional[IntensityConfig] = attrs.field(default=None)
-    geometric: Optional[GeometricConfig] = attrs.field(default=None)
+    intensity: Optional['IntensityConfig'] = attrs.field(default=None)
+    geometric: Optional['GeometricConfig'] = attrs.field(default=None)

This change will resolve the undefined names issue and prevent potential circular imports.

🧰 Tools
🪛 Ruff

71-71: Undefined name IntensityConfig

(F821)


72-72: Undefined name GeometricConfig

(F821)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 7ea94b6 and 192399d.

📒 Files selected for processing (1)
  • sleap_nn/config/data_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


34-34: Undefined name PreprocessingConfig

(F821)


34-34: Undefined name PreprocessingConfig

(F821)


36-36: Undefined name AugmentationConfig

(F821)


55-55: Undefined name Union

(F821)


71-71: Undefined name IntensityConfig

(F821)


72-72: Undefined name GeometricConfig

(F821)

🔇 Additional comments (3)
sleap_nn/config/data_config.py (3)

75-107: Approve IntensityConfig class

The IntensityConfig class is well-structured with clear documentation and appropriate use of attrs validators. All attributes have sensible default values, and the validators help ensure valid input.

The implementation looks good, and no changes are necessary.


110-143: Approve GeometricConfig class

The GeometricConfig class is well-structured with clear and detailed documentation for each attribute. The use of Optional types and default values is appropriate, providing flexibility in configuration.

The implementation looks good, and no changes are necessary.


1-143: Summary and final recommendation

Overall, the data_config.py file is well-structured and provides a comprehensive set of configuration classes for data handling in a machine learning pipeline. The use of the attrs library and detailed documentation for each class and attribute is commendable.

To improve the code quality and prevent potential issues, please implement the following changes:

  1. Remove the unused List import.
  2. Add from __future__ import annotations at the top of the file.
  3. Use forward references for PreprocessingConfig, AugmentationConfig, IntensityConfig, and GeometricConfig in the respective class attributes.
  4. Add from typing import Union to resolve the undefined Union type.

After implementing these changes, the file will be more robust and less prone to circular import issues.

The overall implementation is good, and with these minor adjustments, it will be excellent.

🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


34-34: Undefined name PreprocessingConfig

(F821)


34-34: Undefined name PreprocessingConfig

(F821)


36-36: Undefined name AugmentationConfig

(F821)


55-55: Undefined name Union

(F821)


71-71: Undefined name IntensityConfig

(F821)


72-72: Undefined name GeometricConfig

(F821)

Comment on lines +39 to +57
@attrs.define
class PreprocessingConfig:
"""Configuration of Preprocessing.

Attributes:
is_rgb: (bool) True if the image has 3 channels (RGB image). If input has only one channel when this is set to True, then the images from single-channel is replicated along the channel axis. If input has three channels and this is set to False, then we convert the image to grayscale (single-channel) image.
max_height: (int) Maximum height the image should be padded to. If not provided, the original image size will be retained. Default: None.
max_width: (int) Maximum width the image should be padded to. If not provided, the original image size will be retained. Default: None.
scale: (float or List[float]) Factor to resize the image dimensions by, specified as either a float scalar or as a 2-tuple of [scale_x, scale_y]. If a scalar is provided, both dimensions are resized by the same factor.
crop_hw: (Tuple[int]) Crop height and width of each instance (h, w) for centered-instance model. If None, this would be automatically computed based on the largest instance in the sio.Labels file.
min_crop_size: (int) Minimum crop size to be used if crop_hw is None.
"""

is_rgb: bool = True
max_height: Optional[int] = None
max_width: Optional[int] = None
scale: Union[float, Tuple[float, float]] = 1.0
crop_hw: Optional[Tuple[int, int]] = None
min_crop_size: int = 32 # to help app work incase of error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Approve PreprocessingConfig class and add missing import

The PreprocessingConfig class is well-structured with appropriate attributes and documentation. However, there's a missing import for the Union type.

Add the following import at the top of the file:

from typing import Union

This will resolve the undefined name issue for the Union type used in the scale attribute.

🧰 Tools
🪛 Ruff

55-55: Undefined name Union

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (1)
sleap_nn/config/model_config.py (1)

1-9: Consider adding type hints to class attributes.

While the attributes are well-described in the docstring, adding type hints would improve code clarity and enable better static type checking. Consider updating the class definition to include type information for each attribute. For example:

@attr.s(auto_attribs=True)
class ModelConfig:
    backbone: dict
    heads: dict
    base_checkpoint: str

Replace dict and str with more specific types if applicable. This change will enhance code readability and make it easier for developers to understand the expected types of these configuration parameters.

🧰 Tools
🪛 Ruff

1-1: Undefined name attr

(F821)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 192399d and 184d5ea.

📒 Files selected for processing (1)
  • sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py

1-1: Undefined name attr

(F821)

🔇 Additional comments (3)
sleap_nn/config/model_config.py (3)

1-9: LGTM: Well-structured class definition and docstring.

The ModelConfig class is well-defined using the @attr.s(auto_attribs=True) decorator, which is appropriate for automatic attribute management. The docstring provides a clear description of the class purpose and its attributes, following the Google style guide format. This promotes good code documentation and maintainability.

🧰 Tools
🪛 Ruff

1-1: Undefined name attr

(F821)


1-9: Overall, good implementation of the ModelConfig class.

The ModelConfig class aligns well with the PR objectives of migrating configuration classes and updating them to use the new attrs API. The implementation is clean and well-documented. With the suggested improvements (adding the import statement and considering type hints), this class will provide a solid foundation for model configuration in the project.

🧰 Tools
🪛 Ruff

1-1: Undefined name attr

(F821)


1-1: ⚠️ Potential issue

Add missing import statement for attr.

The attr module is used in the code but not imported. To resolve the undefined name issue and make the usage of the @attr.s decorator valid, please add the following import statement at the beginning of the file:

import attr
🧰 Tools
🪛 Ruff

1-1: Undefined name attr

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Outside diff range and nitpick comments (1)
sleap_nn/config/model_config.py (1)

23-23: Consider using consistent decorators across the file.

The UNetConfig class uses @attrs.define, while the main ModelConfig class uses @attr.s. For consistency, consider using the same decorator style throughout the file. If you're using a newer version of attrs, you might want to update all decorators to @attrs.define. Otherwise, change this to @attr.s(auto_attribs=True) to match the main class.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 184d5ea and c79ea42.

📒 Files selected for processing (1)
  • sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py

13-14: SyntaxError: Expected an expression


15-16: SyntaxError: Expected an expression

🔇 Additional comments (2)
sleap_nn/config/model_config.py (2)

1-59: Overall structure looks good, with some minor improvements needed.

The ModelConfig class and its nested configuration classes provide a comprehensive and flexible structure for configuring different model architectures. The use of attrs for class definitions is a good choice for reducing boilerplate code.

To improve the file:

  1. Add the missing imports for attr and Enum.
  2. Complete the attribute definitions for pre_trained_weights and backbone_config.
  3. Consider using consistent decorators across all class definitions.

These changes will enhance the code's correctness and consistency.

🧰 Tools
🪛 Ruff

13-14: SyntaxError: Expected an expression


15-16: SyntaxError: Expected an expression


1-1: ⚠️ Potential issue

Add missing import for attr module.

The attr module is used in this file but not imported. Add the following import at the beginning of the file:

import attr

This will resolve the undefined name issue for attr.

Comment on lines 13 to 15
pre_trained_weights: str =
backbone_type: BackboneType = BackboneType.UNET
backbone_config:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Complete the attribute definitions for pre_trained_weights and backbone_config.

The pre_trained_weights and backbone_config attributes are incomplete. Please provide appropriate default values or type annotations for these attributes. For example:

pre_trained_weights: str = ""
backbone_config: Union[UNetConfig, ConvNextConfig, SwinTConfig] = attr.field(default=None)

Make sure to import Union from the typing module if you use it.

🧰 Tools
🪛 Ruff

13-14: SyntaxError: Expected an expression

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (3)
sleap_nn/config/model_config.py (3)

20-46: LGTM! Consider refactoring error messages.

The methods in the ModelConfig class are well-implemented and provide good initialization and validation logic.

To reduce code duplication in error messages, consider creating a helper method for generating error messages:

def _get_weight_error_message(self, backbone_type, valid_weights):
    return f"Invalid pre-trained weights for {backbone_type}. Must be one of {valid_weights}"

# Then use it in validate_pre_trained_weights:
if self.backbone_type == BackboneType.CONVNEXT:
    if self.pre_trained_weights not in convnext_weights:
        raise ValueError(self._get_weight_error_message("ConvNext", convnext_weights))
elif self.backbone_type == BackboneType.SWINT:
    if self.pre_trained_weights not in swint_weights:
        raise ValueError(self._get_weight_error_message("SwinT", swint_weights))

This refactoring will make the code more maintainable and reduce the risk of inconsistencies in error messages.

🧰 Tools
🪛 Ruff

25-25: Undefined name BackboneType

(F821)


26-26: Undefined name UNetConfig

(F821)


27-27: Undefined name BackboneType

(F821)


28-28: Undefined name ConvNextConfig

(F821)


29-29: Undefined name BackboneType

(F821)


30-30: Undefined name SwinTConfig

(F821)


38-38: Undefined name BackboneType

(F821)


41-41: Undefined name BackboneType

(F821)


44-44: Undefined name BackboneType

(F821)


53-65: LGTM! Consider adding validation for None values.

The UNetConfig class is well-defined with appropriate attributes for UNet configuration.

Consider adding validation for attributes that have None as default value, such as max_stride and stem_stride. You could do this in a post-init method:

@attrs.define
class UNetConfig:
    # ... existing attributes ...

    def __attrs_post_init__(self):
        if self.max_stride is None:
            # Set a default value or raise an error if it's required
            self.max_stride = 16  # Example default value
        if self.stem_stride is None:
            # Set a default value or raise an error if it's required
            self.stem_stride = 1  # Example default value

This ensures that these critical parameters always have valid values.


66-89: LGTM! Consider reordering attributes for consistency.

The ConvNextConfig and SwinTConfig classes are well-defined with appropriate attributes for their respective architectures.

For better consistency between the two classes, consider reordering the attributes in SwinTConfig to match the order in ConvNextConfig as closely as possible. For example:

@attrs.define
class SwinTConfig:
    model_type: str = "tiny"
    arch: dict = attrs.field(factory=lambda: {'embed': 96, 'depths': [2, 2, 6, 2], 'channels': [3, 6, 12, 24]})
    in_channels: int = 1
    patch_size: list = attrs.field(factory=lambda: [4, 4])
    stem_patch_stride: int = 2
    window_size: list = attrs.field(factory=lambda: [7, 7])
    kernel_size: int = 3
    filters_rate: float = 1.5
    convs_per_block: int = 2
    up_interpolate: bool = True

This reordering makes it easier to compare the two configurations at a glance.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between c79ea42 and d866e36.

📒 Files selected for processing (1)
  • sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)


25-25: Undefined name BackboneType

(F821)


26-26: Undefined name UNetConfig

(F821)


27-27: Undefined name BackboneType

(F821)


28-28: Undefined name ConvNextConfig

(F821)


29-29: Undefined name BackboneType

(F821)


30-30: Undefined name SwinTConfig

(F821)


38-38: Undefined name BackboneType

(F821)


41-41: Undefined name BackboneType

(F821)


44-44: Undefined name BackboneType

(F821)

🔇 Additional comments (3)
sleap_nn/config/model_config.py (3)

47-50: LGTM! Well-defined enum for backbone types.

The BackboneType enum is correctly defined and provides clear options for the backbone types. This approach enhances type safety and code readability.


1-89: Overall, excellent implementation with minor suggestions for improvement.

The ModelConfig class and its associated nested classes provide a comprehensive and well-structured configuration system for model architectures. The use of attrs for class definitions and Enum for backbone types demonstrates good coding practices.

Key strengths:

  1. Clear separation of concerns for different backbone types.
  2. Proper use of post-initialization and validation methods.
  3. Effective use of default values and factory functions for complex defaults.

Suggestions for improvement:

  1. Add import for Union from typing.
  2. Refactor error message generation in validate_pre_trained_weights.
  3. Add validation for None values in UNetConfig.
  4. Reorder attributes in SwinTConfig for consistency with ConvNextConfig.

These minor improvements will enhance the overall quality and maintainability of the code.

🧰 Tools
🪛 Ruff

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)


25-25: Undefined name BackboneType

(F821)


26-26: Undefined name UNetConfig

(F821)


27-27: Undefined name BackboneType

(F821)


28-28: Undefined name ConvNextConfig

(F821)


29-29: Undefined name BackboneType

(F821)


30-30: Undefined name SwinTConfig

(F821)


38-38: Undefined name BackboneType

(F821)


41-41: Undefined name BackboneType

(F821)


44-44: Undefined name BackboneType

(F821)


1-18: LGTM! Consider adding import for Union.

The imports and main class definition look good. The Enum import addresses a previous comment. However, there's a minor improvement we can make:

Consider adding the following import at the beginning of the file:

from typing import Union

This will resolve the undefined name issue for Union on line 18.

There might be a circular import issue causing some undefined names. Let's verify this:

If this script returns results, it indicates a potential circular import that needs to be addressed.

🧰 Tools
🪛 Ruff

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Outside diff range and nitpick comments (2)
sleap_nn/config/model_config.py (2)

55-67: Consider adding type hints to UNetConfig attributes.

To improve code clarity and maintainability, consider adding type hints to the attributes in the UNetConfig class. For example:

@attrs.define
class UNetConfig:
    in_channels: int = 1
    kernel_size: int = 3
    filters: int = 32
    filters_rate: float = 1.5
    max_stride: Optional[int] = None
    stem_stride: Optional[int] = None
    middle_block: bool = True
    up_interpolate: bool = True
    stacks: int = 3
    convs_per_block: int = 2

Don't forget to import Optional from typing if you use it for max_stride and stem_stride.


68-91: Enhance ConvNextConfig and SwinTConfig with type hints and consider using attrs.Factory.

  1. Add type hints to improve code clarity:
@attrs.define
class ConvNextConfig:
    model_type: str = "tiny"
    arch: Dict[str, Union[List[int], List[int]]] = attrs.field(factory=lambda: {'depths': [3, 3, 9, 3], 'channels': [96, 192, 384, 768]})
    stem_patch_kernel: int = 4
    stem_patch_stride: int = 2
    in_channels: int = 1
    kernel_size: int = 3
    filters_rate: float = 1.5
    convs_per_block: int = 2
    up_interpolate: bool = True

@attrs.define
class SwinTConfig:
    model_type: str = "tiny"
    arch: Dict[str, Union[int, List[int]]] = attrs.field(factory=lambda: {'embed': 96, 'depths': [2, 2, 6, 2], 'channels': [3, 6, 12, 24]})
    patch_size: List[int] = attrs.field(factory=lambda: [4, 4])
    stem_patch_stride: int = 2
    window_size: List[int] = attrs.field(factory=lambda: [7, 7])
    in_channels: int = 1
    kernel_size: int = 3
    filters_rate: float = 1.5
    convs_per_block: int = 2
    up_interpolate: bool = True
  1. Consider using attrs.Factory instead of lambda functions for better readability:
from attrs import Factory

# In ConvNextConfig
arch: Dict[str, Union[List[int], List[int]]] = attrs.field(factory=Factory(lambda: {'depths': [3, 3, 9, 3], 'channels': [96, 192, 384, 768]}))

# In SwinTConfig
arch: Dict[str, Union[int, List[int]]] = attrs.field(factory=Factory(lambda: {'embed': 96, 'depths': [2, 2, 6, 2], 'channels': [3, 6, 12, 24]}))
patch_size: List[int] = attrs.field(factory=Factory(lambda: [4, 4]))
window_size: List[int] = attrs.field(factory=Factory(lambda: [7, 7]))

These changes will improve type checking and make the code more explicit about the expected types for each attribute.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between d866e36 and d572730.

📒 Files selected for processing (1)
  • sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)


27-27: Undefined name BackboneType

(F821)


28-28: Undefined name UNetConfig

(F821)


29-29: Undefined name BackboneType

(F821)


30-30: Undefined name ConvNextConfig

(F821)


31-31: Undefined name BackboneType

(F821)


32-32: Undefined name SwinTConfig

(F821)


40-40: Undefined name BackboneType

(F821)


43-43: Undefined name BackboneType

(F821)


46-46: Undefined name BackboneType

(F821)

🔇 Additional comments (1)
sleap_nn/config/model_config.py (1)

1-91: Overall assessment: Good implementation with room for minor improvements.

The ModelConfig class and its nested configuration classes provide a well-structured and comprehensive approach to managing model architecture configurations. The use of attrs for class definitions is a good choice, making the code more concise and less error-prone.

Key strengths:

  1. Clear separation of concerns for different backbone types.
  2. Thorough validation of pre-trained weights.
  3. Use of enums for backbone types, improving type safety.

Suggested improvements:

  1. Resolve import and undefined name issues.
  2. Add type hints throughout the file for better code clarity and maintainability.
  3. Minor refactoring of the validate_pre_trained_weights method to reduce code duplication.
  4. Consider using attrs.Factory for mutable default values.

These changes will enhance the overall quality of the code, making it more robust and easier to maintain in the future.

🧰 Tools
🪛 Ruff

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)


27-27: Undefined name BackboneType

(F821)


28-28: Undefined name UNetConfig

(F821)


29-29: Undefined name BackboneType

(F821)


30-30: Undefined name ConvNextConfig

(F821)


31-31: Undefined name BackboneType

(F821)


32-32: Undefined name SwinTConfig

(F821)


40-40: Undefined name BackboneType

(F821)


43-43: Undefined name BackboneType

(F821)


46-46: Undefined name BackboneType

(F821)

Comment on lines 1 to 18
import attrs
from enum import Enum

@attrs.define
class ModelConfig:
"""Configurations related to model architecture.

Attributes:
init_weight: (str) model weights initialization method. "default" uses kaiming uniform initialization and "xavier" uses Xavier initialization method.
pre_trained_weights: (str) Pretrained weights file name supported only for ConvNext and SwinT backbones. For ConvNext, one of ["ConvNeXt_Base_Weights","ConvNeXt_Tiny_Weights", "ConvNeXt_Small_Weights", "ConvNeXt_Large_Weights"]. For SwinT, one of ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"].
backbone_type: (str) Backbone architecture for the model to be trained. One of "unet", "convnext" or "swint".

"""

init_weight: str = "default"
pre_trained_weights: str = None
backbone_type: BackboneType = BackboneType.UNET
backbone_config: Union[UNetConfig, ConvNextConfig, SwinTConfig] = attrs.field(init=False) # backbone_config can be any of these 3 configurations. init=False lets you set the parameters later (not in initialization)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add missing imports and address undefined names.

To resolve the undefined names and improve the overall structure of the file, please make the following changes:

  1. Add the missing import for Union:
from typing import Union
  1. Move the BackboneType enum definition to the top of the file, just after the imports.

  2. Add forward references for the nested configuration classes:

UNetConfig = "ModelConfig.UNetConfig"
ConvNextConfig = "ModelConfig.ConvNextConfig"
SwinTConfig = "ModelConfig.SwinTConfig"

These changes will resolve the undefined name issues and improve the overall structure of the file.

🧰 Tools
🪛 Ruff

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Outside diff range and nitpick comments (3)
sleap_nn/config/model_config.py (3)

1-2: Add missing imports for type hints.

To improve type checking and code clarity, please add the following imports at the beginning of the file:

from typing import Dict, List, Optional, Union

These imports are necessary for the type hints used throughout the file.


56-92: Enhance documentation and type hinting for nested configuration classes.

For the UNetConfig, ConvNextConfig, and SwinTConfig classes:

  1. Add docstrings to explain the purpose of each class and its attributes.
  2. Consider adding type hints to all attributes for better code clarity and type checking.

For example, for the UNetConfig class:

@attrs.define
class UNetConfig:
    """Configuration for UNet backbone.

    Attributes:
        in_channels (int): Number of input channels.
        kernel_size (int): Size of the convolutional kernel.
        filters (int): Number of filters in the first layer.
        filters_rate (float): Rate at which the number of filters increases.
        max_stride (Optional[int]): Maximum stride in the network.
        stem_stride (Optional[int]): Stride in the stem of the network.
        middle_block (bool): Whether to include a middle block.
        up_interpolate (bool): Whether to use interpolation for upsampling.
        stacks (int): Number of encoder/decoder stacks.
        convs_per_block (int): Number of convolutions per block.
    """
    in_channels: int = 1
    kernel_size: int = 3
    filters: int = 32
    filters_rate: float = 1.5
    max_stride: Optional[int] = None
    stem_stride: Optional[int] = None
    middle_block: bool = True
    up_interpolate: bool = True
    stacks: int = 3
    convs_per_block: int = 2

Apply similar improvements to ConvNextConfig and SwinTConfig classes.


94-120: Improve consistency and documentation for additional configuration classes.

For the HeadConfig, SingleInstanceConfig, and ConfMapsConfig classes:

  1. Add docstrings to HeadConfig and SingleInstanceConfig explaining their purpose and attributes, similar to ConfMapsConfig.
  2. Consider using a consistent style for optional attributes. For example, in ConfMapsConfig, you could use attrs.field(default=None) instead of Optional[Type] = None for consistency with other classes.

Example for HeadConfig:

@attrs.define
class HeadConfig:
    """Configuration for model heads.

    Attributes:
        head_configs (Dict[str, Optional[Dict]]): A dictionary of head configurations.
            Keys represent head types, and values are their respective configurations.
    """
    head_configs: Dict[str, Optional[Dict]] = attrs.field(
        factory=lambda: {
            "single_instance": None,
            "centroid": None,
            "centered_instance": None,
            "bottomup": None
        }
    )

Apply similar improvements to SingleInstanceConfig.

🧰 Tools
🪛 Ruff

96-96: Undefined name Dict

(F821)


96-96: Undefined name Optional

(F821)


96-96: Undefined name Dict

(F821)


107-107: Undefined name Optional

(F821)


107-107: Undefined name ConfMapsConfig

(F821)


118-118: Undefined name Optional

(F821)


118-118: Undefined name List

(F821)


119-119: Undefined name Optional

(F821)


120-120: Undefined name Optional

(F821)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between d572730 and 35b29c4.

📒 Files selected for processing (1)
  • sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


28-28: Undefined name BackboneType

(F821)


29-29: Undefined name UNetConfig

(F821)


30-30: Undefined name BackboneType

(F821)


31-31: Undefined name ConvNextConfig

(F821)


32-32: Undefined name BackboneType

(F821)


33-33: Undefined name SwinTConfig

(F821)


41-41: Undefined name BackboneType

(F821)


44-44: Undefined name BackboneType

(F821)


47-47: Undefined name BackboneType

(F821)


96-96: Undefined name Dict

(F821)


96-96: Undefined name Optional

(F821)


96-96: Undefined name Dict

(F821)


107-107: Undefined name Optional

(F821)


107-107: Undefined name ConfMapsConfig

(F821)


118-118: Undefined name Optional

(F821)


118-118: Undefined name List

(F821)


119-119: Undefined name Optional

(F821)


120-120: Undefined name Optional

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 35b29c4 and 5a5dc7b.

📒 Files selected for processing (1)
  • sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


28-28: Undefined name BackboneType

(F821)


29-29: Undefined name UNetConfig

(F821)


30-30: Undefined name BackboneType

(F821)


31-31: Undefined name ConvNextConfig

(F821)


32-32: Undefined name BackboneType

(F821)


33-33: Undefined name SwinTConfig

(F821)


41-41: Undefined name BackboneType

(F821)


44-44: Undefined name BackboneType

(F821)


47-47: Undefined name BackboneType

(F821)


96-96: Undefined name Dict

(F821)


96-96: Undefined name Optional

(F821)


96-96: Undefined name Dict

(F821)


108-108: Undefined name Optional

(F821)


108-108: Undefined name SingleInstanceConfMapsConfig

(F821)


119-119: Undefined name Optional

(F821)


119-119: Undefined name List

(F821)


120-120: Undefined name Optional

(F821)


121-121: Undefined name Optional

(F821)


126-126: Undefined name Optional

(F821)


126-126: Undefined name CentroidConfMapsConfig

(F821)


137-137: Undefined name Optional

(F821)


138-138: Undefined name Optional

(F821)


139-139: Undefined name Optional

(F821)


144-144: Undefined name Optional

(F821)


144-144: Undefined name CenteredInstanceConfMapsConfig

(F821)


156-156: Undefined name Optional

(F821)


156-156: Undefined name List

(F821)


157-157: Undefined name Optional

(F821)


158-158: Undefined name Optional

(F821)


159-159: Undefined name Optional

(F821)

🔇 Additional comments (2)
sleap_nn/config/model_config.py (2)

50-92: LGTM: Nested configuration classes.

The nested configuration classes (BackboneType, UNetConfig, ConvNextConfig, and SwinTConfig) are well-structured and use appropriate attrs decorators. The use of attrs.field(factory=lambda: ...) for default dictionaries is a good practice to avoid mutable default arguments.


1-160: Overall assessment: Well-structured configuration system with minor improvements needed.

The model_config.py file introduces a comprehensive and well-structured configuration system for a machine learning model. It effectively uses the attrs library for class definitions and provides clear docstrings for attributes. The nested configuration classes for different backbone types and head configurations are well-organized.

To further improve the code:

  1. Add the missing imports for type hints.
  2. Update the backbone_type attribute in ModelConfig to use attrs.field.
  3. Consider simplifying the set_backbone_config and validate_pre_trained_weights methods as suggested.
  4. Fix the naming inconsistency in the centroid class.

These changes will enhance the code's clarity, maintainability, and consistency.

🧰 Tools
🪛 Ruff

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


28-28: Undefined name BackboneType

(F821)


29-29: Undefined name UNetConfig

(F821)


30-30: Undefined name BackboneType

(F821)


31-31: Undefined name ConvNextConfig

(F821)


32-32: Undefined name BackboneType

(F821)


33-33: Undefined name SwinTConfig

(F821)


41-41: Undefined name BackboneType

(F821)


44-44: Undefined name BackboneType

(F821)


47-47: Undefined name BackboneType

(F821)


96-96: Undefined name Dict

(F821)


96-96: Undefined name Optional

(F821)


96-96: Undefined name Dict

(F821)


108-108: Undefined name Optional

(F821)


108-108: Undefined name SingleInstanceConfMapsConfig

(F821)


119-119: Undefined name Optional

(F821)


119-119: Undefined name List

(F821)


120-120: Undefined name Optional

(F821)


121-121: Undefined name Optional

(F821)


126-126: Undefined name Optional

(F821)


126-126: Undefined name CentroidConfMapsConfig

(F821)


137-137: Undefined name Optional

(F821)


138-138: Undefined name Optional

(F821)


139-139: Undefined name Optional

(F821)


144-144: Undefined name Optional

(F821)


144-144: Undefined name CenteredInstanceConfMapsConfig

(F821)


156-156: Undefined name Optional

(F821)


156-156: Undefined name List

(F821)


157-157: Undefined name Optional

(F821)


158-158: Undefined name Optional

(F821)


159-159: Undefined name Optional

(F821)

Comment on lines 22 to 35
def __attrs_post_init__(self):
self.backbone_config = self.set_backbone_config()
self.validate_pre_trained_weights()

# configures back_bone config to one of these types
def set_backbone_config(self):
if self.backbone_type == BackboneType.UNET:
return UNetConfig()
elif self.backbone_type == BackboneType.CONVNEXT:
return ConvNextConfig()
elif self.backbone_type == BackboneType.SWINT:
return SwinTConfig()
else:
raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Simplify set_backbone_config method.

Consider simplifying the set_backbone_config method using a dictionary mapping:

def set_backbone_config(self):
    config_map = {
        BackboneType.UNET: UNetConfig,
        BackboneType.CONVNEXT: ConvNextConfig,
        BackboneType.SWINT: SwinTConfig,
    }
    config_class = config_map.get(self.backbone_type)
    if config_class is None:
        raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
    return config_class()

This approach reduces repetitive code and makes it easier to add new backbone types in the future.

🧰 Tools
🪛 Ruff

28-28: Undefined name BackboneType

(F821)


29-29: Undefined name UNetConfig

(F821)


30-30: Undefined name BackboneType

(F821)


31-31: Undefined name ConvNextConfig

(F821)


32-32: Undefined name BackboneType

(F821)


33-33: Undefined name SwinTConfig

(F821)

Comment on lines 94 to 159
@attrs.define
class HeadConfig:
head_configs: Dict[str, Optional[Dict]] = attrs.field(
factory = lambda:{
"single_instance": None,
"centroid": None,
"centered_instance": None,
"bottomup": None
}
)

# Head_config single instance
@attrs.define
class SingleInstanceConfig:
confmaps: Optional[SingleInstanceConfMapsConfig] = None

@attrs.define
class SingleInstanceConfMapsConfig:
'''

Attributes:
part_names: (List[str]) None if nodes from sio.Labels file can be used directly. Else provide text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'.
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied.
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution.
'''
part_names: Optional[List[str]] = None
sigma: Optional[float] = None
output_stride: Optional[float] = None

# Head_config centroid
@attrs.define
class centroid:
confmaps: Optional[CentroidConfMapsConfig] = None

@attrs.define
class CentroidConfMapsConfig:
'''

Attributes:
anchor_part: (int) Note: Only for 'CenteredInstanceConfmapsHead'. Index of the anchor node to use as the anchor point. If None, the midpoint of the bounding box of all visible instance points will be used as the anchor. The bounding box midpoint will also be used if the anchor part is specified but not visible in the instance. Setting a reliable anchor point can significantly improve topdown model accuracy as they benefit from a consistent geometry of the body parts relative to the center of the image.
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied.
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution.
'''
anchor_part: Optional[int] = None
sigma: Optional[float] = None
output_stride: Optional[float] = None

# Head_config centered_instance
@attrs.define
class centered_instance:
confmaps: Optional[CenteredInstanceConfMapsConfig] = None

@attrs.define
class CenteredInstanceConfMapsConfig:
'''

Attributes:
part_names: (List[str]) None if nodes from sio.Labels file can be used directly. Else provide text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'.
anchor_part: (int) Note: Only for 'CenteredInstanceConfmapsHead'. Index of the anchor node to use as the anchor point. If None, the midpoint of the bounding box of all visible instance points will be used as the anchor. The bounding box midpoint will also be used if the anchor part is specified but not visible in the instance. Setting a reliable anchor point can significantly improve topdown model accuracy as they benefit from a consistent geometry of the body parts relative to the center of the image.
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied.
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution.
'''
part_names: Optional[List[str]] = None
anchor_part: Optional[int] = None
sigma: Optional[float] = None
output_stride: Optional[float] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

LGTM: HeadConfig and related classes.

The HeadConfig and related configuration classes (SingleInstanceConfig, SingleInstanceConfMapsConfig, CentroidConfMapsConfig, CenteredInstanceConfMapsConfig) are well-structured and use appropriate attrs decorators. The docstrings provide clear explanations of the attributes.

However, there's a naming inconsistency in the centroid class. Please update the class name to follow PascalCase convention:

@attrs.define
class Centroid:
    confmaps: Optional[CentroidConfMapsConfig] = None

This change will maintain consistency with the naming of other classes in the file.

🧰 Tools
🪛 Ruff

96-96: Undefined name Dict

(F821)


96-96: Undefined name Optional

(F821)


96-96: Undefined name Dict

(F821)


108-108: Undefined name Optional

(F821)


108-108: Undefined name SingleInstanceConfMapsConfig

(F821)


119-119: Undefined name Optional

(F821)


119-119: Undefined name List

(F821)


120-120: Undefined name Optional

(F821)


121-121: Undefined name Optional

(F821)


126-126: Undefined name Optional

(F821)


126-126: Undefined name CentroidConfMapsConfig

(F821)


137-137: Undefined name Optional

(F821)


138-138: Undefined name Optional

(F821)


139-139: Undefined name Optional

(F821)


144-144: Undefined name Optional

(F821)


144-144: Undefined name CenteredInstanceConfMapsConfig

(F821)


156-156: Undefined name Optional

(F821)


156-156: Undefined name List

(F821)


157-157: Undefined name Optional

(F821)


158-158: Undefined name Optional

(F821)


159-159: Undefined name Optional

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 5a5dc7b and 3b0f48a.

📒 Files selected for processing (1)
  • sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


28-28: Undefined name BackboneType

(F821)


29-29: Undefined name UNetConfig

(F821)


30-30: Undefined name BackboneType

(F821)


31-31: Undefined name ConvNextConfig

(F821)


32-32: Undefined name BackboneType

(F821)


33-33: Undefined name SwinTConfig

(F821)


41-41: Undefined name BackboneType

(F821)


44-44: Undefined name BackboneType

(F821)


47-47: Undefined name BackboneType

(F821)


96-96: Undefined name Dict

(F821)


96-96: Undefined name Optional

(F821)


96-96: Undefined name Dict

(F821)


108-108: Undefined name Optional

(F821)


108-108: Undefined name SingleInstanceConfMapsConfig

(F821)


119-119: Undefined name Optional

(F821)


119-119: Undefined name List

(F821)


120-120: Undefined name Optional

(F821)


121-121: Undefined name Optional

(F821)


126-126: Undefined name Optional

(F821)


126-126: Undefined name CentroidConfMapsConfig

(F821)


137-137: Undefined name Optional

(F821)


138-138: Undefined name Optional

(F821)


139-139: Undefined name Optional

(F821)


144-144: Undefined name Optional

(F821)


144-144: Undefined name CenteredInstanceConfMapsConfig

(F821)


156-156: Undefined name Optional

(F821)


156-156: Undefined name List

(F821)


157-157: Undefined name Optional

(F821)


158-158: Undefined name Optional

(F821)


159-159: Undefined name Optional

(F821)


164-164: Undefined name Optional

(F821)


164-164: Undefined name BottomUpConfMapsConfig

(F821)


165-165: Undefined name Optional

(F821)


165-165: Undefined name PAFConfig

(F821)


177-177: Undefined name Optional

(F821)


177-177: Undefined name List

(F821)


178-178: Undefined name Optional

(F821)


179-179: Undefined name Optional

(F821)


180-180: Undefined name Optional

(F821)


192-192: Undefined name Optional

(F821)


192-192: Undefined name List

(F821)


193-193: Undefined name Optional

(F821)


194-194: Undefined name Optional

(F821)


195-195: Undefined name Optional

(F821)

🔇 Additional comments (1)
sleap_nn/config/model_config.py (1)

1-196: Overall, well-structured configuration classes with room for improvement.

The file contains a comprehensive set of configuration classes for various aspects of a machine learning model. The use of attrs for class definitions and the detailed docstrings are commendable. However, there are a few areas for improvement:

  1. Add missing imports from the typing module to resolve undefined names.
  2. Consider moving nested classes (BackboneType, UNetConfig, ConvNextConfig, SwinTConfig) to the module level for better modularity.
  3. Fix the naming inconsistency in the centroid class.
  4. Simplify the validate_pre_trained_weights method in the ModelConfig class.

Addressing these points will enhance the overall quality and maintainability of the code.

🧰 Tools
🪛 Ruff

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


28-28: Undefined name BackboneType

(F821)


29-29: Undefined name UNetConfig

(F821)


30-30: Undefined name BackboneType

(F821)


31-31: Undefined name ConvNextConfig

(F821)


32-32: Undefined name BackboneType

(F821)


33-33: Undefined name SwinTConfig

(F821)


41-41: Undefined name BackboneType

(F821)


44-44: Undefined name BackboneType

(F821)


47-47: Undefined name BackboneType

(F821)


96-96: Undefined name Dict

(F821)


96-96: Undefined name Optional

(F821)


96-96: Undefined name Dict

(F821)


108-108: Undefined name Optional

(F821)


108-108: Undefined name SingleInstanceConfMapsConfig

(F821)


119-119: Undefined name Optional

(F821)


119-119: Undefined name List

(F821)


120-120: Undefined name Optional

(F821)


121-121: Undefined name Optional

(F821)


126-126: Undefined name Optional

(F821)


126-126: Undefined name CentroidConfMapsConfig

(F821)


137-137: Undefined name Optional

(F821)


138-138: Undefined name Optional

(F821)


139-139: Undefined name Optional

(F821)


144-144: Undefined name Optional

(F821)


144-144: Undefined name CenteredInstanceConfMapsConfig

(F821)


156-156: Undefined name Optional

(F821)


156-156: Undefined name List

(F821)


157-157: Undefined name Optional

(F821)


158-158: Undefined name Optional

(F821)


159-159: Undefined name Optional

(F821)


164-164: Undefined name Optional

(F821)


164-164: Undefined name BottomUpConfMapsConfig

(F821)


165-165: Undefined name Optional

(F821)


165-165: Undefined name PAFConfig

(F821)


177-177: Undefined name Optional

(F821)


177-177: Undefined name List

(F821)


178-178: Undefined name Optional

(F821)


179-179: Undefined name Optional

(F821)


180-180: Undefined name Optional

(F821)


192-192: Undefined name Optional

(F821)


192-192: Undefined name List

(F821)


193-193: Undefined name Optional

(F821)


194-194: Undefined name Optional

(F821)


195-195: Undefined name Optional

(F821)

Comment on lines +1 to +2
import attrs
from enum import Enum
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add missing imports from typing module.

To resolve the undefined names flagged by the static analysis tool, please add the following imports at the beginning of the file:

from typing import Union, Optional, List, Dict

This will address many of the undefined name issues throughout the file.

Comment on lines 4 to 19
@attrs.define
class ModelConfig:
"""Configurations related to model architecture.

Attributes:
init_weight: (str) model weights initialization method. "default" uses kaiming uniform initialization and "xavier" uses Xavier initialization method.
pre_trained_weights: (str) Pretrained weights file name supported only for ConvNext and SwinT backbones. For ConvNext, one of ["ConvNeXt_Base_Weights","ConvNeXt_Tiny_Weights", "ConvNeXt_Small_Weights", "ConvNeXt_Large_Weights"]. For SwinT, one of ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"].
backbone_type: (str) Backbone architecture for the model to be trained. One of "unet", "convnext" or "swint".

"""

init_weight: str = "default"
pre_trained_weights: str = None
backbone_type: BackboneType = BackboneType.UNET
backbone_config: Union[UNetConfig, ConvNextConfig, SwinTConfig] = attrs.field(init=False) # backbone_config can be any of these 3 configurations. init=False lets you set the parameters later (not in initialization)
head_configs: HeadConfig = attrs.field(factory=HeadConfig)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider moving BackboneType enum outside ModelConfig class.

The BackboneType enum is currently defined inside the ModelConfig class. It's generally more common and easier to use if enums are defined at the module level. Consider moving it outside and above the ModelConfig class definition.

🧰 Tools
🪛 Ruff

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


19-19: Undefined name HeadConfig

(F821)

Comment on lines 56 to 92
@attrs.define
class UNetConfig:
in_channels: int = 1
kernel_size: int = 3
filters: int = 32
filters_rate: float = 1.5
max_stride: int = None
stem_stride: int = None
middle_block: bool = True
up_interpolate: bool = True
stacks: int = 3
convs_per_block: int = 2

@attrs.define
class ConvNextConfig:
model_type: str = "tiny" # Options: tiny, small, base, large
arch: dict = attrs.field(factory=lambda: {'depths': [3, 3, 9, 3], 'channels': [96, 192, 384, 768]})
stem_patch_kernel: int = 4
stem_patch_stride: int = 2
in_channels: int = 1
kernel_size: int = 3
filters_rate: float = 1.5
convs_per_block: int = 2
up_interpolate: bool = True

@attrs.define
class SwinTConfig:
model_type: str = "tiny" # Options: tiny, small, base
arch: dict = attrs.field(factory=lambda: {'embed': 96, 'depths': [2, 2, 6, 2], 'channels': [3, 6, 12, 24]})
patch_size: list = attrs.field(factory=lambda: [4, 4])
stem_patch_stride: int = 2
window_size: list = attrs.field(factory=lambda: [7, 7])
in_channels: int = 1
kernel_size: int = 3
filters_rate: float = 1.5
convs_per_block: int = 2
up_interpolate: bool = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider moving backbone config classes outside ModelConfig.

The UNetConfig, ConvNextConfig, and SwinTConfig classes are currently nested within the ModelConfig class. This structure might make it harder to use these configs independently or to import them in other parts of the codebase. Consider moving these classes to the module level, just after the ModelConfig class definition. This would improve modularity and make the configs more accessible.

Example:

@attrs.define
class ModelConfig:
    # ... (existing ModelConfig code) ...

@attrs.define
class UNetConfig:
    # ... (existing UNetConfig code) ...

@attrs.define
class ConvNextConfig:
    # ... (existing ConvNextConfig code) ...

@attrs.define
class SwinTConfig:
    # ... (existing SwinTConfig code) ...

This change would make the code structure more flat and easier to navigate.

Comment on lines 124 to 126
@attrs.define
class centroid:
confmaps: Optional[CentroidConfMapsConfig] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix naming inconsistency in centroid class.

The centroid class is using lowercase naming, which is inconsistent with the PascalCase naming convention used for other classes in this file. Please update the class name to follow the PascalCase convention:

@attrs.define
class Centroid:
    confmaps: Optional[CentroidConfMapsConfig] = None

This change will maintain consistency with the naming of other classes in the file.

🧰 Tools
🪛 Ruff

126-126: Undefined name Optional

(F821)


126-126: Undefined name CentroidConfMapsConfig

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 3b0f48a and a074821.

📒 Files selected for processing (1)
  • sleap_nn/config/trainer_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/trainer_config.py

1-1: omegaconf.OmegaConf imported but unused

Remove unused import: omegaconf.OmegaConf

(F401)


30-30: Undefined name attrs

(F821)

Comment on lines 31 to 49
class TrainerConfig:
"""Configuration of Trainer.

Attributes:
train_data_loader: (Note: Any parameters from Torch's DataLoader could be used.)
val_data_loader: (Similar to train_data_loader)
model_ckpt: (Note: Any parameters from Lightning's ModelCheckpoint could be used.)
trainer_devices: (int) Number of devices to train on (int), which devices to train on (list or str), or "auto" to select automatically.
trainer_accelerator: (str) One of the ("cpu", "gpu", "tpu", "ipu", "auto"). "auto" recognises the machine the model is running on and chooses the appropriate accelerator for the Trainer to be connected to.
enable_progress_bar: (bool) When True, enables printing the logs during training.
steps_per_epoch: (int) Minimum number of iterations in a single epoch. (Useful if model is trained with very few data points). Refer limit_train_batches parameter of Torch Trainer. If None, the number of iterations depends on the number of samples in the train dataset.
max_epochs: (int) Maxinum number of epochs to run.
seed: (int) Seed value for the current experiment.
use_wandb: (bool) True to enable wandb logging.
save_ckpt: (bool) True to enable checkpointing.
save_ckpt_path: (str) Directory path to save the training config and checkpoint files. Default: "./"
resume_ckpt_path: (str) Path to .ckpt file from which training is resumed. Default: None.
wandb: (Only if use_wandb is True, else skip this)
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Define the attributes within the TrainerConfig class.

The TrainerConfig class currently lacks attribute definitions. You need to define the attributes using attrs fields to hold the configuration parameters.

Add the attribute definitions and necessary imports:

+from typing import Any, Union, List, Optional

 class TrainerConfig:
     """Configuration of Trainer.

     Attributes:
         train_data_loader: (Note: Any parameters from Torch's DataLoader could be used.) 
         val_data_loader: (Similar to train_data_loader)
         model_ckpt: (Note: Any parameters from Lightning's ModelCheckpoint could be used.) 
         trainer_devices: (int) Number of devices to train on (int), which devices to train on (list or str), or "auto" to select automatically.
         trainer_accelerator: (str) One of the ("cpu", "gpu", "tpu", "ipu", "auto"). "auto" recognises the machine the model is running on and chooses the appropriate accelerator for the Trainer to be connected to.
         enable_progress_bar: (bool) When True, enables printing the logs during training.
         steps_per_epoch: (int) Minimum number of iterations in a single epoch. (Useful if model is trained with very few data points). Refer limit_train_batches parameter of Torch Trainer. If None, the number of iterations depends on the number of samples in the train dataset.
         max_epochs: (int) Maxinum number of epochs to run.
         seed: (int) Seed value for the current experiment.
         use_wandb: (bool) True to enable wandb logging.
         save_ckpt: (bool) True to enable checkpointing.
         save_ckpt_path: (str) Directory path to save the training config and checkpoint files. Default: "./"
         resume_ckpt_path: (str) Path to .ckpt file from which training is resumed. Default: None.
         wandb: (Only if use_wandb is True, else skip this)
     """
+    train_data_loader: Any = attrs.field()
+    val_data_loader: Any = attrs.field()
+    model_ckpt: Any = attrs.field()
+    trainer_devices: Union[int, List[int], str] = attrs.field(default="auto")
+    trainer_accelerator: str = attrs.field(default="auto")
+    enable_progress_bar: bool = attrs.field(default=True)
+    steps_per_epoch: Optional[int] = attrs.field(default=None)
+    max_epochs: int = attrs.field(default=100)
+    seed: int = attrs.field(default=42)
+    use_wandb: bool = attrs.field(default=False)
+    save_ckpt: bool = attrs.field(default=True)
+    save_ckpt_path: str = attrs.field(default="./")
+    resume_ckpt_path: Optional[str] = attrs.field(default=None)
+    wandb: Any = attrs.field(default=None)

Committable suggestion was skipped due to low confidence.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Outside diff range and nitpick comments (1)
sleap_nn/config/trainer_config.py (1)

55-57: Use consistent quote style for docstrings.

For consistency with the rest of the file, use double quotes for multi-line docstrings instead of single quotes.

Replace the single quotes with double quotes:

"""
train and val data_loader: (Note: Any parameters from Torch's DataLoader could be used.)

Attributes:
    ...
"""
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between a074821 and 470a2f3.

📒 Files selected for processing (1)
  • sleap_nn/config/trainer_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/trainer_config.py

1-1: omegaconf.OmegaConf imported but unused

Remove unused import: omegaconf.OmegaConf

(F401)


30-30: Undefined name attrs

(F821)


50-50: Undefined name TrainDataLoader

(F821)


50-50: Undefined name attrs

(F821)


50-50: Undefined name DataLoader

(F821)


51-51: Undefined name ValDataLoader

(F821)


51-51: Undefined name attrs

(F821)


51-51: Undefined name DataLoader

(F821)


53-53: Undefined name attrs

(F821)

🔇 Additional comments (3)
sleap_nn/config/trainer_config.py (3)

3-27: LGTM: Well-documented module purpose.

The module-level docstring provides a clear and comprehensive explanation of the purpose and design of the configuration classes. It effectively communicates the intended use and benefits of this approach.


32-49: LGTM: Well-documented TrainerConfig class.

The docstring for the TrainerConfig class provides clear and detailed explanations for each attribute, which is excellent for maintainability and usability.


1-66: Summary: Good implementation with minor improvements needed.

Overall, the implementation of TrainerConfig and DataLoader classes provides a solid foundation for managing training job parameters. The code is well-documented and structured. To improve it further:

  1. Address the unused import and missing attrs import.
  2. Ensure consistency in attribute definitions using attrs.field.
  3. Define or import the TrainDataLoader and ValDataLoader types.
  4. Use consistent quote style for docstrings.

These changes will enhance the code's maintainability and type safety.

🧰 Tools
🪛 Ruff

1-1: omegaconf.OmegaConf imported but unused

Remove unused import: omegaconf.OmegaConf

(F401)


30-30: Undefined name attrs

(F821)


50-50: Undefined name TrainDataLoader

(F821)


50-50: Undefined name attrs

(F821)


50-50: Undefined name DataLoader

(F821)


51-51: Undefined name ValDataLoader

(F821)


51-51: Undefined name attrs

(F821)


51-51: Undefined name DataLoader

(F821)


53-53: Undefined name attrs

(F821)

"""


@attrs.define
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add missing import for attrs.

The attrs module is used but not imported. Add the import at the beginning of the file.

Add this import at the top of the file:

import attrs
🧰 Tools
🪛 Ruff

30-30: Undefined name attrs

(F821)

Comment on lines 50 to 51
train_data_loader: TrainDataLoader = attrs.field(factory=DataLoader)
val_data_loader: ValDataLoader = attrs.field(factory=DataLoader)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Define or import TrainDataLoader and ValDataLoader types.

The TrainDataLoader and ValDataLoader types are used but not defined or imported. Ensure these types are properly defined or imported to maintain type consistency.

If these types are defined elsewhere in the project, import them at the beginning of the file:

from .data_config import TrainDataLoader, ValDataLoader  # Adjust the import path as needed

If they are aliases for the DataLoader class, consider using type aliases:

from typing import TypeAlias

TrainDataLoader: TypeAlias = DataLoader
ValDataLoader: TypeAlias = DataLoader
🧰 Tools
🪛 Ruff

50-50: Undefined name TrainDataLoader

(F821)


50-50: Undefined name attrs

(F821)


50-50: Undefined name DataLoader

(F821)


51-51: Undefined name ValDataLoader

(F821)


51-51: Undefined name attrs

(F821)


51-51: Undefined name DataLoader

(F821)

Comment on lines 63 to 65
batch_size: int = 1
shuffle: bool=False
num_workers: int=0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use attrs.field for attribute definitions in DataLoader class.

For consistency with the TrainerConfig class, use attrs.field to define the attributes in the DataLoader class.

Apply this diff to use attrs.field:

-    batch_size: int = 1
-    shuffle: bool=False
-    num_workers: int=0
+    batch_size: int = attrs.field(default=1)
+    shuffle: bool = attrs.field(default=False)
+    num_workers: int = attrs.field(default=0)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
batch_size: int = 1
shuffle: bool=False
num_workers: int=0
batch_size: int = attrs.field(default=1)
shuffle: bool = attrs.field(default=False)
num_workers: int = attrs.field(default=0)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 470a2f3 and 4d5f169.

📒 Files selected for processing (1)
  • sleap_nn/config/trainer_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/trainer_config.py

63-64: SyntaxError: Expected an expression

🔇 Additional comments (1)
sleap_nn/config/trainer_config.py (1)

75-77: 🛠️ Refactor suggestion

Use attrs.field for attribute definitions in DataLoaderConfig.

For consistency with other classes and to leverage attrs features, use attrs.field to define the attributes in the DataLoaderConfig class.

Apply this diff to use attrs.field:

-    batch_size: int = 1
-    shuffle: bool=False
-    num_workers: int=0
+    batch_size: int = attrs.field(default=1)
+    shuffle: bool = attrs.field(default=False)
+    num_workers: int = attrs.field(default=0)

Likely invalid or redundant comment.

@@ -0,0 +1,88 @@
from omegaconf import OmegaConf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add missing imports and remove unused import.

  1. Remove the unused import of OmegaConf.
  2. Add the necessary imports for attrs and typing modules.

Apply this diff to fix the imports:

-from omegaconf import OmegaConf
+import attrs
+from typing import Union, List, Optional

Also applies to: 30-30

Comment on lines 63 to 64
wandb:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Complete the 'wandb' attribute definition.

The 'wandb' attribute is incomplete. It should be properly defined with a type and default value.

Consider updating the 'wandb' attribute as follows:

wandb: Optional[Any] = attrs.field(default=None)

Also, add from typing import Any to the import statement if not already present.

🧰 Tools
🪛 Ruff

63-64: SyntaxError: Expected an expression

Comment on lines 87 to 88
save_top_k: int = 1
save_last: Optional[bool]=None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use attrs.field for attribute definitions in ModelCkptConfig.

For consistency with other classes and to leverage attrs features, use attrs.field to define the attributes in the ModelCkptConfig class.

Apply this diff to use attrs.field:

-    save_top_k: int = 1
-    save_last: Optional[bool]=None
+    save_top_k: int = attrs.field(default=1)
+    save_last: Optional[bool] = attrs.field(default=None)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
save_top_k: int = 1
save_last: Optional[bool]=None
save_top_k: int = attrs.field(default=1)
save_last: Optional[bool] = attrs.field(default=None)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 4d5f169 and cfd87d4.

📒 Files selected for processing (1)
  • sleap_nn/config/trainer_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/trainer_config.py

2-2: omegaconf.OmegaConf imported but unused

Remove unused import: omegaconf.OmegaConf

(F401)


55-55: Undefined name DataLoaderConfig

(F821)


55-55: Undefined name DataLoaderConfig

(F821)


56-56: Undefined name DataLoaderConfig

(F821)


56-56: Undefined name DataLoaderConfig

(F821)


57-57: Undefined name ModelCkptConfig

(F821)


57-57: Undefined name ModelCkptConfig

(F821)


58-58: Undefined name Union

(F821)


58-58: Undefined name List

(F821)


61-61: Undefined name Optional

(F821)


63-63: Undefined name Optional

(F821)


67-67: Undefined name Optional

(F821)


68-68: Undefined name Optional

(F821)


68-68: Undefined name WandBConfig

(F821)


69-69: Undefined name Optional

(F821)


69-69: Undefined name OptimizerConfig

(F821)


69-69: Undefined name OptimizerConfig

(F821)


70-70: Undefined name LRSchedulerConfig

(F821)


70-70: Undefined name LRSchedulerConfig

(F821)


71-71: Undefined name EarlyStoppingConfig

(F821)


71-71: Undefined name attr

(F821)


71-71: Undefined name EarlyStoppingConfig

(F821)


104-104: Undefined name Optional

(F821)


119-119: Undefined name Optional

(F821)


120-120: Undefined name Optional

(F821)


121-121: Undefined name Optional

(F821)


122-122: Undefined name Optional

(F821)


123-123: Undefined name Optional

(F821)


124-124: Undefined name Optional

(F821)


125-125: Undefined name Optional

(F821)


125-125: Undefined name List

(F821)


157-157: Undefined name Union

(F821)


157-157: Undefined name List

(F821)

🔇 Additional comments (5)
sleap_nn/config/trainer_config.py (5)

4-28: LGTM: Well-written file-level docstring.

The file-level docstring provides a clear explanation of the purpose and design principles of the configuration classes. It effectively communicates the separation of concerns between parameter specification and implementation.


106-125: LGTM: Well-structured WandBConfig class.

The WandBConfig class is well-defined with appropriate use of attrs.define and Optional type hints for its attributes. The docstring provides clear explanations for each attribute.

🧰 Tools
🪛 Ruff

119-119: Undefined name Optional

(F821)


120-120: Undefined name Optional

(F821)


121-121: Undefined name Optional

(F821)


122-122: Undefined name Optional

(F821)


123-123: Undefined name Optional

(F821)


124-124: Undefined name Optional

(F821)


125-125: Undefined name Optional

(F821)


125-125: Undefined name List

(F821)


1-170: Summary: Well-structured configuration classes with minor improvements needed.

Overall, this file introduces a comprehensive set of configuration classes for training job parameters. The design is solid, and the documentation is thorough. To improve the code:

  1. Ensure consistency in attribute definitions by using attrs.field across all classes.
  2. Add missing type hints and remove unused imports.
  3. Fix the early_stopping attribute in TrainerConfig to use attrs.field.

These changes will enhance the consistency and type safety of the configuration classes.

🧰 Tools
🪛 Ruff

2-2: omegaconf.OmegaConf imported but unused

Remove unused import: omegaconf.OmegaConf

(F401)


55-55: Undefined name DataLoaderConfig

(F821)


55-55: Undefined name DataLoaderConfig

(F821)


56-56: Undefined name DataLoaderConfig

(F821)


56-56: Undefined name DataLoaderConfig

(F821)


57-57: Undefined name ModelCkptConfig

(F821)


57-57: Undefined name ModelCkptConfig

(F821)


58-58: Undefined name Union

(F821)


58-58: Undefined name List

(F821)


61-61: Undefined name Optional

(F821)


63-63: Undefined name Optional

(F821)


67-67: Undefined name Optional

(F821)


68-68: Undefined name Optional

(F821)


68-68: Undefined name WandBConfig

(F821)


69-69: Undefined name Optional

(F821)


69-69: Undefined name OptimizerConfig

(F821)


69-69: Undefined name OptimizerConfig

(F821)


70-70: Undefined name LRSchedulerConfig

(F821)


70-70: Undefined name LRSchedulerConfig

(F821)


71-71: Undefined name EarlyStoppingConfig

(F821)


71-71: Undefined name attr

(F821)


71-71: Undefined name EarlyStoppingConfig

(F821)


104-104: Undefined name Optional

(F821)


119-119: Undefined name Optional

(F821)


120-120: Undefined name Optional

(F821)


121-121: Undefined name Optional

(F821)


122-122: Undefined name Optional

(F821)


123-123: Undefined name Optional

(F821)


124-124: Undefined name Optional

(F821)


125-125: Undefined name Optional

(F821)


125-125: Undefined name List

(F821)


157-157: Undefined name Union

(F821)


157-157: Undefined name List

(F821)


82-93: 🛠️ Refactor suggestion

Use attrs.field for attribute definitions in DataLoaderConfig.

For consistency with the TrainerConfig class and to leverage attrs features, use attrs.field to define the attributes in the DataLoaderConfig class.

Apply this diff:

-    batch_size: int = 1
-    shuffle: bool=False
-    num_workers: int=0
+    batch_size: int = attrs.field(default=1)
+    shuffle: bool = attrs.field(default=False)
+    num_workers: int = attrs.field(default=0)

Likely invalid or redundant comment.


95-104: 🛠️ Refactor suggestion

Use attrs.field for attribute definitions in ModelCkptConfig.

For consistency with the TrainerConfig class and to leverage attrs features, use attrs.field to define the attributes in the ModelCkptConfig class.

Apply this diff:

-    save_top_k: int = 1
-    save_last: Optional[bool]=None
+    save_top_k: int = attrs.field(default=1)
+    save_last: Optional[bool] = attrs.field(default=None)

Likely invalid or redundant comment.

🧰 Tools
🪛 Ruff

104-104: Undefined name Optional

(F821)

Comment on lines 127 to 135
@attrs.define
class OptimizerConfig:
'''optimizer configuration

lr: (float) Learning rate of type float. Default: 1e-3
amsgrad: (bool) Enable AMSGrad with the optimizer. Default: False
'''
lr: float = 1e-3
amsgrad: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use attrs.field for attribute definitions in OptimizerConfig.

For consistency with the TrainerConfig class and to leverage attrs features, use attrs.field to define the attributes in the OptimizerConfig class.

Apply this diff:

-    lr: float = 1e-3
-    amsgrad: bool = False
+    lr: float = attrs.field(default=1e-3)
+    amsgrad: bool = attrs.field(default=False)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@attrs.define
class OptimizerConfig:
'''optimizer configuration
lr: (float) Learning rate of type float. Default: 1e-3
amsgrad: (bool) Enable AMSGrad with the optimizer. Default: False
'''
lr: float = 1e-3
amsgrad: bool = False
@attrs.define
class OptimizerConfig:
'''optimizer configuration
lr: (float) Learning rate of type float. Default: 1e-3
amsgrad: (bool) Enable AMSGrad with the optimizer. Default: False
'''
lr: float = attrs.field(default=1e-3)
amsgrad: bool = attrs.field(default=False)

Comment on lines 159 to 170
@attrs.define
class EarlyStoppingConfig:
'''early_stopping configuration

Attributes:
stop_training_on_plateau: (bool) True if early stopping should be enabled.
min_delta: (float) Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than or equal to min_delta, will count as no improvement.
patience: (int) Number of checks with no improvement after which training will be stopped. Under the default configuration, one check happens after every training epoch.
'''
stop_training_on_plateau: bool = False
min_delta: float = 0.0
patience: int = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use attrs.field for attribute definitions in EarlyStoppingConfig.

For consistency with the TrainerConfig class and to leverage attrs features, use attrs.field to define the attributes in the EarlyStoppingConfig class.

Apply this diff:

-    stop_training_on_plateau: bool = False
-    min_delta: float = 0.0
-    patience: int = 1
+    stop_training_on_plateau: bool = attrs.field(default=False)
+    min_delta: float = attrs.field(default=0.0)
+    patience: int = attrs.field(default=1)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@attrs.define
class EarlyStoppingConfig:
'''early_stopping configuration
Attributes:
stop_training_on_plateau: (bool) True if early stopping should be enabled.
min_delta: (float) Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than or equal to min_delta, will count as no improvement.
patience: (int) Number of checks with no improvement after which training will be stopped. Under the default configuration, one check happens after every training epoch.
'''
stop_training_on_plateau: bool = False
min_delta: float = 0.0
patience: int = 1
@attrs.define
class EarlyStoppingConfig:
'''early_stopping configuration
Attributes:
stop_training_on_plateau: (bool) True if early stopping should be enabled.
min_delta: (float) Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than or equal to min_delta, will count as no improvement.
patience: (int) Number of checks with no improvement after which training will be stopped. Under the default configuration, one check happens after every training epoch.
'''
stop_training_on_plateau: bool = attrs.field(default=False)
min_delta: float = attrs.field(default=0.0)
patience: int = attrs.field(default=1)

Comment on lines 31 to 71
@attrs.define
class TrainerConfig:
"""Configuration of Trainer.

Attributes:
train_data_loader: (Note: Any parameters from Torch's DataLoader could be used.)
val_data_loader: (Similar to train_data_loader)
model_ckpt: (Note: Any parameters from Lightning's ModelCheckpoint could be used.)
trainer_devices: (int) Number of devices to train on (int), which devices to train on (list or str), or "auto" to select automatically.
trainer_accelerator: (str) One of the ("cpu", "gpu", "tpu", "ipu", "auto"). "auto" recognises the machine the model is running on and chooses the appropriate accelerator for the Trainer to be connected to.
enable_progress_bar: (bool) When True, enables printing the logs during training.
steps_per_epoch: (int) Minimum number of iterations in a single epoch. (Useful if model is trained with very few data points). Refer limit_train_batches parameter of Torch Trainer. If None, the number of iterations depends on the number of samples in the train dataset.
max_epochs: (int) Maxinum number of epochs to run.
seed: (int) Seed value for the current experiment.
use_wandb: (bool) True to enable wandb logging.
save_ckpt: (bool) True to enable checkpointing.
save_ckpt_path: (str) Directory path to save the training config and checkpoint files. Default: "./"
resume_ckpt_path: (str) Path to .ckpt file from which training is resumed. Default: None.
wandb: (Only if use_wandb is True, else skip this)
optimizer_name: (str) Optimizer to be used. One of ["Adam", "AdamW"].
optimizer:
lr_scheduler:
early_stopping:
"""
train_data_loader: DataLoaderConfig = attrs.field(factory=DataLoaderConfig)
val_data_loader: DataLoaderConfig = attrs.field(factory=DataLoaderConfig)
model_ckpt: ModelCkptConfig = attrs.field(factory=ModelCkptConfig)
trainer_devices: Union[int, List[int], str] = "auto"
trainer_accelerator: str="auto"
enable_progress_bar: bool = True
steps_per_epoch: Optional[int] = None
max_epochs: int = 10
seed: Optional[int] = None
use_wandb: bool = False
save_ckpt: bool = False
save_ckpt_path: str = "./"
resume_ckpt_path: Optional[str] = None
wandb: Optional[WandBConfig] = attrs.field(init=False)
optimizer: Optional[OptimizerConfig] = attrs.field(factory=OptimizerConfig)
lr_scheduler: LRSchedulerConfig = attrs.field(factory=LRSchedulerConfig)
early_stopping: EarlyStoppingConfig = attr.field(factory=EarlyStoppingConfig)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix early_stopping attribute and add missing type hints.

  1. The early_stopping attribute uses attr.field instead of attrs.field.
  2. Add type hints for trainer_devices, trainer_accelerator, and other attributes to resolve static analysis warnings.

Apply these changes:

-    trainer_devices: Union[int, List[int], str] = "auto"
-    trainer_accelerator: str="auto"
+    trainer_devices: Union[int, List[int], str] = attrs.field(default="auto")
+    trainer_accelerator: str = attrs.field(default="auto")
     # ... (apply similar changes to other attributes)
-    early_stopping: EarlyStoppingConfig = attr.field(factory=EarlyStoppingConfig)
+    early_stopping: EarlyStoppingConfig = attrs.field(factory=EarlyStoppingConfig)

Committable suggestion was skipped due to low confidence.

🧰 Tools
🪛 Ruff

55-55: Undefined name DataLoaderConfig

(F821)


55-55: Undefined name DataLoaderConfig

(F821)


56-56: Undefined name DataLoaderConfig

(F821)


56-56: Undefined name DataLoaderConfig

(F821)


57-57: Undefined name ModelCkptConfig

(F821)


57-57: Undefined name ModelCkptConfig

(F821)


58-58: Undefined name Union

(F821)


58-58: Undefined name List

(F821)


61-61: Undefined name Optional

(F821)


63-63: Undefined name Optional

(F821)


67-67: Undefined name Optional

(F821)


68-68: Undefined name Optional

(F821)


68-68: Undefined name WandBConfig

(F821)


69-69: Undefined name Optional

(F821)


69-69: Undefined name OptimizerConfig

(F821)


69-69: Undefined name OptimizerConfig

(F821)


70-70: Undefined name LRSchedulerConfig

(F821)


70-70: Undefined name LRSchedulerConfig

(F821)


71-71: Undefined name EarlyStoppingConfig

(F821)


71-71: Undefined name attr

(F821)


71-71: Undefined name EarlyStoppingConfig

(F821)

Comment on lines 137 to 157
@attrs.define
class LRSchedulerConfig:
'''lr_scheduler configuration

Attributes:
mode: (str) One of "min", "max". In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: "min".
threshold: (float) Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4.
threshold_mode: (str) One of "rel", "abs". In rel mode, dynamic_threshold = best * ( 1 + threshold ) in max mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: "rel".
cooldown: (int) Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0
patience: (int) Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the third epoch if the loss still hasn’t improved then. Default: 10.
factor: (float) Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1.
min_lr: (float or List[float]) A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0.
'''

mode: str = "min"
threshold: float = 1e-4
threshold_mode: str = "rel"
cooldown: int = 0
patience: int = 10
factor: float = 0.1
min_lr: Union[float, List[float]] = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use attrs.field for attribute definitions in LRSchedulerConfig.

For consistency with the TrainerConfig class and to leverage attrs features, use attrs.field to define the attributes in the LRSchedulerConfig class.

Apply this diff:

-    mode: str = "min"
-    threshold: float = 1e-4
-    threshold_mode: str = "rel"
-    cooldown: int = 0
-    patience: int = 10
-    factor: float = 0.1
-    min_lr: Union[float, List[float]] = 0.0
+    mode: str = attrs.field(default="min")
+    threshold: float = attrs.field(default=1e-4)
+    threshold_mode: str = attrs.field(default="rel")
+    cooldown: int = attrs.field(default=0)
+    patience: int = attrs.field(default=10)
+    factor: float = attrs.field(default=0.1)
+    min_lr: Union[float, List[float]] = attrs.field(default=0.0)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@attrs.define
class LRSchedulerConfig:
'''lr_scheduler configuration
Attributes:
mode: (str) One of "min", "max". In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: "min".
threshold: (float) Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4.
threshold_mode: (str) One of "rel", "abs". In rel mode, dynamic_threshold = best * ( 1 + threshold ) in max mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: "rel".
cooldown: (int) Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0
patience: (int) Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the third epoch if the loss still hasnt improved then. Default: 10.
factor: (float) Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1.
min_lr: (float or List[float]) A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0.
'''
mode: str = "min"
threshold: float = 1e-4
threshold_mode: str = "rel"
cooldown: int = 0
patience: int = 10
factor: float = 0.1
min_lr: Union[float, List[float]] = 0.0
@attrs.define
class LRSchedulerConfig:
'''lr_scheduler configuration
Attributes:
mode: (str) One of "min", "max". In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: "min".
threshold: (float) Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4.
threshold_mode: (str) One of "rel", "abs". In rel mode, dynamic_threshold = best * ( 1 + threshold ) in max mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: "rel".
cooldown: (int) Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0
patience: (int) Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the third epoch if the loss still hasn't improved then. Default: 10.
factor: (float) Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1.
min_lr: (float or List[float]) A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0.
'''
mode: str = attrs.field(default="min")
threshold: float = attrs.field(default=1e-4)
threshold_mode: str = attrs.field(default="rel")
cooldown: int = attrs.field(default=0)
patience: int = attrs.field(default=10)
factor: float = attrs.field(default=0.1)
min_lr: Union[float, List[float]] = attrs.field(default=0.0)
```
Note: The static analysis hints indicate that `Union` and `List` are undefined. To address this, you should add the following import at the beginning of the file:
```python
from typing import Union, List
🧰 Tools
🪛 Ruff

157-157: Undefined name Union

(F821)


157-157: Undefined name List

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between cfd87d4 and c73e546.

📒 Files selected for processing (1)
  • sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


28-28: Undefined name BackboneType

(F821)


29-29: Undefined name UNetConfig

(F821)


30-30: Undefined name BackboneType

(F821)


31-31: Undefined name ConvNextConfig

(F821)


32-32: Undefined name BackboneType

(F821)


33-33: Undefined name SwinTConfig

(F821)


41-41: Undefined name BackboneType

(F821)


44-44: Undefined name BackboneType

(F821)


47-47: Undefined name BackboneType

(F821)


94-94: Undefined name oneof

(F821)


95-95: Undefined name attr

(F821)


110-110: Undefined name Optional

(F821)


110-110: Undefined name SingleInstanceConfig

(F821)


111-111: Undefined name Optional

(F821)


111-111: Undefined name CentroidConfig

(F821)


112-112: Undefined name Optional

(F821)


112-112: Undefined name CenteredInstanceConfig

(F821)


113-113: Undefined name Optional

(F821)


113-113: Undefined name BottomUpConfig

(F821)


118-118: Undefined name Optional

(F821)


118-118: Undefined name SingleInstanceConfMapsConfig

(F821)


123-123: Undefined name Optional

(F821)


123-123: Undefined name CentroidConfMapsConfig

(F821)


128-128: Undefined name Optional

(F821)


128-128: Undefined name CenteredInstanceConfMapsConfig

(F821)


133-133: Undefined name Optional

(F821)


133-133: Undefined name BottomUpConfMapsConfig

(F821)


134-134: Undefined name Optional

(F821)


134-134: Undefined name PAFConfig

(F821)


145-145: Undefined name Optional

(F821)


145-145: Undefined name List

(F821)


146-146: Undefined name Optional

(F821)


147-147: Undefined name Optional

(F821)


158-158: Undefined name Optional

(F821)


159-159: Undefined name Optional

(F821)


160-160: Undefined name Optional

(F821)


172-172: Undefined name Optional

(F821)


172-172: Undefined name List

(F821)


173-173: Undefined name Optional

(F821)


174-174: Undefined name Optional

(F821)


175-175: Undefined name Optional

(F821)


187-187: Undefined name Optional

(F821)


187-187: Undefined name List

(F821)


188-188: Undefined name Optional

(F821)


189-189: Undefined name Optional

(F821)


190-190: Undefined name Optional

(F821)


202-202: Undefined name Optional

(F821)


202-202: Undefined name List

(F821)


203-203: Undefined name Optional

(F821)


204-204: Undefined name Optional

(F821)


205-205: Undefined name Optional

(F821)

🔇 Additional comments (1)
sleap_nn/config/model_config.py (1)

1-3: ⚠️ Potential issue

Add missing imports from typing module.

Add the following imports to resolve undefined names and improve type hints:

 import attrs
 from enum import Enum
+from typing import Union, Optional, List

Likely invalid or redundant comment.

Comment on lines 136 to 190
@attrs.define
class SingleInstanceConfMapsConfig:
'''

Attributes:
part_names: (List[str]) None if nodes from sio.Labels file can be used directly. Else provide text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'.
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied.
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution.
'''
part_names: Optional[List[str]] = None
sigma: Optional[float] = None
output_stride: Optional[float] = None

@attrs.define
class CentroidConfMapsConfig:
'''

Attributes:
anchor_part: (int) Note: Only for 'CenteredInstanceConfmapsHead'. Index of the anchor node to use as the anchor point. If None, the midpoint of the bounding box of all visible instance points will be used as the anchor. The bounding box midpoint will also be used if the anchor part is specified but not visible in the instance. Setting a reliable anchor point can significantly improve topdown model accuracy as they benefit from a consistent geometry of the body parts relative to the center of the image.
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied.
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution.
'''
anchor_part: Optional[int] = None
sigma: Optional[float] = None
output_stride: Optional[float] = None

@attrs.define
class CenteredInstanceConfMapsConfig:
'''

Attributes:
part_names: (List[str]) None if nodes from sio.Labels file can be used directly. Else provide text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'.
anchor_part: (int) Note: Only for 'CenteredInstanceConfmapsHead'. Index of the anchor node to use as the anchor point. If None, the midpoint of the bounding box of all visible instance points will be used as the anchor. The bounding box midpoint will also be used if the anchor part is specified but not visible in the instance. Setting a reliable anchor point can significantly improve topdown model accuracy as they benefit from a consistent geometry of the body parts relative to the center of the image.
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied.
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution.
'''
part_names: Optional[List[str]] = None
anchor_part: Optional[int] = None
sigma: Optional[float] = None
output_stride: Optional[float] = None

@attrs.define
class BottomUpConfMapsConfig():
'''

Attributes:
part_names: (List[str]) None if nodes from sio.Labels file can be used directly. Else provide text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'.
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied.
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution.
loss_weight: (float) Scalar float used to weigh the loss term for this head during training. Increase this to encourage the optimization to focus on improving this specific output in multi-head models.
'''
part_names: Optional[List[str]] = None
sigma: Optional[float] = None
output_stride: Optional[float] = None
loss_weight: Optional[float] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider creating a base configuration class for shared attributes.

The configuration classes (SingleInstanceConfMapsConfig, CentroidConfMapsConfig, etc.) share common attributes like sigma and output_stride. Consider creating a base class:

@attrs.define
class BaseConfMapsConfig:
    sigma: Optional[float] = None
    output_stride: Optional[float] = None

@attrs.define
class SingleInstanceConfMapsConfig(BaseConfMapsConfig):
    part_names: Optional[List[str]] = None

# Similar for other classes

This would reduce code duplication and make maintenance easier.

🧰 Tools
🪛 Ruff

145-145: Undefined name Optional

(F821)


145-145: Undefined name List

(F821)


146-146: Undefined name Optional

(F821)


147-147: Undefined name Optional

(F821)


158-158: Undefined name Optional

(F821)


159-159: Undefined name Optional

(F821)


160-160: Undefined name Optional

(F821)


172-172: Undefined name Optional

(F821)


172-172: Undefined name List

(F821)


173-173: Undefined name Optional

(F821)


174-174: Undefined name Optional

(F821)


175-175: Undefined name Optional

(F821)


187-187: Undefined name Optional

(F821)


187-187: Undefined name List

(F821)


188-188: Undefined name Optional

(F821)


189-189: Undefined name Optional

(F821)


190-190: Undefined name Optional

(F821)

"""

init_weight: str = "default"
pre_trained_weights: str = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix type annotation for pre_trained_weights.

The attribute should use Optional[str] since it can be None:

-    pre_trained_weights: str = None
+    pre_trained_weights: Optional[str] = None

Committable suggestion was skipped due to low confidence.

Comment on lines 94 to 96
@oneof
@attr.s(auto_attribs=True)
class HeadsConfig:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Update HeadsConfig decorator and class definition.

  1. The @oneof decorator is undefined
  2. The class uses older attr.s style instead of attrs.define
-@oneof
-@attr.s(auto_attribs=True)
+@attrs.define
 class HeadsConfig:

Committable suggestion was skipped due to low confidence.

🧰 Tools
🪛 Ruff

94-94: Undefined name oneof

(F821)


95-95: Undefined name attr

(F821)

Comment on lines 50 to 53
class BackboneType(Enum):
UNET = "unet"
CONVNEXT = 'convnext'
SWINT = 'swint'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider moving BackboneType enum outside ModelConfig.

Moving the enum outside would improve reusability and follow Python's common patterns:

 from enum import Enum
 
+class BackboneType(Enum):
+    UNET = "unet"
+    CONVNEXT = 'convnext'
+    SWINT = 'swint'
+
 @attrs.define
 class ModelConfig:
-    class BackboneType(Enum):
-        UNET = "unet"
-        CONVNEXT = 'convnext'
-        SWINT = 'swint'

Committable suggestion was skipped due to low confidence.

Comment on lines 37 to 48
def validate_pre_trained_weights(self):
convnext_weights = ["ConvNeXt_Base_Weights", "ConvNeXt_Tiny_Weights", "ConvNeXt_Small_Weights", "ConvNeXt_Large_Weights"]
swint_weights = ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"]

if self.backbone_type == BackboneType.CONVNEXT:
if self.pre_trained_weights not in convnext_weights:
raise ValueError(f"Invalid pre-trained weights for ConvNext. Must be one of {convnext_weights}")
elif self.backbone_type == BackboneType.SWINT:
if self.pre_trained_weights not in swint_weights:
raise ValueError(f"Invalid pre-trained weights for SwinT. Must be one of {swint_weights}")
elif self.backbone_type == BackboneType.UNET and self.pre_trained_weights is not None:
raise ValueError("UNet does not support pre-trained weights.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Simplify validation logic using a dictionary mapping.

The validation logic can be made more maintainable using a dictionary:

     def validate_pre_trained_weights(self):
-        convnext_weights = ["ConvNeXt_Base_Weights", "ConvNeXt_Tiny_Weights", "ConvNeXt_Small_Weights", "ConvNeXt_Large_Weights"]
-        swint_weights = ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"]
-
-        if self.backbone_type == BackboneType.CONVNEXT:
-            if self.pre_trained_weights not in convnext_weights:
-                raise ValueError(f"Invalid pre-trained weights for ConvNext. Must be one of {convnext_weights}")
-        elif self.backbone_type == BackboneType.SWINT:
-            if self.pre_trained_weights not in swint_weights:
-                raise ValueError(f"Invalid pre-trained weights for SwinT. Must be one of {swint_weights}")
-        elif self.backbone_type == BackboneType.UNET and self.pre_trained_weights is not None:
-            raise ValueError("UNet does not support pre-trained weights.")
+        valid_weights = {
+            BackboneType.CONVNEXT: ["ConvNeXt_Base_Weights", "ConvNeXt_Tiny_Weights", "ConvNeXt_Small_Weights", "ConvNeXt_Large_Weights"],
+            BackboneType.SWINT: ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"],
+            BackboneType.UNET: None
+        }
+        allowed_weights = valid_weights[self.backbone_type]
+        if allowed_weights is None and self.pre_trained_weights is not None:
+            raise ValueError(f"{self.backbone_type.value} does not support pre-trained weights.")
+        elif allowed_weights and self.pre_trained_weights not in allowed_weights:
+            raise ValueError(f"Invalid pre-trained weights for {self.backbone_type.value}. Must be one of {allowed_weights}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def validate_pre_trained_weights(self):
convnext_weights = ["ConvNeXt_Base_Weights", "ConvNeXt_Tiny_Weights", "ConvNeXt_Small_Weights", "ConvNeXt_Large_Weights"]
swint_weights = ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"]
if self.backbone_type == BackboneType.CONVNEXT:
if self.pre_trained_weights not in convnext_weights:
raise ValueError(f"Invalid pre-trained weights for ConvNext. Must be one of {convnext_weights}")
elif self.backbone_type == BackboneType.SWINT:
if self.pre_trained_weights not in swint_weights:
raise ValueError(f"Invalid pre-trained weights for SwinT. Must be one of {swint_weights}")
elif self.backbone_type == BackboneType.UNET and self.pre_trained_weights is not None:
raise ValueError("UNet does not support pre-trained weights.")
def validate_pre_trained_weights(self):
valid_weights = {
BackboneType.CONVNEXT: ["ConvNeXt_Base_Weights", "ConvNeXt_Tiny_Weights", "ConvNeXt_Small_Weights", "ConvNeXt_Large_Weights"],
BackboneType.SWINT: ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"],
BackboneType.UNET: None
}
allowed_weights = valid_weights[self.backbone_type]
if allowed_weights is None and self.pre_trained_weights is not None:
raise ValueError(f"{self.backbone_type.value} does not support pre-trained weights.")
elif allowed_weights and self.pre_trained_weights not in allowed_weights:
raise ValueError(f"Invalid pre-trained weights for {self.backbone_type.value}. Must be one of {allowed_weights}")
🧰 Tools
🪛 Ruff

41-41: Undefined name BackboneType

(F821)


44-44: Undefined name BackboneType

(F821)


47-47: Undefined name BackboneType

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 11

🧹 Outside diff range and nitpick comments (1)
sleap_nn/config/trainer_config.py (1)

31-54: Complete the documentation for optimizer, lr_scheduler, and early_stopping attributes.

The class docstring is missing descriptions for some attributes.

Add these descriptions to the docstring:

         wandb: (Only if use_wandb is True, else skip this)
         optimizer_name: (str) Optimizer to be used. One of ["Adam", "AdamW"].
-        optimizer:
-        lr_scheduler:
-        early_stopping:
+        optimizer: (OptimizerConfig) Configuration for the optimizer.
+        lr_scheduler: (LRSchedulerConfig) Configuration for the learning rate scheduler.
+        early_stopping: (EarlyStoppingConfig) Configuration for early stopping criteria.
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between c73e546 and d7ba20c.

📒 Files selected for processing (2)
  • sleap_nn/config/model_config.py (1 hunks)
  • sleap_nn/config/trainer_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py

19-19: Undefined name BackboneType

(F821)


19-19: Undefined name BackboneType

(F821)


20-20: Undefined name Union

(F821)


20-20: Undefined name UNetConfig

(F821)


20-20: Undefined name ConvNextConfig

(F821)


20-20: Undefined name SwinTConfig

(F821)


23-23: Undefined name HeadConfig

(F821)


23-23: Undefined name HeadConfig

(F821)


32-32: Undefined name BackboneType

(F821)


33-33: Undefined name UNetConfig

(F821)


34-34: Undefined name BackboneType

(F821)


35-35: Undefined name ConvNextConfig

(F821)


36-36: Undefined name BackboneType

(F821)


37-37: Undefined name SwinTConfig

(F821)


51-51: Undefined name BackboneType

(F821)


56-56: Undefined name BackboneType

(F821)


62-62: Undefined name BackboneType

(F821)


166-166: Undefined name oneof

(F821)


167-167: Undefined name attr

(F821)


182-182: Undefined name Optional

(F821)


182-182: Undefined name SingleInstanceConfig

(F821)


183-183: Undefined name Optional

(F821)


183-183: Undefined name CentroidConfig

(F821)


184-184: Undefined name Optional

(F821)


184-184: Undefined name CenteredInstanceConfig

(F821)


185-185: Undefined name Optional

(F821)


185-185: Undefined name BottomUpConfig

(F821)


191-191: Undefined name Optional

(F821)


191-191: Undefined name SingleInstanceConfMapsConfig

(F821)


197-197: Undefined name Optional

(F821)


197-197: Undefined name CentroidConfMapsConfig

(F821)


203-203: Undefined name Optional

(F821)


203-203: Undefined name CenteredInstanceConfMapsConfig

(F821)


209-209: Undefined name Optional

(F821)


209-209: Undefined name BottomUpConfMapsConfig

(F821)


210-210: Undefined name Optional

(F821)


210-210: Undefined name PAFConfig

(F821)


223-223: Undefined name Optional

(F821)


223-223: Undefined name List

(F821)


224-224: Undefined name Optional

(F821)


225-225: Undefined name Optional

(F821)


238-238: Undefined name Optional

(F821)


239-239: Undefined name Optional

(F821)


240-240: Undefined name Optional

(F821)


254-254: Undefined name Optional

(F821)


254-254: Undefined name List

(F821)


255-255: Undefined name Optional

(F821)


256-256: Undefined name Optional

(F821)


257-257: Undefined name Optional

(F821)


271-271: Undefined name Optional

(F821)


271-271: Undefined name List

(F821)


272-272: Undefined name Optional

(F821)


273-273: Undefined name Optional

(F821)


274-274: Undefined name Optional

(F821)


288-288: Undefined name Optional

(F821)


288-288: Undefined name List

(F821)


289-289: Undefined name Optional

(F821)


290-290: Undefined name Optional

(F821)


291-291: Undefined name Optional

(F821)

sleap_nn/config/trainer_config.py

2-2: omegaconf.OmegaConf imported but unused

Remove unused import: omegaconf.OmegaConf

(F401)


56-56: Undefined name DataLoaderConfig

(F821)


56-56: Undefined name DataLoaderConfig

(F821)


57-57: Undefined name DataLoaderConfig

(F821)


57-57: Undefined name DataLoaderConfig

(F821)


58-58: Undefined name ModelCkptConfig

(F821)


58-58: Undefined name ModelCkptConfig

(F821)


59-59: Undefined name Union

(F821)


59-59: Undefined name List

(F821)


62-62: Undefined name Optional

(F821)


64-64: Undefined name Optional

(F821)


68-68: Undefined name Optional

(F821)


69-69: Undefined name Optional

(F821)


69-69: Undefined name WandBConfig

(F821)


70-70: Undefined name Optional

(F821)


70-70: Undefined name OptimizerConfig

(F821)


70-70: Undefined name OptimizerConfig

(F821)


71-71: Undefined name LRSchedulerConfig

(F821)


71-71: Undefined name LRSchedulerConfig

(F821)


72-72: Undefined name EarlyStoppingConfig

(F821)


72-72: Undefined name attr

(F821)


72-72: Undefined name EarlyStoppingConfig

(F821)


110-110: Undefined name Optional

(F821)


127-127: Undefined name Optional

(F821)


128-128: Undefined name Optional

(F821)


129-129: Undefined name Optional

(F821)


130-130: Undefined name Optional

(F821)


131-131: Undefined name Optional

(F821)


132-132: Undefined name Optional

(F821)


133-133: Undefined name Optional

(F821)


133-133: Undefined name List

(F821)


169-169: Undefined name Union

(F821)


169-169: Undefined name List

(F821)

🔇 Additional comments (5)
sleap_nn/config/trainer_config.py (2)

4-28: LGTM! Well-documented module purpose.

The module docstring clearly explains the purpose and design principles of these configuration classes.


74-83: LGTM! Clean post-initialization logic.

The __attrs_post_init__ method correctly handles the conditional initialization of the wandb configuration.

sleap_nn/config/model_config.py (3)

18-18: ⚠️ Potential issue

Use Optional type for pre_trained_weights.

The attribute allows None but is typed as str.

-    pre_trained_weights: str = None
+    pre_trained_weights: Optional[str] = None

Likely invalid or redundant comment.


1-4: ⚠️ Potential issue

Add missing imports from typing module.

The code uses type annotations but lacks the required imports. Add the following imports:

 import attrs
 from enum import Enum
+from typing import Dict, List, Optional, Union

Likely invalid or redundant comment.


166-168: ⚠️ Potential issue

Update HeadsConfig decorator.

The class uses deprecated attrs style and an undefined oneof decorator.

-@oneof
-@attr.s(auto_attribs=True)
+@attrs.define
 class HeadsConfig:

Note: If the oneof decorator is required for validation, please ensure it's properly imported or implemented.

🧰 Tools
🪛 Ruff

166-166: Undefined name oneof

(F821)


167-167: Undefined name attr

(F821)

Comment on lines +145 to +146
lr: float = 1e-3
amsgrad: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use attrs.field for OptimizerConfig attributes.

For consistency with other classes, use attrs.field for attribute definitions.

-    lr: float = 1e-3
-    amsgrad: bool = False
+    lr: float = attrs.field(default=1e-3)
+    amsgrad: bool = attrs.field(default=False)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
lr: float = 1e-3
amsgrad: bool = False
lr: float = attrs.field(default=1e-3)
amsgrad: bool = attrs.field(default=False)

Comment on lines +182 to +184
stop_training_on_plateau: bool = False
min_delta: float = 0.0
patience: int = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use attrs.field for EarlyStoppingConfig attributes.

For consistency with other classes, use attrs.field for attribute definitions.

-    stop_training_on_plateau: bool = False
-    min_delta: float = 0.0
-    patience: int = 1
+    stop_training_on_plateau: bool = attrs.field(default=False)
+    min_delta: float = attrs.field(default=0.0)
+    patience: int = attrs.field(default=1)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
stop_training_on_plateau: bool = False
min_delta: float = 0.0
patience: int = 1
stop_training_on_plateau: bool = attrs.field(default=False)
min_delta: float = attrs.field(default=0.0)
patience: int = attrs.field(default=1)

Comment on lines +95 to +97
batch_size: int = 1
shuffle: bool = False
num_workers: int = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use attrs.field for DataLoaderConfig attributes.

For consistency with other classes, use attrs.field for attribute definitions.

-    batch_size: int = 1
-    shuffle: bool = False
-    num_workers: int = 0
+    batch_size: int = attrs.field(default=1)
+    shuffle: bool = attrs.field(default=False)
+    num_workers: int = attrs.field(default=0)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
batch_size: int = 1
shuffle: bool = False
num_workers: int = 0
batch_size: int = attrs.field(default=1)
shuffle: bool = attrs.field(default=False)
num_workers: int = attrs.field(default=0)

Comment on lines +109 to +110
save_top_k: int = 1
save_last: Optional[bool] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use attrs.field for ModelCkptConfig attributes.

For consistency with other classes, use attrs.field for attribute definitions.

-    save_top_k: int = 1
-    save_last: Optional[bool] = None
+    save_top_k: int = attrs.field(default=1)
+    save_last: Optional[bool] = attrs.field(default=None)

Committable suggestion was skipped due to low confidence.

🧰 Tools
🪛 Ruff

110-110: Undefined name Optional

(F821)

backbone_config: Union[UNetConfig, ConvNextConfig, SwinTConfig] = attrs.field(
init=False
) # backbone_config can be any of these 3 configurations. init=False lets you set the parameters later (not in initialization)
head_configs: HeadConfig = attrs.field(factory=HeadConfig)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix inconsistent class reference.

The attribute references HeadConfig but the class is named HeadsConfig.

-    head_configs: HeadConfig = attrs.field(factory=HeadConfig)
+    head_configs: HeadsConfig = attrs.field(factory=HeadsConfig)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
head_configs: HeadConfig = attrs.field(factory=HeadConfig)
head_configs: HeadsConfig = attrs.field(factory=HeadsConfig)
🧰 Tools
🪛 Ruff

23-23: Undefined name HeadConfig

(F821)


23-23: Undefined name HeadConfig

(F821)

kernel_size: int = 3
filters: int = 32
filters_rate: float = 1.5
max_stride: int = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Use Optional type for max_stride.

The attribute allows None but is typed as int.

-        max_stride: int = None
+        max_stride: Optional[int] = None

Committable suggestion was skipped due to low confidence.

Comment on lines 67 to 71
class BackboneType(Enum):
UNET = "unet"
CONVNEXT = "convnext"
SWINT = "swint"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Move BackboneType enum to module level.

The enum should be at module level for better reusability and to follow Python conventions.

-    class BackboneType(Enum):
-        UNET = "unet"
-        CONVNEXT = "convnext"
-        SWINT = "swint"
+class BackboneType(Enum):
+    UNET = "unet"
+    CONVNEXT = "convnext"
+    SWINT = "swint"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
class BackboneType(Enum):
UNET = "unet"
CONVNEXT = "convnext"
SWINT = "swint"
class BackboneType(Enum):
UNET = "unet"
CONVNEXT = "convnext"
SWINT = "swint"

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between d7ba20c and 19f5256.

📒 Files selected for processing (1)
  • sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py

2-2: enum.Enum imported but unused

Remove unused import: enum.Enum

(F401)


18-18: Undefined name BackboneConfig

(F821)


18-18: Undefined name BackboneConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


35-35: Undefined name BackboneType

(F821)


40-40: Undefined name BackboneType

(F821)


46-46: Undefined name BackboneType

(F821)


52-52: Undefined name oneof

(F821)


63-63: Undefined name Optional

(F821)


63-63: Undefined name UNetConfig

(F821)


64-64: Undefined name Optional

(F821)


64-64: Undefined name ConvNextConfig

(F821)


65-65: Undefined name Optional

(F821)


65-65: Undefined name SwinTConfig

(F821)


164-164: Undefined name oneof

(F821)


180-180: Undefined name Optional

(F821)


180-180: Undefined name SingleInstanceConfig

(F821)


181-181: Undefined name Optional

(F821)


181-181: Undefined name CentroidConfig

(F821)


182-182: Undefined name Optional

(F821)


182-182: Undefined name CenteredInstanceConfig

(F821)


183-183: Undefined name Optional

(F821)


183-183: Undefined name BottomUpConfig

(F821)


189-189: Undefined name Optional

(F821)


189-189: Undefined name SingleInstanceConfMapsConfig

(F821)


195-195: Undefined name Optional

(F821)


195-195: Undefined name CentroidConfMapsConfig

(F821)


201-201: Undefined name Optional

(F821)


201-201: Undefined name CenteredInstanceConfMapsConfig

(F821)


207-207: Undefined name Optional

(F821)


207-207: Undefined name BottomUpConfMapsConfig

(F821)


208-208: Undefined name Optional

(F821)


208-208: Undefined name PAFConfig

(F821)


221-221: Undefined name Optional

(F821)


221-221: Undefined name List

(F821)


222-222: Undefined name Optional

(F821)


223-223: Undefined name Optional

(F821)


236-236: Undefined name Optional

(F821)


237-237: Undefined name Optional

(F821)


238-238: Undefined name Optional

(F821)


252-252: Undefined name Optional

(F821)


252-252: Undefined name List

(F821)


253-253: Undefined name Optional

(F821)


254-254: Undefined name Optional

(F821)


255-255: Undefined name Optional

(F821)


269-269: Undefined name Optional

(F821)


269-269: Undefined name List

(F821)


270-270: Undefined name Optional

(F821)


271-271: Undefined name Optional

(F821)


272-272: Undefined name Optional

(F821)


286-286: Undefined name Optional

(F821)


286-286: Undefined name List

(F821)


287-287: Undefined name Optional

(F821)


288-288: Undefined name Optional

(F821)


289-289: Undefined name Optional

(F821)

🔇 Additional comments (2)
sleap_nn/config/model_config.py (2)

1-4: ⚠️ Potential issue

Add missing imports for type hints and decorators.

Add the following imports to resolve undefined names and improve type safety:

 import attrs
 from enum import Enum
+from typing import Optional, List, Dict, Union
+from omegaconf import MISSING

Let's verify if omegaconf is a project dependency:

✅ Verification successful

Add missing imports for type hints and decorators

Add the following imports to resolve undefined names and improve type safety:

 import attrs
 from enum import Enum
+from typing import Optional, List, Dict, Union
+from omegaconf import MISSING

The verification confirms:

  1. omegaconf is a project dependency through hydra-core in pyproject.toml
  2. omegaconf is used across the codebase, including other config files
  3. Specifically, MISSING from omegaconf is already used in sleap_nn/config/data_config.py
  4. The project follows a consistent pattern of using omegaconf for configuration management

The suggested imports are appropriate and align with the project's existing practices.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check if omegaconf is listed in requirements
rg -l "omegaconf" "requirements.*"

Length of output: 89


Script:

#!/bin/bash
# Let's check for setup.py, pyproject.toml, or any other dependency files
fd -e toml -e txt -e py "setup|requirements|pyproject"

# Also search for omegaconf usage in the codebase
rg -l "omegaconf"

Length of output: 1002


Script:

#!/bin/bash
# Let's check the content of pyproject.toml for dependencies
cat pyproject.toml

# Also check environment files for omegaconf
grep -h "omegaconf" environment*.yml

Length of output: 1667


Script:

#!/bin/bash
# Let's check if hydra-core (which includes omegaconf) is actually used in the config files
rg "from hydra|import hydra" sleap_nn/config/
rg "from omegaconf|import omegaconf" sleap_nn/config/

# Also check how MISSING is used in the config files
rg "MISSING" sleap_nn/config/

Length of output: 455

🧰 Tools
🪛 Ruff

2-2: enum.Enum imported but unused

Remove unused import: enum.Enum

(F401)


52-66: ⚠️ Potential issue

Fix oneof decorator usage.

The @oneof decorator is undefined. This appears to be a custom decorator that needs to be imported or defined.

Let's check if this decorator exists in the codebase:

🧰 Tools
🪛 Ruff

52-52: Undefined name oneof

(F821)


63-63: Undefined name Optional

(F821)


63-63: Undefined name UNetConfig

(F821)


64-64: Undefined name Optional

(F821)


64-64: Undefined name ConvNextConfig

(F821)


65-65: Undefined name Optional

(F821)


65-65: Undefined name SwinTConfig

(F821)

kernel_size: int = 3
filters: int = 32
filters_rate: float = 1.5
max_stride: int = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix type hint for max_stride in UNetConfig.

The max_stride attribute allows None but is typed as int.

-    max_stride: int = None
+    max_stride: Optional[int] = None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
max_stride: int = None
max_stride: Optional[int] = None

Comment on lines +5 to +19
@attrs.define
class ModelConfig:
"""Configurations related to model architecture.

Attributes:
init_weight: (str) model weights initialization method. "default" uses kaiming uniform initialization and "xavier" uses Xavier initialization method.
pre_trained_weights: (str) Pretrained weights file name supported only for ConvNext and SwinT backbones. For ConvNext, one of ["ConvNeXt_Base_Weights","ConvNeXt_Tiny_Weights", "ConvNeXt_Small_Weights", "ConvNeXt_Large_Weights"]. For SwinT, one of ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"].
backbone_config: initialize either UNetConfig, ConvNextConfig, or SwinTConfig based on input from backbone_type
head_config: head_configs: (Dict) Dictionary with the following keys having head configs for the model to be trained. Note: Configs should be provided only for the model to train and others should be None
"""

init_weight: str = "default"
pre_trained_weights: str = None
backbone_config: BackboneConfig = attrs.field(factory=BackboneConfig)
head_configs: HeadConfig = attrs.field(factory=HeadConfig)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix type hints and class references in ModelConfig.

  1. The pre_trained_weights should be Optional[str] since it accepts None
  2. The head_configs references HeadConfig but the class is named HeadsConfig
  3. The BackboneType enum is missing

Apply these fixes:

+class BackboneType(Enum):
+    UNET = "unet"
+    CONVNEXT = "convnext"
+    SWINT = "swint"
+
 @attrs.define
 class ModelConfig:
     """Configurations related to model architecture.
     ...
     """
     init_weight: str = "default"
-    pre_trained_weights: str = None
+    pre_trained_weights: Optional[str] = None
     backbone_config: BackboneConfig = attrs.field(factory=BackboneConfig)
-    head_configs: HeadConfig = attrs.field(factory=HeadConfig)
+    head_configs: HeadsConfig = attrs.field(factory=HeadsConfig)

Committable suggestion was skipped due to low confidence.

🧰 Tools
🪛 Ruff

18-18: Undefined name BackboneConfig

(F821)


18-18: Undefined name BackboneConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


19-19: Undefined name HeadConfig

(F821)

Comment on lines +211 to +289
@attrs.define
class SingleInstanceConfMapsConfig:
"""Single Instance configuration map

Attributes:
part_names: (List[str]) None if nodes from sio.Labels file can be used directly. Else provide text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'.
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied.
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution.
"""

part_names: Optional[List[str]] = None
sigma: Optional[float] = None
output_stride: Optional[float] = None


@attrs.define
class CentroidConfMapsConfig:
"""Centroid configuration map

Attributes:
anchor_part: (int) Note: Only for 'CenteredInstanceConfmapsHead'. Index of the anchor node to use as the anchor point. If None, the midpoint of the bounding box of all visible instance points will be used as the anchor. The bounding box midpoint will also be used if the anchor part is specified but not visible in the instance. Setting a reliable anchor point can significantly improve topdown model accuracy as they benefit from a consistent geometry of the body parts relative to the center of the image.
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied.
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution.
"""

anchor_part: Optional[int] = None
sigma: Optional[float] = None
output_stride: Optional[float] = None


@attrs.define
class CenteredInstanceConfMapsConfig:
"""Centered Instance configuration map

Attributes:
part_names: (List[str]) None if nodes from sio.Labels file can be used directly. Else provide text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'.
anchor_part: (int) Note: Only for 'CenteredInstanceConfmapsHead'. Index of the anchor node to use as the anchor point. If None, the midpoint of the bounding box of all visible instance points will be used as the anchor. The bounding box midpoint will also be used if the anchor part is specified but not visible in the instance. Setting a reliable anchor point can significantly improve topdown model accuracy as they benefit from a consistent geometry of the body parts relative to the center of the image.
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied.
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution.
"""

part_names: Optional[List[str]] = None
anchor_part: Optional[int] = None
sigma: Optional[float] = None
output_stride: Optional[float] = None


@attrs.define
class BottomUpConfMapsConfig:
"""Bottomup configuration map

Attributes:
part_names: (List[str]) None if nodes from sio.Labels file can be used directly. Else provide text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'.
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied.
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution.
loss_weight: (float) Scalar float used to weigh the loss term for this head during training. Increase this to encourage the optimization to focus on improving this specific output in multi-head models.
"""

part_names: Optional[List[str]] = None
sigma: Optional[float] = None
output_stride: Optional[float] = None
loss_weight: Optional[float] = None


@attrs.define
class PAFConfig:
"""PAF configuration map

Attributes:
edges: (List[str]) None if edges from sio.Labels file can be used directly. Note: Only for 'PartAffinityFieldsHead'. List of indices (src, dest) that form an edge.
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied.
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution.
loss_weight: (float) Scalar float used to weigh the loss term for this head during training. Increase this to encourage the optimization to focus on improving this specific output in multi-head models.
"""

edges: Optional[List[str]] = None
sigma: Optional[float] = None
output_stride: Optional[float] = None
loss_weight: Optional[float] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider creating a base configuration class for common attributes.

The configuration classes share common attributes (sigma, output_stride, loss_weight). Consider creating a base class to reduce code duplication.

+@attrs.define
+class BaseConfMapsConfig:
+    """Base configuration for confidence maps.
+
+    Attributes:
+        sigma: Spread of the Gaussian distribution of the confidence maps.
+        output_stride: The stride of the output confidence maps relative to the input image.
+        loss_weight: Weight of the loss term during training.
+    """
+    sigma: Optional[float] = None
+    output_stride: Optional[float] = None
+    loss_weight: Optional[float] = None
+
 @attrs.define
-class SingleInstanceConfMapsConfig:
+class SingleInstanceConfMapsConfig(BaseConfMapsConfig):
     """Single Instance configuration map"""
     part_names: Optional[List[str]] = None
-    sigma: Optional[float] = None
-    output_stride: Optional[float] = None

# Similar changes for other configuration classes

Committable suggestion was skipped due to low confidence.

🧰 Tools
🪛 Ruff

221-221: Undefined name Optional

(F821)


221-221: Undefined name List

(F821)


222-222: Undefined name Optional

(F821)


223-223: Undefined name Optional

(F821)


236-236: Undefined name Optional

(F821)


237-237: Undefined name Optional

(F821)


238-238: Undefined name Optional

(F821)


252-252: Undefined name Optional

(F821)


252-252: Undefined name List

(F821)


253-253: Undefined name Optional

(F821)


254-254: Undefined name Optional

(F821)


255-255: Undefined name Optional

(F821)


269-269: Undefined name Optional

(F821)


269-269: Undefined name List

(F821)


270-270: Undefined name Optional

(F821)


271-271: Undefined name Optional

(F821)


272-272: Undefined name Optional

(F821)


286-286: Undefined name Optional

(F821)


286-286: Undefined name List

(F821)


287-287: Undefined name Optional

(F821)


288-288: Undefined name Optional

(F821)


289-289: Undefined name Optional

(F821)

Comment on lines +26 to +49
def validate_pre_trained_weights(self):
convnext_weights = [
"ConvNeXt_Base_Weights",
"ConvNeXt_Tiny_Weights",
"ConvNeXt_Small_Weights",
"ConvNeXt_Large_Weights",
]
swint_weights = ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"]

if self.backbone_type == BackboneType.CONVNEXT:
if self.pre_trained_weights not in convnext_weights:
raise ValueError(
f"Invalid pre-trained weights for ConvNext. Must be one of {convnext_weights}"
)
elif self.backbone_type == BackboneType.SWINT:
if self.pre_trained_weights not in swint_weights:
raise ValueError(
f"Invalid pre-trained weights for SwinT. Must be one of {swint_weights}"
)
elif (
self.backbone_type == BackboneType.UNET
and self.pre_trained_weights is not None
):
raise ValueError("UNet does not support pre-trained weights.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Simplify pre-trained weights validation using a mapping.

The validation logic can be simplified and made more maintainable using a dictionary mapping.

     def validate_pre_trained_weights(self):
-        convnext_weights = [
-            "ConvNeXt_Base_Weights",
-            "ConvNeXt_Tiny_Weights",
-            "ConvNeXt_Small_Weights",
-            "ConvNeXt_Large_Weights",
-        ]
-        swint_weights = ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"]
-
-        if self.backbone_type == BackboneType.CONVNEXT:
-            if self.pre_trained_weights not in convnext_weights:
-                raise ValueError(
-                    f"Invalid pre-trained weights for ConvNext. Must be one of {convnext_weights}"
-                )
-        elif self.backbone_type == BackboneType.SWINT:
-            if self.pre_trained_weights not in swint_weights:
-                raise ValueError(
-                    f"Invalid pre-trained weights for SwinT. Must be one of {swint_weights}"
-                )
-        elif (
-            self.backbone_type == BackboneType.UNET
-            and self.pre_trained_weights is not None
-        ):
-            raise ValueError("UNet does not support pre-trained weights.")
+        VALID_WEIGHTS = {
+            BackboneType.CONVNEXT: [
+                "ConvNeXt_Base_Weights",
+                "ConvNeXt_Tiny_Weights",
+                "ConvNeXt_Small_Weights",
+                "ConvNeXt_Large_Weights",
+            ],
+            BackboneType.SWINT: [
+                "Swin_T_Weights",
+                "Swin_S_Weights",
+                "Swin_B_Weights"
+            ],
+            BackboneType.UNET: None
+        }
+        
+        allowed_weights = VALID_WEIGHTS.get(self.backbone_type)
+        if allowed_weights is None and self.pre_trained_weights is not None:
+            raise ValueError(f"{self.backbone_type.value} does not support pre-trained weights.")
+        elif allowed_weights and self.pre_trained_weights not in allowed_weights:
+            raise ValueError(
+                f"Invalid pre-trained weights for {self.backbone_type.value}. "
+                f"Must be one of {allowed_weights}"
+            )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def validate_pre_trained_weights(self):
convnext_weights = [
"ConvNeXt_Base_Weights",
"ConvNeXt_Tiny_Weights",
"ConvNeXt_Small_Weights",
"ConvNeXt_Large_Weights",
]
swint_weights = ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"]
if self.backbone_type == BackboneType.CONVNEXT:
if self.pre_trained_weights not in convnext_weights:
raise ValueError(
f"Invalid pre-trained weights for ConvNext. Must be one of {convnext_weights}"
)
elif self.backbone_type == BackboneType.SWINT:
if self.pre_trained_weights not in swint_weights:
raise ValueError(
f"Invalid pre-trained weights for SwinT. Must be one of {swint_weights}"
)
elif (
self.backbone_type == BackboneType.UNET
and self.pre_trained_weights is not None
):
raise ValueError("UNet does not support pre-trained weights.")
def validate_pre_trained_weights(self):
VALID_WEIGHTS = {
BackboneType.CONVNEXT: [
"ConvNeXt_Base_Weights",
"ConvNeXt_Tiny_Weights",
"ConvNeXt_Small_Weights",
"ConvNeXt_Large_Weights",
],
BackboneType.SWINT: [
"Swin_T_Weights",
"Swin_S_Weights",
"Swin_B_Weights"
],
BackboneType.UNET: None
}
allowed_weights = VALID_WEIGHTS.get(self.backbone_type)
if allowed_weights is None and self.pre_trained_weights is not None:
raise ValueError(f"{self.backbone_type.value} does not support pre-trained weights.")
elif allowed_weights and self.pre_trained_weights not in allowed_weights:
raise ValueError(
f"Invalid pre-trained weights for {self.backbone_type.value}. "
f"Must be one of {allowed_weights}"
)
🧰 Tools
🪛 Ruff

35-35: Undefined name BackboneType

(F821)


40-40: Undefined name BackboneType

(F821)


46-46: Undefined name BackboneType

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Outside diff range and nitpick comments (2)
sleap_nn/config/data_config.py (2)

21-21: Address TODO comment in docstring

The TODO comment about changing chunk_size in inference checkpoints should be tracked properly.

Would you like me to create a GitHub issue to track this TODO item?


65-75: Enhance class documentation

The documentation could be more detailed by describing what types of augmentations are supported and how they interact.

Consider expanding the docstring to include:

  • Purpose of intensity vs. geometric augmentations
  • How these configurations are used in the training pipeline
  • Examples of common configurations
🧰 Tools
🪛 Ruff

74-74: Undefined name IntensityConfig

(F821)


75-75: Undefined name GeometricConfig

(F821)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 19f5256 and 16c8f26.

📒 Files selected for processing (1)
  • sleap_nn/config/data_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py

3-3: typing.List imported but unused

Remove unused import

(F401)


3-3: typing.Dict imported but unused

Remove unused import

(F401)


34-34: Undefined name PreprocessingConfig

(F821)


34-34: Undefined name PreprocessingConfig

(F821)


36-36: Undefined name AugmentationConfig

(F821)


60-60: Undefined name Union

(F821)


74-74: Undefined name IntensityConfig

(F821)


75-75: Undefined name GeometricConfig

(F821)

🔇 Additional comments (3)
sleap_nn/config/data_config.py (3)

39-41: LGTM!

The validation function is well-implemented with a clear error message and is used consistently throughout the configuration classes.


44-62: LGTM!

The PreprocessingConfig class is well-structured with:

  • Clear documentation for each attribute
  • Appropriate default values
  • Type hints for optional fields
🧰 Tools
🪛 Ruff

60-60: Undefined name Union

(F821)


78-110: LGTM!

The IntensityConfig class is well-implemented with:

  • Comprehensive parameter validation
  • Appropriate default values
  • Clear documentation

Comment on lines +1 to +10
import attrs
from omegaconf import MISSING
from typing import Optional, Tuple, List, Dict


"""Serializable configuration classes for specifying all data configuration parameters.

These configuration classes are intended to specify all
the parameters required to initialize the data config.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Optimize imports

The imports need the following adjustments:

  • Remove unused imports: List and Dict
  • Add missing import: Union from typing

Apply this diff to fix the imports:

 import attrs
 from omegaconf import MISSING
-from typing import Optional, Tuple, List, Dict
+from typing import Optional, Tuple, Union
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import attrs
from omegaconf import MISSING
from typing import Optional, Tuple, List, Dict
"""Serializable configuration classes for specifying all data configuration parameters.
These configuration classes are intended to specify all
the parameters required to initialize the data config.
"""
import attrs
from omegaconf import MISSING
from typing import Optional, Tuple, Union
"""Serializable configuration classes for specifying all data configuration parameters.
These configuration classes are intended to specify all
the parameters required to initialize the data config.
"""
🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import

(F401)


3-3: typing.Dict imported but unused

Remove unused import

(F401)

Comment on lines +34 to +36
preprocessing: PreprocessingConfig = attrs.field(factory=PreprocessingConfig)
use_augmentations_train: bool = False
augmentation_config: Optional[AugmentationConfig] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Use forward references for configuration classes

To avoid circular imports and undefined names, use string literals for type hints of configuration classes.

Apply this diff:

-    preprocessing: PreprocessingConfig = attrs.field(factory=PreprocessingConfig)
-    augmentation_config: Optional[AugmentationConfig] = None
+    preprocessing: 'PreprocessingConfig' = attrs.field(
+        factory=lambda: PreprocessingConfig()
+    )
+    augmentation_config: Optional['AugmentationConfig'] = None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
preprocessing: PreprocessingConfig = attrs.field(factory=PreprocessingConfig)
use_augmentations_train: bool = False
augmentation_config: Optional[AugmentationConfig] = None
preprocessing: 'PreprocessingConfig' = attrs.field(
factory=lambda: PreprocessingConfig()
)
use_augmentations_train: bool = False
augmentation_config: Optional['AugmentationConfig'] = None
🧰 Tools
🪛 Ruff

34-34: Undefined name PreprocessingConfig

(F821)


34-34: Undefined name PreprocessingConfig

(F821)


36-36: Undefined name AugmentationConfig

(F821)

Comment on lines +113 to +152
@attrs.define
class GeometricConfig:
"""
Configuration of Geometric (Optional)

Attributes:
rotation: (float) Angles in degrees as a scalar float of the amount of rotation. A random angle in (-rotation, rotation) will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation.
scale: (float) scaling factor interval. If (a, b) represents isotropic scaling, the scale is randomly sampled from the range a <= scale <= b. If (a, b, c, d), the scale is randomly sampled from the range a <= scale_x <= b, c <= scale_y <= d Default: None.
translate_width: (float) Maximum absolute fraction for horizontal translation. For example, if translate_width=a, then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a. Will not translate by default.
translate_height: (float) Maximum absolute fraction for vertical translation. For example, if translate_height=a, then vertical shift is randomly sampled in the range -img_height * a < dy < img_height * a. Will not translate by default.
affine_p: (float) Probability of applying random affine transformations. Default=0.0
erase_scale_min: (float) Minimum value of range of proportion of erased area against input image. Default: 0.0001.
erase_scale_max: (float) Maximum value of range of proportion of erased area against input image. Default: 0.01.
erase_ration_min: (float) Minimum value of range of aspect ratio of erased area. Default: 1.
erase_ratio_max: (float) Maximum value of range of aspect ratio of erased area. Default: 1.
erase_p: (float) Probability of applying random erase. Default=0.0
mixup_lambda: (float) min-max value of mixup strength. Default is 0-1. Default: None.
mixup_p: (float) Probability of applying random mixup v2. Default=0.0
input_key: (str) Can be image or instance. The input_key instance expects the KorniaAugmenter to follow the InstanceCropper else image otherwise for default.
random_crop_p: (float) Probability of applying random crop.
random_crop_height: (int) Desired output height of the random crop.
random_crop_width: (int) Desired output height of the random crop.
"""

rotation: float = 0.0
scale: Optional[Tuple[float, float, float, float]] = None
translate_width: float = 0.0
translate_height: float = 0.0
affine_p: float = attrs.field(default=0.0, validator=validate_proportion)
erase_scale_min: float = 0.0001
erase_scale_max: float = 0.01
erase_ratio_min: float = 1.0
erase_ratio_max: float = 1.0
erase_p: float = attrs.field(default=0.0, validator=validate_proportion)
mixup_lambda: Optional[float] = None
mixup_p: float = attrs.field(default=0.0, validator=validate_proportion)
input_key: str = "image"
random_crop_p: Optional[float] = None
random_crop_height: Optional[int] = None
random_crop_width: Optional[int] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix documentation and attribute name issues

There are two issues in the GeometricConfig class:

  1. Typo in attribute documentation: "erase_ration_min" should be "erase_ratio_min"
  2. Documentation error: random_crop_width's description says "output height" instead of "output width"

Apply this diff to fix the documentation:

-        erase_ration_min: (float) Minimum value of range of aspect ratio of erased area. Default: 1.
+        erase_ratio_min: (float) Minimum value of range of aspect ratio of erased area. Default: 1.
-        random_crop_width: (int) Desired output height of the random crop.
+        random_crop_width: (int) Desired output width of the random crop.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@attrs.define
class GeometricConfig:
"""
Configuration of Geometric (Optional)
Attributes:
rotation: (float) Angles in degrees as a scalar float of the amount of rotation. A random angle in (-rotation, rotation) will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation.
scale: (float) scaling factor interval. If (a, b) represents isotropic scaling, the scale is randomly sampled from the range a <= scale <= b. If (a, b, c, d), the scale is randomly sampled from the range a <= scale_x <= b, c <= scale_y <= d Default: None.
translate_width: (float) Maximum absolute fraction for horizontal translation. For example, if translate_width=a, then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a. Will not translate by default.
translate_height: (float) Maximum absolute fraction for vertical translation. For example, if translate_height=a, then vertical shift is randomly sampled in the range -img_height * a < dy < img_height * a. Will not translate by default.
affine_p: (float) Probability of applying random affine transformations. Default=0.0
erase_scale_min: (float) Minimum value of range of proportion of erased area against input image. Default: 0.0001.
erase_scale_max: (float) Maximum value of range of proportion of erased area against input image. Default: 0.01.
erase_ration_min: (float) Minimum value of range of aspect ratio of erased area. Default: 1.
erase_ratio_max: (float) Maximum value of range of aspect ratio of erased area. Default: 1.
erase_p: (float) Probability of applying random erase. Default=0.0
mixup_lambda: (float) min-max value of mixup strength. Default is 0-1. Default: None.
mixup_p: (float) Probability of applying random mixup v2. Default=0.0
input_key: (str) Can be image or instance. The input_key instance expects the KorniaAugmenter to follow the InstanceCropper else image otherwise for default.
random_crop_p: (float) Probability of applying random crop.
random_crop_height: (int) Desired output height of the random crop.
random_crop_width: (int) Desired output height of the random crop.
"""
rotation: float = 0.0
scale: Optional[Tuple[float, float, float, float]] = None
translate_width: float = 0.0
translate_height: float = 0.0
affine_p: float = attrs.field(default=0.0, validator=validate_proportion)
erase_scale_min: float = 0.0001
erase_scale_max: float = 0.01
erase_ratio_min: float = 1.0
erase_ratio_max: float = 1.0
erase_p: float = attrs.field(default=0.0, validator=validate_proportion)
mixup_lambda: Optional[float] = None
mixup_p: float = attrs.field(default=0.0, validator=validate_proportion)
input_key: str = "image"
random_crop_p: Optional[float] = None
random_crop_height: Optional[int] = None
random_crop_width: Optional[int] = None
@attrs.define
class GeometricConfig:
"""
Configuration of Geometric (Optional)
Attributes:
rotation: (float) Angles in degrees as a scalar float of the amount of rotation. A random angle in (-rotation, rotation) will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation.
scale: (float) scaling factor interval. If (a, b) represents isotropic scaling, the scale is randomly sampled from the range a <= scale <= b. If (a, b, c, d), the scale is randomly sampled from the range a <= scale_x <= b, c <= scale_y <= d Default: None.
translate_width: (float) Maximum absolute fraction for horizontal translation. For example, if translate_width=a, then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a. Will not translate by default.
translate_height: (float) Maximum absolute fraction for vertical translation. For example, if translate_height=a, then vertical shift is randomly sampled in the range -img_height * a < dy < img_height * a. Will not translate by default.
affine_p: (float) Probability of applying random affine transformations. Default=0.0
erase_scale_min: (float) Minimum value of range of proportion of erased area against input image. Default: 0.0001.
erase_scale_max: (float) Maximum value of range of proportion of erased area against input image. Default: 0.01.
erase_ratio_min: (float) Minimum value of range of aspect ratio of erased area. Default: 1.
erase_ratio_max: (float) Maximum value of range of aspect ratio of erased area. Default: 1.
erase_p: (float) Probability of applying random erase. Default=0.0
mixup_lambda: (float) min-max value of mixup strength. Default is 0-1. Default: None.
mixup_p: (float) Probability of applying random mixup v2. Default=0.0
input_key: (str) Can be image or instance. The input_key instance expects the KorniaAugmenter to follow the InstanceCropper else image otherwise for default.
random_crop_p: (float) Probability of applying random crop.
random_crop_height: (int) Desired output height of the random crop.
random_crop_width: (int) Desired output width of the random crop.
"""
rotation: float = 0.0
scale: Optional[Tuple[float, float, float, float]] = None
translate_width: float = 0.0
translate_height: float = 0.0
affine_p: float = attrs.field(default=0.0, validator=validate_proportion)
erase_scale_min: float = 0.0001
erase_scale_max: float = 0.01
erase_ratio_min: float = 1.0
erase_ratio_max: float = 1.0
erase_p: float = attrs.field(default=0.0, validator=validate_proportion)
mixup_lambda: Optional[float] = None
mixup_p: float = attrs.field(default=0.0, validator=validate_proportion)
input_key: str = "image"
random_crop_p: Optional[float] = None
random_crop_height: Optional[int] = None
random_crop_width: Optional[int] = None

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 16c8f26 and 43c777b.

📒 Files selected for processing (1)
  • sleap_nn/config/trainer_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/trainer_config.py

56-56: Undefined name DataLoaderConfig

(F821)


56-56: Undefined name DataLoaderConfig

(F821)


57-57: Undefined name DataLoaderConfig

(F821)


57-57: Undefined name DataLoaderConfig

(F821)


58-58: Undefined name ModelCkptConfig

(F821)


58-58: Undefined name ModelCkptConfig

(F821)


59-59: Undefined name Union

(F821)


59-59: Undefined name List

(F821)


62-62: Undefined name Optional

(F821)


64-64: Undefined name Optional

(F821)


68-68: Undefined name Optional

(F821)


69-69: Undefined name Optional

(F821)


69-69: Undefined name WandBConfig

(F821)


70-70: Undefined name Optional

(F821)


70-70: Undefined name OptimizerConfig

(F821)


70-70: Undefined name OptimizerConfig

(F821)


71-71: Undefined name LRSchedulerConfig

(F821)


71-71: Undefined name LRSchedulerConfig

(F821)


72-72: Undefined name EarlyStoppingConfig

(F821)


72-72: Undefined name attr

(F821)


72-72: Undefined name EarlyStoppingConfig

(F821)


85-85: Undefined name Dict

(F821)


85-85: Undefined name Text

(F821)


85-85: Undefined name Any

(F821)


85-85: Undefined name TrainerJobConfig

(F821)


96-96: Undefined name Omega

(F821)


99-99: Undefined name Text

(F821)


99-99: Undefined name TrainingJobConfig

(F821)


108-108: Undefined name json

(F821)


112-112: Undefined name Text

(F821)


112-112: Undefined name TrainingJobConfig

(F821)


125-125: Undefined name Text

(F821)


125-125: Undefined name TrainingJobConfig

(F821)


140-140: Undefined name Text

(F821)


140-140: Undefined name TrainingJobConfig

(F821)


153-153: Undefined name DictConfig

(F821)


168-168: Undefined name json

(F821)


179-179: Undefined name Text

(F821)


188-188: Undefined name Text

(F821)


222-222: Undefined name Optional

(F821)


239-239: Undefined name Optional

(F821)


240-240: Undefined name Optional

(F821)


241-241: Undefined name Optional

(F821)


242-242: Undefined name Optional

(F821)


243-243: Undefined name Optional

(F821)


244-244: Undefined name Optional

(F821)


245-245: Undefined name Optional

(F821)


245-245: Undefined name List

(F821)


281-281: Undefined name Union

(F821)


281-281: Undefined name List

(F821)

🔇 Additional comments (2)
sleap_nn/config/trainer_config.py (2)

1-2: ⚠️ Potential issue

Add missing imports for typing and json modules.

The code uses types from the typing module (e.g., Union, List, Optional, Dict, Any, Text) and the json module, but these are not imported. Add the necessary imports at the top of the file to resolve the undefined names.

Apply this diff to add the missing imports:

 import attrs
+from typing import Any, Dict, List, Optional, Text, Union
+import json
 from omegaconf import OmegaConf

Likely invalid or redundant comment.


85-112: ⚠️ Potential issue

Correct return type annotations to match the class name.

The methods from_dict, from_json, and from_yaml have return type annotations as "TrainingJobConfig", but the class is named TrainerConfig. Update the return type annotations to "TrainerConfig" for consistency.

Apply this diff to correct the return type annotations:

     @classmethod
-    def from_dict(cls, config_dict: Dict[Text, Any]) -> "TrainingJobConfig":
+    def from_dict(cls, config_dict: Dict[Text, Any]) -> "TrainerConfig":

     @classmethod
-    def from_json(cls, json_data: Text) -> "TrainingJobConfig":
+    def from_json(cls, json_data: Text) -> "TrainerConfig":

     @classmethod
-    def from_yaml(cls, yaml_data: Text) -> "TrainingJobConfig":
+    def from_yaml(cls, yaml_data: Text) -> "TrainerConfig"

Likely invalid or redundant comment.

🧰 Tools
🪛 Ruff

85-85: Undefined name Dict

(F821)


85-85: Undefined name Text

(F821)


85-85: Undefined name Any

(F821)


85-85: Undefined name TrainerJobConfig

(F821)


96-96: Undefined name Omega

(F821)


99-99: Undefined name Text

(F821)


99-99: Undefined name TrainingJobConfig

(F821)


108-108: Undefined name json

(F821)


112-112: Undefined name Text

(F821)


112-112: Undefined name TrainingJobConfig

(F821)

"""
# Convert dictionary to an OmegaConf config, then instantiate from it.
config = OmegaConf.create(config_dict)
return Omega.to_object(config, cls)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Replace Omega with OmegaConf to fix undefined name.

In the from_dict method, Omega.to_object is used, but Omega is undefined. It should be OmegaConf.to_object to correctly utilize the omegaconf module.

Apply this diff to correct the method call:

-        return Omega.to_object(config, cls)
+        return OmegaConf.to_object(config, cls)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
return Omega.to_object(config, cls)
return OmegaConf.to_object(config, cls)
🧰 Tools
🪛 Ruff

96-96: Undefined name Omega

(F821)

wandb: Optional[WandBConfig] = attrs.field(init=False)
optimizer: Optional[OptimizerConfig] = attrs.field(factory=OptimizerConfig)
lr_scheduler: LRSchedulerConfig = attrs.field(factory=LRSchedulerConfig)
early_stopping: EarlyStoppingConfig = attr.field(factory=EarlyStoppingConfig)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Correct typo in attrs.field usage.

The attribute early_stopping uses attr.field instead of attrs.field. Replace attr.field with attrs.field to fix the typo and ensure consistency.

Apply this diff to fix the typo:

-    early_stopping: EarlyStoppingConfig = attr.field(factory=EarlyStoppingConfig)
+    early_stopping: EarlyStoppingConfig = attrs.field(factory=EarlyStoppingConfig)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
early_stopping: EarlyStoppingConfig = attr.field(factory=EarlyStoppingConfig)
early_stopping: EarlyStoppingConfig = attrs.field(factory=EarlyStoppingConfig)
🧰 Tools
🪛 Ruff

72-72: Undefined name EarlyStoppingConfig

(F821)


72-72: Undefined name attr

(F821)


72-72: Undefined name EarlyStoppingConfig

(F821)

Comment on lines 59 to 60
trainer_devices: Union[int, List[int], str] = "auto"
trainer_accelerator: str = "auto"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use attrs.field for attribute definitions for consistency.

In TrainerConfig, the attributes trainer_devices and trainer_accelerator are assigned default values directly instead of using attrs.field. For consistency and to leverage attrs features, use attrs.field to define these attributes.

Apply this diff to use attrs.field:

-    trainer_devices: Union[int, List[int], str] = "auto"
-    trainer_accelerator: str = "auto"
+    trainer_devices: Union[int, List[int], str] = attrs.field(default="auto")
+    trainer_accelerator: str = attrs.field(default="auto")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
trainer_devices: Union[int, List[int], str] = "auto"
trainer_accelerator: str = "auto"
trainer_devices: Union[int, List[int], str] = attrs.field(default="auto")
trainer_accelerator: str = attrs.field(default="auto")
🧰 Tools
🪛 Ruff

59-59: Undefined name Union

(F821)


59-59: Undefined name List

(F821)

config = OmegaConf.load(filename)
return OmegaConf.to_object(config, cls)

def to_dict(self) -> DictConfig:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Import DictConfig from omegaconf module.

The method to_dict returns DictConfig, but DictConfig is not imported. Add the import at the top of the file to resolve the undefined name.

Apply this diff to import DictConfig:

 from omegaconf import OmegaConf
+from omegaconf import DictConfig

Or combine the imports:

-from omegaconf import OmegaConf
+from omegaconf import OmegaConf, DictConfig
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def to_dict(self) -> DictConfig:
from omegaconf import OmegaConf, DictConfig
def to_dict(self) -> DictConfig:
🧰 Tools
🪛 Ruff

153-153: Undefined name DictConfig

(F821)

Comment on lines 125 to 140
def load_json(cls, filename: Text) -> "TrainingJobConfig":
"""Load a training job configuration from a json file.

Arguments:
filename: Path to a training job configuration JSON file or a directory
containing `"training_job.json"`.

Returns:
A TrainingJobConfig instance parsed from the json file.
"""
with open(filename, "r") as f:
json_data = f.read()
return cls.from_json(json_data)

@classmethod
def load_yaml(cls, filename:Text) -> "TrainingJobConfig":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Update return type annotations in loader methods.

The methods load_json and load_yaml also have incorrect return type annotations referencing "TrainingJobConfig". Update them to "TrainerConfig" to match the class name.

Apply this diff to correct the return type annotations:

     @classmethod
-    def load_json(cls, filename: Text) -> "TrainingJobConfig":
+    def load_json(cls, filename: Text) -> "TrainerConfig":

     @classmethod
-    def load_yaml(cls, filename: Text) -> "TrainingJobConfig":
+    def load_yaml(cls, filename: Text) -> "TrainerConfig":
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def load_json(cls, filename: Text) -> "TrainingJobConfig":
"""Load a training job configuration from a json file.
Arguments:
filename: Path to a training job configuration JSON file or a directory
containing `"training_job.json"`.
Returns:
A TrainingJobConfig instance parsed from the json file.
"""
with open(filename, "r") as f:
json_data = f.read()
return cls.from_json(json_data)
@classmethod
def load_yaml(cls, filename:Text) -> "TrainingJobConfig":
def load_json(cls, filename: Text) -> "TrainerConfig":
"""Load a training job configuration from a json file.
Arguments:
filename: Path to a training job configuration JSON file or a directory
containing `"training_job.json"`.
Returns:
A TrainingJobConfig instance parsed from the json file.
"""
with open(filename, "r") as f:
json_data = f.read()
return cls.from_json(json_data)
@classmethod
def load_yaml(cls, filename:Text) -> "TrainerConfig":
🧰 Tools
🪛 Ruff

125-125: Undefined name Text

(F821)


125-125: Undefined name TrainingJobConfig

(F821)


140-140: Undefined name Text

(F821)


140-140: Undefined name TrainingJobConfig

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 43c777b and 7b71a8b.

📒 Files selected for processing (1)
  • sleap_nn/config/trainer_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/trainer_config.py

56-56: Undefined name DataLoaderConfig

(F821)


56-56: Undefined name DataLoaderConfig

(F821)


57-57: Undefined name DataLoaderConfig

(F821)


57-57: Undefined name DataLoaderConfig

(F821)


58-58: Undefined name ModelCkptConfig

(F821)


58-58: Undefined name ModelCkptConfig

(F821)


59-59: Undefined name Union

(F821)


59-59: Undefined name List

(F821)


62-62: Undefined name Optional

(F821)


64-64: Undefined name Optional

(F821)


68-68: Undefined name Optional

(F821)


69-69: Undefined name Optional

(F821)


69-69: Undefined name WandBConfig

(F821)


70-70: Undefined name Optional

(F821)


70-70: Undefined name OptimizerConfig

(F821)


70-70: Undefined name OptimizerConfig

(F821)


71-71: Undefined name LRSchedulerConfig

(F821)


71-71: Undefined name LRSchedulerConfig

(F821)


72-72: Undefined name EarlyStoppingConfig

(F821)


72-72: Undefined name attr

(F821)


72-72: Undefined name EarlyStoppingConfig

(F821)


85-85: Undefined name Dict

(F821)


85-85: Undefined name Text

(F821)


85-85: Undefined name Any

(F821)


96-96: Undefined name Omega

(F821)


99-99: Undefined name Text

(F821)


108-108: Undefined name json

(F821)


112-112: Undefined name Text

(F821)


125-125: Undefined name Text

(F821)


140-140: Undefined name Text

(F821)


153-153: Undefined name DictConfig

(F821)


168-168: Undefined name json

(F821)


179-179: Undefined name Text

(F821)


188-188: Undefined name Text

(F821)


223-223: Undefined name Optional

(F821)


240-240: Undefined name Optional

(F821)


241-241: Undefined name Optional

(F821)


242-242: Undefined name Optional

(F821)


243-243: Undefined name Optional

(F821)


244-244: Undefined name Optional

(F821)


245-245: Undefined name Optional

(F821)


246-246: Undefined name Optional

(F821)


246-246: Undefined name List

(F821)


282-282: Undefined name Union

(F821)


282-282: Undefined name List

(F821)

🔇 Additional comments (6)
sleap_nn/config/trainer_config.py (6)

96-96: ⚠️ Potential issue

Fix incorrect OmegaConf usage.

The Omega.to_object call is incorrect. It should be OmegaConf.to_object.

-        return Omega.to_object(config, cls)
+        return OmegaConf.to_object(config, cls)

Likely invalid or redundant comment.

🧰 Tools
🪛 Ruff

96-96: Undefined name Omega

(F821)


295-297: 🛠️ Refactor suggestion

Use attrs.field in EarlyStoppingConfig for consistency.

-    stop_training_on_plateau: bool = False
-    min_delta: float = 0.0
-    patience: int = 1
+    stop_training_on_plateau: bool = attrs.field(default=False)
+    min_delta: float = attrs.field(default=0.0)
+    patience: int = attrs.field(default=1)

Likely invalid or redundant comment.


258-259: 🛠️ Refactor suggestion

Use attrs.field in OptimizerConfig for consistency.

-    lr: float = 1e-3
-    amsgrad: bool = False
+    lr: float = attrs.field(default=1e-3)
+    amsgrad: bool = attrs.field(default=False)

Likely invalid or redundant comment.


208-210: 🛠️ Refactor suggestion

Use attrs.field in DataLoaderConfig for consistency.

For consistency with TrainerConfig and to leverage attrs features, use attrs.field for attribute definitions.

-    batch_size: int = 1
-    shuffle: bool = False
-    num_workers: int = 0
+    batch_size: int = attrs.field(default=1)
+    shuffle: bool = attrs.field(default=False)
+    num_workers: int = attrs.field(default=0)

Likely invalid or redundant comment.


222-223: 🛠️ Refactor suggestion

Use attrs.field in ModelCkptConfig for consistency.

-    save_top_k: int = 1
-    save_last: Optional[bool] = None
+    save_top_k: int = attrs.field(default=1)
+    save_last: Optional[bool] = attrs.field(default=None)

Likely invalid or redundant comment.

🧰 Tools
🪛 Ruff

223-223: Undefined name Optional

(F821)


72-72: ⚠️ Potential issue

Fix incorrect attrs import usage.

The attr.field usage is incorrect and inconsistent with other field definitions in the class.

-    early_stopping: EarlyStoppingConfig = attr.field(factory=EarlyStoppingConfig)
+    early_stopping: EarlyStoppingConfig = attrs.field(factory=EarlyStoppingConfig)

Likely invalid or redundant comment.

🧰 Tools
🪛 Ruff

72-72: Undefined name EarlyStoppingConfig

(F821)


72-72: Undefined name attr

(F821)


72-72: Undefined name EarlyStoppingConfig

(F821)

return OmegaConf.to_object(config, cls)

@classmethod
def load_json(cls, filename: Text) -> "TrainerConfig":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix incorrect return type hints.

The return type hints in load_json and load_yaml methods incorrectly reference "TrainingJobConfig" instead of "TrainerConfig".

-    def load_json(cls, filename: Text) -> "TrainingJobConfig":
+    def load_json(cls, filename: Text) -> "TrainerConfig":

-    def load_yaml(cls, filename: Text) -> "TrainingJobConfig":
+    def load_yaml(cls, filename: Text) -> "TrainerConfig":

Also applies to: 140-140

🧰 Tools
🪛 Ruff

125-125: Undefined name Text

(F821)

Comment on lines 50 to 54
optimizer_name: (str) Optimizer to be used. One of ["Adam", "AdamW"].
optimizer:
lr_scheduler:
early_stopping:
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add missing optimizer_name attribute.

The optimizer_name attribute is documented but not implemented in the class.

     optimizer: Optional[OptimizerConfig] = attrs.field(factory=OptimizerConfig)
+    optimizer_name: str = attrs.field(default="Adam")
     lr_scheduler: LRSchedulerConfig = attrs.field(factory=LRSchedulerConfig)

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +1 to +2
import attrs
from omegaconf import OmegaConf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add missing imports.

Several required imports are missing. Add these imports at the top of the file:

 import attrs
 from omegaconf import OmegaConf
+from omegaconf import DictConfig
+from typing import Any, Dict, List, Optional, Text, Union
+import json
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import attrs
from omegaconf import OmegaConf
import attrs
from omegaconf import OmegaConf
from omegaconf import DictConfig
from typing import Any, Dict, List, Optional, Text, Union
import json

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 7b71a8b and 46df99a.

📒 Files selected for processing (1)
  • sleap_nn/config/trainer_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/trainer_config.py

56-56: Undefined name DataLoaderConfig

(F821)


56-56: Undefined name DataLoaderConfig

(F821)


57-57: Undefined name DataLoaderConfig

(F821)


57-57: Undefined name DataLoaderConfig

(F821)


58-58: Undefined name ModelCkptConfig

(F821)


58-58: Undefined name ModelCkptConfig

(F821)


59-59: Undefined name Union

(F821)


59-59: Undefined name List

(F821)


62-62: Undefined name Optional

(F821)


64-64: Undefined name Optional

(F821)


68-68: Undefined name Optional

(F821)


69-69: Undefined name Optional

(F821)


69-69: Undefined name WandBConfig

(F821)


70-70: Undefined name Optional

(F821)


70-70: Undefined name OptimizerConfig

(F821)


70-70: Undefined name OptimizerConfig

(F821)


71-71: Undefined name LRSchedulerConfig

(F821)


71-71: Undefined name LRSchedulerConfig

(F821)


72-72: Undefined name EarlyStoppingConfig

(F821)


72-72: Undefined name attr

(F821)


72-72: Undefined name EarlyStoppingConfig

(F821)


85-85: Undefined name Text

(F821)


98-98: Undefined name Text

(F821)


120-120: Undefined name Text

(F821)


155-155: Undefined name Optional

(F821)


172-172: Undefined name Optional

(F821)


173-173: Undefined name Optional

(F821)


174-174: Undefined name Optional

(F821)


175-175: Undefined name Optional

(F821)


176-176: Undefined name Optional

(F821)


177-177: Undefined name Optional

(F821)


178-178: Undefined name Optional

(F821)


178-178: Undefined name List

(F821)


214-214: Undefined name Union

(F821)


214-214: Undefined name List

(F821)

🔇 Additional comments (7)
sleap_nn/config/trainer_config.py (7)

154-155: 🛠️ Refactor suggestion

Use attrs.field in ModelCkptConfig for consistency.

For consistency with other classes and to leverage attrs features, use attrs.field for attribute definitions.

-    save_top_k: int = 1
-    save_last: Optional[bool] = None
+    save_top_k: int = attrs.field(default=1)
+    save_last: Optional[bool] = attrs.field(default=None)

Likely invalid or redundant comment.

🧰 Tools
🪛 Ruff

155-155: Undefined name Optional

(F821)


1-2: ⚠️ Potential issue

Add missing imports.

Several required imports are missing. Add these imports at the top of the file:

 import attrs
 from omegaconf import OmegaConf
+from omegaconf import DictConfig
+from typing import Any, Dict, List, Optional, Text, Union
+import json

Likely invalid or redundant comment.


72-72: ⚠️ Potential issue

Fix incorrect attrs import usage.

The early_stopping attribute uses attr.field instead of attrs.field. This will cause a runtime error.

-    early_stopping: EarlyStoppingConfig = attr.field(factory=EarlyStoppingConfig)
+    early_stopping: EarlyStoppingConfig = attrs.field(factory=EarlyStoppingConfig)

Likely invalid or redundant comment.

🧰 Tools
🪛 Ruff

72-72: Undefined name EarlyStoppingConfig

(F821)


72-72: Undefined name attr

(F821)


72-72: Undefined name EarlyStoppingConfig

(F821)


227-229: 🛠️ Refactor suggestion

Use attrs.field in EarlyStoppingConfig for consistency.

For consistency with other classes and to leverage attrs features, use attrs.field for attribute definitions.

-    stop_training_on_plateau: bool = False
-    min_delta: float = 0.0
-    patience: int = 1
+    stop_training_on_plateau: bool = attrs.field(default=False)
+    min_delta: float = attrs.field(default=0.0)
+    patience: int = attrs.field(default=1)

Likely invalid or redundant comment.


50-54: ⚠️ Potential issue

Implement documented optimizer_name attribute.

The optimizer_name attribute is documented in the class docstring but not implemented in the class definition.

+    optimizer_name: str = attrs.field(default="Adam")
     optimizer: Optional[OptimizerConfig] = attrs.field(factory=OptimizerConfig)

Likely invalid or redundant comment.


140-142: 🛠️ Refactor suggestion

Use attrs.field in DataLoaderConfig for consistency.

For consistency with other classes and to leverage attrs features, use attrs.field for attribute definitions.

-    batch_size: int = 1
-    shuffle: bool = False
-    num_workers: int = 0
+    batch_size: int = attrs.field(default=1)
+    shuffle: bool = attrs.field(default=False)
+    num_workers: int = attrs.field(default=0)

Likely invalid or redundant comment.


59-60: 🛠️ Refactor suggestion

Use attrs.field for consistency.

For consistency with other attributes and to leverage attrs features, use attrs.field for these attributes.

-    trainer_devices: Union[int, List[int], str] = "auto"
-    trainer_accelerator: str = "auto"
+    trainer_devices: Union[int, List[int], str] = attrs.field(default="auto")
+    trainer_accelerator: str = attrs.field(default="auto")

Likely invalid or redundant comment.

🧰 Tools
🪛 Ruff

59-59: Undefined name Union

(F821)


59-59: Undefined name List

(F821)

Comment on lines 208 to 214
mode: str = "min"
threshold: float = 1e-4
threshold_mode: str = "rel"
cooldown: int = 0
patience: int = 10
factor: float = 0.1
min_lr: Union[float, List[float]] = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use attrs.field in LRSchedulerConfig for consistency.

For consistency with other classes and to leverage attrs features, use attrs.field for attribute definitions.

-    mode: str = "min"
-    threshold: float = 1e-4
-    threshold_mode: str = "rel"
-    cooldown: int = 0
-    patience: int = 10
-    factor: float = 0.1
-    min_lr: Union[float, List[float]] = 0.0
+    mode: str = attrs.field(default="min")
+    threshold: float = attrs.field(default=1e-4)
+    threshold_mode: str = attrs.field(default="rel")
+    cooldown: int = attrs.field(default=0)
+    patience: int = attrs.field(default=10)
+    factor: float = attrs.field(default=0.1)
+    min_lr: Union[float, List[float]] = attrs.field(default=0.0)

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff

214-214: Undefined name Union

(F821)


214-214: Undefined name List

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Outside diff range and nitpick comments (5)
tests/config/test_trainer_config.py (1)

13-27: Consider adding edge case tests for DataLoaderConfig.

While the basic functionality is well tested, consider adding tests for:

  • Negative batch sizes
  • Very large batch sizes
  • Negative number of workers

Example addition:

def test_dataloader_config_edge_cases():
    with pytest.raises(ValueError):
        OmegaConf.structured(DataLoaderConfig(batch_size=-1))
    with pytest.raises(ValueError):
        OmegaConf.structured(DataLoaderConfig(num_workers=-1))
sleap_nn/config/trainer_config.py (4)

43-45: Use attrs.field() for attribute definitions.

For consistency with other classes and to leverage attrs features like validation, use attrs.field().

-    batch_size: int = 1
-    shuffle: bool = False
-    num_workers: int = 0
+    batch_size: int = attrs.field(default=1)
+    shuffle: bool = attrs.field(default=False)
+    num_workers: int = attrs.field(default=0)

117-117: Improve type hint for min_lr attribute.

The min_lr attribute is validated to be either float or List[float], but typed as Any. Update the type hint to match the validation.

-    min_lr: Any = 0.0 
+    min_lr: Union[float, List[float]] = attrs.field(default=0.0)

173-173: Improve type hint for trainer_devices attribute.

The trainer_devices attribute has specific validation rules but is typed as Any. Update the type hint to match the validation.

-    trainer_devices: Any = "auto"
+    trainer_devices: Union[int, List[int], str] = attrs.field(
+        default="auto",
+        validator=attrs.validators.instance_of((int, list, str))
+    )

253-254: Add error handling for file operations.

The file write operation should handle potential IOErrors.

-        with open(filename, "w") as f:
-            f.write(self.to_yaml())
+        try:
+            with open(filename, "w") as f:
+                f.write(self.to_yaml())
+        except IOError as e:
+            raise IOError(f"Failed to save configuration to {filename}: {e}")
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 46df99a and 110ca64.

📒 Files selected for processing (2)
  • sleap_nn/config/trainer_config.py (1 hunks)
  • tests/config/test_trainer_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.7.0)
sleap_nn/config/trainer_config.py

3-3: typing.Union imported but unused

Remove unused import: typing.Union

(F401)

tests/config/test_trainer_config.py

1-1: pytest imported but unused

Remove unused import: pytest

(F401)

🔇 Additional comments (2)
tests/config/test_trainer_config.py (1)

1-11: LGTM! Ignore the unused import warning.

The pytest import is required for the test framework to function properly, even though it's not explicitly used in the code. The static analysis warning can be safely ignored.

🧰 Tools
🪛 Ruff (0.7.0)

1-1: pytest imported but unused

Remove unused import: pytest

(F401)

sleap_nn/config/trainer_config.py (1)

1-30: LGTM! Well-documented module with clear design philosophy.

The imports are appropriate and the module documentation clearly explains the purpose, design decisions, and benefits of the configuration system.

🧰 Tools
🪛 Ruff (0.7.0)

3-3: typing.Union imported but unused

Remove unused import: typing.Union

(F401)

Comment on lines +62 to +74
def test_lr_scheduler_config():
# Check default values
conf = OmegaConf.structured(LRSchedulerConfig)
assert conf.mode == "min"
assert conf.threshold == 1e-4
assert conf.patience == 10

# Test customization
custom_conf = OmegaConf.structured(LRSchedulerConfig(mode="max", patience=5, factor=0.5))
assert custom_conf.mode == "max"
assert custom_conf.patience == 5
assert custom_conf.factor == 0.5

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance LRSchedulerConfig test coverage.

The current tests don't validate the mode field constraints. Consider adding tests for:

  • Invalid mode values
  • Boundary conditions for patience and threshold
  • Edge cases for the factor value (0.0, 1.0, >1.0)

Example addition:

def test_lr_scheduler_config_validation():
    with pytest.raises(ValueError):
        OmegaConf.structured(LRSchedulerConfig(mode="invalid"))
    with pytest.raises(ValueError):
        OmegaConf.structured(LRSchedulerConfig(factor=0.0))

Comment on lines +89 to +113
def test_trainer_config():
# Check default values
conf = OmegaConf.structured(TrainerConfig)
assert conf.train_data_loader.batch_size == 1
assert conf.val_data_loader.shuffle is False
assert conf.model_ckpt.save_top_k == 1
assert conf.optimizer.lr == 1e-3
assert conf.lr_scheduler.mode == "min"
assert conf.early_stopping.patience == 1
assert conf.use_wandb is False
assert conf.save_ckpt_path == "./"

# Test customization
custom_conf = OmegaConf.structured(
TrainerConfig(
max_epochs=20,
train_data_loader=DataLoaderConfig(batch_size=32),
optimizer=OptimizerConfig(lr=0.01),
use_wandb=True,
)
)
assert custom_conf.max_epochs == 20
assert custom_conf.train_data_loader.batch_size == 32
assert custom_conf.optimizer.lr == 0.01
assert custom_conf.use_wandb is True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance TrainerConfig integration testing.

While the basic integration is tested, consider adding:

  1. Tests for config relationships (e.g., early stopping patience vs lr scheduler patience)
  2. Validation of config serialization/deserialization
  3. Tests for invalid configuration combinations

Example additions:

def test_trainer_config_relationships():
    # Test that early stopping patience is less than lr scheduler patience
    conf = OmegaConf.structured(
        TrainerConfig(
            early_stopping=EarlyStoppingConfig(patience=10),
            lr_scheduler=LRSchedulerConfig(patience=5)
        )
    )
    with pytest.raises(ValueError):
        conf.validate()  # Assuming validate() method exists

def test_trainer_config_serialization():
    conf = TrainerConfig(max_epochs=20)
    yaml_str = OmegaConf.to_yaml(conf)
    loaded_conf = OmegaConf.create(yaml_str)
    assert OmegaConf.to_container(conf) == OmegaConf.to_container(loaded_conf)

entity: Optional[str] = None
project: Optional[str] = None
name: Optional[str] = None
api_key: Optional[str] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Secure sensitive data in serialization.

The api_key in WandBConfig should be masked when serialized to prevent accidental exposure in configuration files.

Consider implementing custom serialization for the api_key field:

@attrs.define
class WandBConfig:
    # ... other attributes ...
    api_key: Optional[str] = attrs.field(
        default=None,
        repr=lambda value: '****' if value else None,
        metadata={'sensitive': True}
    )

Then update the to_yaml method in TrainerConfig to handle sensitive fields:

def to_yaml(self) -> str:
    config = self.to_dict()
    if self.wandb and self.wandb.api_key:
        config['wandb']['api_key'] = '****'
    return OmegaConf.to_yaml(config)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant