-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Omegaconfig PR1: basic functionality #97
base: main
Are you sure you want to change the base?
Conversation
WalkthroughThis pull request introduces significant changes to the configuration management of the Changes
Possibly related issues
Possibly related PRs
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #97 +/- ##
==========================================
+ Coverage 96.64% 97.34% +0.70%
==========================================
Files 23 38 +15
Lines 1818 3694 +1876
==========================================
+ Hits 1757 3596 +1839
- Misses 61 98 +37 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Outside diff range and nitpick comments (3)
sleap_nn/config/model_config.py (2)
1-2
: Consider adding explicit attribute definitions.The
ModelConfig
class is correctly using the@attr.s
decorator withauto_attribs=True
. However, the class body is empty, which is unusual. Consider adding explicit attribute definitions to improve code clarity and enable IDE autocompletion.Here's an example of how you could define the attributes:
@attr.s(auto_attribs=True) class ModelConfig: backbone: dict heads: dict base_checkpoint: strThis will make the class structure more explicit and easier to understand at a glance.
🧰 Tools
🪛 Ruff
1-1: Undefined name
attr
(F821)
3-9
: Docstring looks good, minor formatting suggestion.The docstring provides clear and informative descriptions of the class and its attributes. Well done!
Consider adding a period at the end of the last line for consistency:
- base_checkpoint: Path to model folder for loading a checkpoint. Should contain the .h5 file + base_checkpoint: Path to model folder for loading a checkpoint. Should contain the .h5 file.sleap_nn/config/training_job.py (1)
1-42
: Overall assessment and recommendationsThe
TrainingJobConfig
class provides a well-structured and documented foundation for managing training job configurations. However, there are a few improvements needed to make it complete and error-free:
- Add the missing imports for
attr
,DataConfig
, and other required classes.- Implement the remaining attributes mentioned in the class docstring.
- Consider adding type hints for all attributes to improve code readability and catch potential type-related issues early.
- If not already present in your project, consider adding a
requirements.txt
orsetup.py
file to specify theattrs
library as a dependency.Once these changes are made, the
TrainingJobConfig
class will be a robust and well-documented configuration management solution for your training jobs.🧰 Tools
🪛 Ruff
27-27: Undefined name
attr
(F821)
42-42: Undefined name
DataConfig
(F821)
42-42: Undefined name
attr
(F821)
42-42: Undefined name
DataConfig
(F821)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
- sleap_nn/config/data_config.py (1 hunks)
- sleap_nn/config/model_config.py (1 hunks)
- sleap_nn/config/training_job.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py
1-1: Undefined name
attr
(F821)
11-11: Undefined name
LabelsConfig
(F821)
11-11: Undefined name
attr
(F821)
11-11: Undefined name
LabelsConfig
(F821)
12-12: Undefined name
PreprocessingConfig
(F821)
12-12: Undefined name
attr
(F821)
12-12: Undefined name
PreprocessingConfig
(F821)
13-13: Undefined name
InstanceCroppingConfig
(F821)
13-13: Undefined name
attr
(F821)
13-13: Undefined name
InstanceCroppingConfig
(F821)
sleap_nn/config/model_config.py
1-1: Undefined name
attr
(F821)
sleap_nn/config/training_job.py
27-27: Undefined name
attr
(F821)
42-42: Undefined name
DataConfig
(F821)
42-42: Undefined name
attr
(F821)
42-42: Undefined name
DataConfig
(F821)
🔇 Additional comments (3)
sleap_nn/config/data_config.py (2)
3-9
: Well-structured and informative docstringThe class docstring is well-written, providing a clear description of the class purpose and each of its attributes. It follows good documentation practices, which will help other developers understand the purpose and usage of this configuration class.
1-13
: Overall well-structured and maintainable codeThe
DataConfig
class is well-designed and follows good coding practices. It uses theattr
library effectively for attribute management, has clear and descriptive naming, and includes informative docstrings. The suggested improvements (adding missing imports) are minor and easy to implement. Once these changes are made, the code will be more robust and easier to maintain.🧰 Tools
🪛 Ruff
1-1: Undefined name
attr
(F821)
11-11: Undefined name
LabelsConfig
(F821)
11-11: Undefined name
attr
(F821)
11-11: Undefined name
LabelsConfig
(F821)
12-12: Undefined name
PreprocessingConfig
(F821)
12-12: Undefined name
attr
(F821)
12-12: Undefined name
PreprocessingConfig
(F821)
13-13: Undefined name
InstanceCroppingConfig
(F821)
13-13: Undefined name
attr
(F821)
13-13: Undefined name
InstanceCroppingConfig
(F821)
sleap_nn/config/training_job.py (1)
1-25
: Excellent documentation!The file-level docstring provides a clear and comprehensive explanation of the purpose and design principles behind the configuration classes. It effectively communicates the separation of configuration from implementation and highlights the benefits of this approach, including easy serialization and human-readability.
sleap_nn/config/model_config.py
Outdated
@@ -0,0 +1,9 @@ | |||
@attr.s(auto_attribs=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add import statement for attr
.
The static analysis tool has flagged an undefined name attr
. This is because the import statement for the attr
module is missing.
Add the following import statement at the beginning of the file:
import attr
This will resolve the undefined name issue and make the usage of @attr.s
decorator valid.
🧰 Tools
🪛 Ruff
1-1: Undefined name
attr
(F821)
sleap_nn/config/training_job.py
Outdated
@attr.s(auto_attribs=True) | ||
class TrainingJobConfig: | ||
"""Configuration of a training job. | ||
|
||
Attributes: | ||
data: Configuration options related to the training data. | ||
model: Configuration options related to the model architecture. | ||
optimization: Configuration options related to the training. | ||
outputs: Configuration options related to outputs during training. | ||
name: Optional name for this configuration profile. | ||
description: Optional description of the configuration. | ||
sleap_version: Version of SLEAP that generated this configuration. | ||
filename: Path to this config file if it was loaded from disk. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add missing import for attr
The attr
module is used to define the TrainingJobConfig
class, but it's not imported. This is causing the static analysis tool to flag an error.
Please add the following import at the beginning of the file:
import attr
This will resolve the undefined name error for attr
.
🧰 Tools
🪛 Ruff
27-27: Undefined name
attr
(F821)
sleap_nn/config/training_job.py
Outdated
filename: Path to this config file if it was loaded from disk. | ||
""" | ||
|
||
data: DataConfig = attr.ib(factory=DataConfig) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add missing import for DataConfig
The DataConfig
class is used as the type for the data
attribute, but it's not imported. This is causing the static analysis tool to flag an error.
Please add the following import at the beginning of the file:
from sleap_nn.config.data import DataConfig
This will resolve the undefined name error for DataConfig
.
🧰 Tools
🪛 Ruff
42-42: Undefined name
DataConfig
(F821)
42-42: Undefined name
attr
(F821)
42-42: Undefined name
DataConfig
(F821)
Add missing attributes to TrainingJobConfig
The class docstring mentions several attributes that are not present in the provided code snippet. To ensure consistency with the documentation and completeness of the class, please add the following attributes:
model
optimization
outputs
name
description
sleap_version
filename
Here's a suggested implementation for the missing attributes:
@attr.s(auto_attribs=True)
class TrainingJobConfig:
# ... (existing docstring)
data: DataConfig = attr.ib(factory=DataConfig)
model: ModelConfig = attr.ib(factory=ModelConfig)
optimization: OptimizationConfig = attr.ib(factory=OptimizationConfig)
outputs: OutputConfig = attr.ib(factory=OutputConfig)
name: Optional[str] = attr.ib(default=None)
description: Optional[str] = attr.ib(default=None)
sleap_version: str = attr.ib(default=attr.Factory(lambda: sleap.__version__))
filename: Optional[str] = attr.ib(default=None)
Note: You'll need to import the necessary classes (ModelConfig
, OptimizationConfig
, OutputConfig
) and Optional
from typing
. Also, import sleap
to get the version.
🧰 Tools
🪛 Ruff
42-42: Undefined name
DataConfig
(F821)
42-42: Undefined name
attr
(F821)
42-42: Undefined name
DataConfig
(F821)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Outside diff range and nitpick comments (2)
sleap_nn/config/data_config.py (2)
1-3
: Remove unused importThe
List
type fromtyping
is imported but not used in the file. Consider removing it to keep the imports clean.Apply this diff to remove the unused import:
-from typing import Optional, Tuple, List, Dict +from typing import Optional, Tuple, Dict🧰 Tools
🪛 Ruff
3-3:
typing.List
imported but unusedRemove unused import:
typing.List
(F401)
38-51
: Approve IntensityConfig class and suggest comment updateThe
IntensityConfig
class is well-structured with appropriate attributes for intensity-related augmentation options. However, there's a comment indicating that this part is still under review.Once the review of this section is complete, consider removing or updating the following comment:
# still looking over this part and below (function to check if is greater than 0)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- sleap_nn/config/data_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py
3-3:
typing.List
imported but unusedRemove unused import:
typing.List
(F401)
18-18: Undefined name
PreprocessingConfig
(F821)
18-18: Undefined name
PreprocessingConfig
(F821)
20-20: Undefined name
AugmentationConfig
(F821)
28-28: Undefined name
Union
(F821)
35-35: Undefined name
IntensityConfig
(F821)
36-36: Undefined name
GeometricConfig
(F821)
🔇 Additional comments (2)
sleap_nn/config/data_config.py (2)
53-67
: Approve GeometricConfig classThe
GeometricConfig
class is well-structured with appropriate attributes for geometric augmentation options. The default values seem reasonable, and the use of Optional types where appropriate is good practice.
1-67
: Summary of reviewOverall, the configuration classes in this file are well-structured and provide a comprehensive set of options for data handling, preprocessing, and augmentation in the SLEAP-NN project. The use of
attr.s
withauto_attribs=True
is a good practice for creating classes with less boilerplate code.Main points to address:
- Remove the unused
List
import.- Fix circular import issues by using string annotations for types within the same file.
- Import the
Union
type fromtyping
.- Update or remove the comment in the
IntensityConfig
class once the review is complete.After addressing these points, the file will be in excellent shape for use in the project.
🧰 Tools
🪛 Ruff
3-3:
typing.List
imported but unusedRemove unused import:
typing.List
(F401)
18-18: Undefined name
PreprocessingConfig
(F821)
18-18: Undefined name
PreprocessingConfig
(F821)
20-20: Undefined name
AugmentationConfig
(F821)
28-28: Undefined name
Union
(F821)
35-35: Undefined name
IntensityConfig
(F821)
36-36: Undefined name
GeometricConfig
(F821)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Outside diff range and nitpick comments (1)
sleap_nn/config/data_config.py (1)
1-3
: Remove unused importThe
List
type is imported but not used in the file. Consider removing it to keep the imports clean.Apply this diff to remove the unused import:
-from typing import Optional, Tuple, List, Dict +from typing import Optional, Tuple, Dict🧰 Tools
🪛 Ruff
3-3:
typing.List
imported but unusedRemove unused import:
typing.List
(F401)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- sleap_nn/config/data_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py
3-3:
typing.List
imported but unusedRemove unused import:
typing.List
(F401)
18-18: Undefined name
PreprocessingConfig
(F821)
18-18: Undefined name
PreprocessingConfig
(F821)
20-20: Undefined name
AugmentationConfig
(F821)
28-28: Undefined name
Union
(F821)
35-35: Undefined name
IntensityConfig
(F821)
36-36: Undefined name
GeometricConfig
(F821)
🔇 Additional comments (3)
sleap_nn/config/data_config.py (3)
38-61
: Approve IntensityConfig classThe
IntensityConfig
class is well-structured with appropriate attributes for intensity augmentation options. The inclusion of validator methods foruniform_noise_min
anduniform_noise_max
is a good practice to ensure valid values.The class implementation looks good and provides a comprehensive set of options for intensity-based augmentations.
63-77
: Approve GeometricConfig classThe
GeometricConfig
class is well-structured with appropriate attributes for geometric augmentation options. The attribute names are clear and descriptive, and the types are suitable for their intended purposes.The class implementation provides a comprehensive set of options for geometric augmentations, which should cover most common use cases in image augmentation pipelines.
1-77
: Overall assessment: Well-structured configuration classes with minor improvements neededThe
sleap_nn/config/data_config.py
file introduces a set of well-structured configuration classes for data handling, preprocessing, and augmentation. The use of theattr
library and type hints enhances code readability and maintainability.Key points:
- The classes provide comprehensive options for various aspects of data processing and augmentation.
- The code follows good practices in terms of class structure and attribute definitions.
- Validator methods in the
IntensityConfig
class add an extra layer of safety.Minor improvements suggested:
- Remove unused import (
List
).- Add missing import (
Union
).- Address potential circular import issues with string annotations.
Overall, the file provides a solid foundation for configuring data-related operations in the project. After addressing the minor issues, it will be ready for use.
🧰 Tools
🪛 Ruff
3-3:
typing.List
imported but unusedRemove unused import:
typing.List
(F401)
18-18: Undefined name
PreprocessingConfig
(F821)
18-18: Undefined name
PreprocessingConfig
(F821)
20-20: Undefined name
AugmentationConfig
(F821)
28-28: Undefined name
Union
(F821)
35-35: Undefined name
IntensityConfig
(F821)
36-36: Undefined name
GeometricConfig
(F821)
sleap_nn/config/data_config.py
Outdated
@attr.s(auto_attribs=True) | ||
class PreprocessingConfig: | ||
is_rgb: bool = True | ||
max_height: Optional[int] = None | ||
max_width: Optional[int] = None | ||
scale: Union[float, Tuple[float, float]] = 1.0 | ||
crop_hw: Optional[Tuple[int, int]] = None | ||
min_crop_size: int = 32 #to help app work incase of error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approve PreprocessingConfig class and suggest import fix
The PreprocessingConfig
class is well-structured with appropriate attributes for preprocessing options. However, the Union
type is used but not imported.
Add the following import at the top of the file:
from typing import Union
🧰 Tools
🪛 Ruff
28-28: Undefined name
Union
(F821)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Outside diff range and nitpick comments (1)
sleap_nn/config/data_config.py (1)
1-3
: Remove unused importThe
List
type is imported fromtyping
but not used in the file. Consider removing it to keep the imports clean.Apply this diff to remove the unused import:
-from typing import Optional, Tuple, List, Dict +from typing import Optional, Tuple, Dict🧰 Tools
🪛 Ruff
3-3:
typing.List
imported but unusedRemove unused import:
typing.List
(F401)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- sleap_nn/config/data_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py
3-3:
typing.List
imported but unusedRemove unused import:
typing.List
(F401)
25-25: Undefined name
PreprocessingConfig
(F821)
25-25: Undefined name
PreprocessingConfig
(F821)
27-27: Undefined name
AugmentationConfig
(F821)
46-46: Undefined name
Union
(F821)
61-61: Undefined name
IntensityConfig
(F821)
62-62: Undefined name
GeometricConfig
(F821)
🔇 Additional comments (3)
sleap_nn/config/data_config.py (3)
64-103
: LGTM: IntensityConfig class is well-structuredThe
IntensityConfig
class is well-defined with appropriate attributes and validators for uniform noise parameters. The use ofattr.s
decorator withauto_attribs=True
is a good practice for creating classes with less boilerplate code.
105-138
: LGTM: GeometricConfig class is well-structuredThe
GeometricConfig
class is well-defined with appropriate attributes for geometric augmentations. The use ofattr.s
decorator withauto_attribs=True
is consistent with the other classes in this file.
1-138
: Overall: Well-structured configuration classes for data handlingThis file introduces a set of well-organized configuration classes for managing various aspects of data handling in machine learning workflows. The use of
attr.s
decorators withauto_attribs=True
promotes clean and maintainable code. The classes cover essential aspects such as data sources, preprocessing, and augmentation options.A few minor improvements have been suggested regarding imports and forward references. Once these are addressed, the file will provide a robust foundation for configuring data-related parameters in the project.
🧰 Tools
🪛 Ruff
3-3:
typing.List
imported but unusedRemove unused import:
typing.List
(F401)
25-25: Undefined name
PreprocessingConfig
(F821)
25-25: Undefined name
PreprocessingConfig
(F821)
27-27: Undefined name
AugmentationConfig
(F821)
46-46: Undefined name
Union
(F821)
61-61: Undefined name
IntensityConfig
(F821)
62-62: Undefined name
GeometricConfig
(F821)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (2)
sleap_nn/config/data_config.py (2)
1-3
: Optimize importsPlease make the following changes to the import statements:
- Remove the unused import of
List
.- Add
Union
to the imports fromtyping
.Apply this diff to optimize the imports:
import attrs from omegaconf import MISSING -from typing import Optional, Tuple, List, Dict +from typing import Optional, Tuple, Dict, UnionThis change will remove the unused import and add the necessary
Union
type.🧰 Tools
🪛 Ruff
3-3:
typing.List
imported but unusedRemove unused import:
typing.List
(F401)
94-127
: Approve GeometricConfig class and fix typoThe
GeometricConfig
class is well-structured with appropriate attributes and default values. The docstring provides clear explanations for each attribute.There's a small typo in the docstring for the
erase_ratio_min
attribute. Please fix it as follows:- erase_ration_min: (float) Minimum value of range of aspect ratio of erased area. Default: 1. + erase_ratio_min: (float) Minimum value of range of aspect ratio of erased area. Default: 1.This change will correct the attribute name in the docstring.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- sleap_nn/config/data_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py
3-3:
typing.List
imported but unusedRemove unused import:
typing.List
(F401)
25-25: Undefined name
PreprocessingConfig
(F821)
25-25: Undefined name
PreprocessingConfig
(F821)
27-27: Undefined name
AugmentationConfig
(F821)
46-46: Undefined name
Union
(F821)
61-61: Undefined name
IntensityConfig
(F821)
62-62: Undefined name
GeometricConfig
(F821)
🔇 Additional comments (3)
sleap_nn/config/data_config.py (3)
30-48
: Approve PreprocessingConfig class and remind about Union importThe
PreprocessingConfig
class is well-structured with appropriate attributes and default values. The docstring provides clear explanations for each attribute.Remember to add the
Union
import as suggested in the previous comment about optimizing imports.🧰 Tools
🪛 Ruff
46-46: Undefined name
Union
(F821)
64-92
: Approve IntensityConfig classThe
IntensityConfig
class is well-structured with appropriate attributes, default values, and validators. The docstring provides clear explanations for each attribute. The use ofattrs.field
with validators foruniform_noise_min
,uniform_noise_max
,contrast_min
, andcontrast_max
is a good practice to ensure valid input.
1-127
: Overall assessment: Well-structured configuration classes with minor improvements neededThe
data_config.py
file introduces a set of well-structured configuration classes for managing data parameters in machine learning workflows. The use of theattrs
library and detailed docstrings contributes to clean and self-documenting code.Main points of improvement:
- Optimize imports by removing unused ones and adding missing ones.
- Use forward references to resolve potential circular import issues.
- Fix a minor typo in the
GeometricConfig
class docstring.Once these changes are implemented, the file will be in excellent shape, providing a robust foundation for configuring data-related parameters in your project.
🧰 Tools
🪛 Ruff
3-3:
typing.List
imported but unusedRemove unused import:
typing.List
(F401)
25-25: Undefined name
PreprocessingConfig
(F821)
25-25: Undefined name
PreprocessingConfig
(F821)
27-27: Undefined name
AugmentationConfig
(F821)
46-46: Undefined name
Union
(F821)
61-61: Undefined name
IntensityConfig
(F821)
62-62: Undefined name
GeometricConfig
(F821)
sleap_nn/config/data_config.py
Outdated
@attrs.define | ||
class DataConfig: | ||
"""Data configuration. | ||
|
||
labels: Configuration options related to user labels for training or testing. | ||
preprocessing: Configuration options related to data preprocessing. | ||
instance_cropping: Configuration options related to instance cropping for centroid | ||
and topdown models. | ||
""" | ||
|
||
provider: str="LabelsReader" | ||
train_labels_path: str=MISSING | ||
val_labels_path: str=MISSING | ||
preprocessing: PreprocessingConfig = attrs.field(factory=PreprocessingConfig) | ||
use_augmentations_train: bool=False | ||
augmentation_config: Optional[AugmentationConfig] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use forward references for configuration classes
The DataConfig
class looks well-structured, but there are undefined names for PreprocessingConfig
and AugmentationConfig
. To avoid potential circular imports, use forward references for these classes.
Apply this diff to use forward references:
+from __future__ import annotations
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from sleap_nn.config.data_config import PreprocessingConfig, AugmentationConfig
@attrs.define
class DataConfig:
# ... (existing code)
- preprocessing: PreprocessingConfig = attrs.field(factory=PreprocessingConfig)
- augmentation_config: Optional[AugmentationConfig] = None
+ preprocessing: 'PreprocessingConfig' = attrs.field(factory=lambda: PreprocessingConfig())
+ augmentation_config: Optional['AugmentationConfig'] = None
This change will resolve the undefined names issue and prevent potential circular imports.
Committable suggestion was skipped due to low confidence.
🧰 Tools
🪛 Ruff
25-25: Undefined name
PreprocessingConfig
(F821)
25-25: Undefined name
PreprocessingConfig
(F821)
27-27: Undefined name
AugmentationConfig
(F821)
sleap_nn/config/data_config.py
Outdated
@attrs.define | ||
class AugmentationConfig: | ||
""" Configuration of Augmentation | ||
|
||
Attributes: | ||
random crop: (Optional) (Dict[float]) {"random_crop_p": None, "crop_height": None. "crop_width": None}, where random_crop_p is the probability of applying random crop and crop_height and crop_width are the desired output size (out_h, out_w) of the crop. | ||
intensity: (Optional) | ||
geometric: (Optional) | ||
""" | ||
|
||
random_crop: Optional[Dict[str, Optional[float]]] = None | ||
intensity: Optional[IntensityConfig] = attrs.field(default=None) | ||
geometric: Optional[GeometricConfig] = attrs.field(default=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use forward references for IntensityConfig and GeometricConfig
The AugmentationConfig
class is well-structured, but there are undefined names for IntensityConfig
and GeometricConfig
. To avoid potential circular imports, use forward references for these classes.
Apply this diff to use forward references:
@attrs.define
class AugmentationConfig:
# ... (existing code)
random_crop: Optional[Dict[str, Optional[float]]] = None
- intensity: Optional[IntensityConfig] = attrs.field(default=None)
- geometric: Optional[GeometricConfig] = attrs.field(default=None)
+ intensity: Optional['IntensityConfig'] = attrs.field(default=None)
+ geometric: Optional['GeometricConfig'] = attrs.field(default=None)
This change will resolve the undefined names issue and prevent potential circular imports.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
@attrs.define | |
class AugmentationConfig: | |
""" Configuration of Augmentation | |
Attributes: | |
random crop: (Optional) (Dict[float]) {"random_crop_p": None, "crop_height": None. "crop_width": None}, where random_crop_p is the probability of applying random crop and crop_height and crop_width are the desired output size (out_h, out_w) of the crop. | |
intensity: (Optional) | |
geometric: (Optional) | |
""" | |
random_crop: Optional[Dict[str, Optional[float]]] = None | |
intensity: Optional[IntensityConfig] = attrs.field(default=None) | |
geometric: Optional[GeometricConfig] = attrs.field(default=None) | |
@attrs.define | |
class AugmentationConfig: | |
""" Configuration of Augmentation | |
Attributes: | |
random crop: (Optional) (Dict[float]) {"random_crop_p": None, "crop_height": None. "crop_width": None}, where random_crop_p is the probability of applying random crop and crop_height and crop_width are the desired output size (out_h, out_w) of the crop. | |
intensity: (Optional) | |
geometric: (Optional) | |
""" | |
random_crop: Optional[Dict[str, Optional[float]]] = None | |
intensity: Optional['IntensityConfig'] = attrs.field(default=None) | |
geometric: Optional['GeometricConfig'] = attrs.field(default=None) |
🧰 Tools
🪛 Ruff
61-61: Undefined name
IntensityConfig
(F821)
62-62: Undefined name
GeometricConfig
(F821)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (4)
sleap_nn/config/data_config.py (4)
1-3
: Remove unused importThe
List
type is imported fromtyping
but not used in the file. To keep the imports clean, you can remove it.Apply this change:
- from typing import Optional, Tuple, List, Dict + from typing import Optional, Tuple, Dict🧰 Tools
🪛 Ruff
3-3:
typing.List
imported but unusedRemove unused import:
typing.List
(F401)
12-35
: Approve DataConfig class and suggest forward referencesThe
DataConfig
class is well-structured with clear attribute definitions and a comprehensive docstring. However, there are undefined names forPreprocessingConfig
andAugmentationConfig
.To resolve the undefined names and prevent potential circular imports, use forward references for these classes. Apply this diff:
+from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sleap_nn.config.data_config import PreprocessingConfig, AugmentationConfig @attrs.define class DataConfig: # ... (existing code) - preprocessing: PreprocessingConfig = attrs.field(factory=PreprocessingConfig) - augmentation_config: Optional[AugmentationConfig] = None + preprocessing: 'PreprocessingConfig' = attrs.field(factory=lambda: PreprocessingConfig()) + augmentation_config: Optional['AugmentationConfig'] = NoneThis change will resolve the undefined names issue and prevent potential circular imports.
🧰 Tools
🪛 Ruff
33-33: Undefined name
PreprocessingConfig
(F821)
33-33: Undefined name
PreprocessingConfig
(F821)
35-35: Undefined name
AugmentationConfig
(F821)
38-56
: Approve PreprocessingConfig class and add missing importThe
PreprocessingConfig
class is well-structured with appropriate attributes for preprocessing options. However, theUnion
type is used but not imported.Add the following import at the top of the file:
- from typing import Optional, Tuple, Dict + from typing import Optional, Tuple, Dict, UnionThis will resolve the undefined
Union
type issue.🧰 Tools
🪛 Ruff
54-54: Undefined name
Union
(F821)
58-70
: Approve AugmentationConfig class and use forward referencesThe
AugmentationConfig
class is well-structured, but there are undefined names forIntensityConfig
andGeometricConfig
.To resolve the undefined names and prevent potential circular imports, use forward references for these classes. Apply this diff:
@attrs.define class AugmentationConfig: # ... (existing code) random_crop: Optional[Dict[str, Optional[float]]] = None - intensity: Optional[IntensityConfig] = attrs.field(default=None) - geometric: Optional[GeometricConfig] = attrs.field(default=None) + intensity: Optional['IntensityConfig'] = attrs.field(default=None) + geometric: Optional['GeometricConfig'] = attrs.field(default=None)This change will resolve the undefined names issue and prevent potential circular imports.
🧰 Tools
🪛 Ruff
69-69: Undefined name
IntensityConfig
(F821)
70-70: Undefined name
GeometricConfig
(F821)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- sleap_nn/config/data_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py
3-3:
typing.List
imported but unusedRemove unused import:
typing.List
(F401)
33-33: Undefined name
PreprocessingConfig
(F821)
33-33: Undefined name
PreprocessingConfig
(F821)
35-35: Undefined name
AugmentationConfig
(F821)
54-54: Undefined name
Union
(F821)
69-69: Undefined name
IntensityConfig
(F821)
70-70: Undefined name
GeometricConfig
(F821)
🔇 Additional comments (3)
sleap_nn/config/data_config.py (3)
72-100
: Approve IntensityConfig classThe
IntensityConfig
class is well-structured with clear attribute definitions, comprehensive docstring, and appropriate use of validators for certain attributes. This ensures data integrity and provides clear documentation for users.The implementation looks good and follows best practices.
102-135
: Approve GeometricConfig classThe
GeometricConfig
class is well-structured with clear attribute definitions, comprehensive docstring, and appropriate default values for all attributes. This provides a good starting point for users and clear documentation of the available options.The implementation looks good and follows best practices.
1-135
: Final approval: Well-structured configuration classes with minor improvements neededOverall, the implementation of the configuration classes in
sleap_nn/config/data_config.py
is well-done. The classes are clearly structured, with comprehensive docstrings and appropriate use of theattrs
library. The suggested improvements mainly revolve around import statements and using forward references to prevent circular imports.Key points:
- Remove the unused
List
import.- Add the missing
Union
import.- Use forward references for
PreprocessingConfig
,AugmentationConfig
,IntensityConfig
, andGeometricConfig
to resolve undefined names and prevent potential circular imports.After implementing these minor changes, the code will be more robust and maintainable.
🧰 Tools
🪛 Ruff
3-3:
typing.List
imported but unusedRemove unused import:
typing.List
(F401)
33-33: Undefined name
PreprocessingConfig
(F821)
33-33: Undefined name
PreprocessingConfig
(F821)
35-35: Undefined name
AugmentationConfig
(F821)
54-54: Undefined name
Union
(F821)
69-69: Undefined name
IntensityConfig
(F821)
70-70: Undefined name
GeometricConfig
(F821)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (3)
sleap_nn/config/data_config.py (3)
1-3
: Approve imports with minor suggestionThe imports look good overall. However, there's an unused import that can be removed.
Remove the unused
List
import:-from typing import Optional, Tuple, List, Dict +from typing import Optional, Tuple, Dict🧰 Tools
🪛 Ruff
3-3:
typing.List
imported but unusedRemove unused import:
typing.List
(F401)
13-36
: Approve DataConfig class and suggest forward referencesThe
DataConfig
class is well-structured with clear documentation. However, there are undefined names forPreprocessingConfig
andAugmentationConfig
.To avoid potential circular imports, use forward references for these classes. Add the following import at the top of the file and modify the class attributes:
from __future__ import annotationsThen update the attributes:
- preprocessing: PreprocessingConfig = attrs.field(factory=PreprocessingConfig) - augmentation_config: Optional[AugmentationConfig] = None + preprocessing: 'PreprocessingConfig' = attrs.field(factory=lambda: PreprocessingConfig()) + augmentation_config: Optional['AugmentationConfig'] = NoneThis change will resolve the undefined names issue and prevent potential circular imports.
🧰 Tools
🪛 Ruff
34-34: Undefined name
PreprocessingConfig
(F821)
34-34: Undefined name
PreprocessingConfig
(F821)
36-36: Undefined name
AugmentationConfig
(F821)
60-72
: Approve AugmentationConfig class and suggest forward referencesThe
AugmentationConfig
class is well-structured with clear documentation. However, there are undefined names forIntensityConfig
andGeometricConfig
.To avoid potential circular imports, use forward references for these classes. Modify the class attributes as follows:
- intensity: Optional[IntensityConfig] = attrs.field(default=None) - geometric: Optional[GeometricConfig] = attrs.field(default=None) + intensity: Optional['IntensityConfig'] = attrs.field(default=None) + geometric: Optional['GeometricConfig'] = attrs.field(default=None)This change will resolve the undefined names issue and prevent potential circular imports.
🧰 Tools
🪛 Ruff
71-71: Undefined name
IntensityConfig
(F821)
72-72: Undefined name
GeometricConfig
(F821)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- sleap_nn/config/data_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py
3-3:
typing.List
imported but unusedRemove unused import:
typing.List
(F401)
34-34: Undefined name
PreprocessingConfig
(F821)
34-34: Undefined name
PreprocessingConfig
(F821)
36-36: Undefined name
AugmentationConfig
(F821)
55-55: Undefined name
Union
(F821)
71-71: Undefined name
IntensityConfig
(F821)
72-72: Undefined name
GeometricConfig
(F821)
🔇 Additional comments (3)
sleap_nn/config/data_config.py (3)
75-107
: Approve IntensityConfig classThe
IntensityConfig
class is well-structured with clear documentation and appropriate use of attrs validators. All attributes have sensible default values, and the validators help ensure valid input.The implementation looks good, and no changes are necessary.
110-143
: Approve GeometricConfig classThe
GeometricConfig
class is well-structured with clear and detailed documentation for each attribute. The use of Optional types and default values is appropriate, providing flexibility in configuration.The implementation looks good, and no changes are necessary.
1-143
: Summary and final recommendationOverall, the
data_config.py
file is well-structured and provides a comprehensive set of configuration classes for data handling in a machine learning pipeline. The use of theattrs
library and detailed documentation for each class and attribute is commendable.To improve the code quality and prevent potential issues, please implement the following changes:
- Remove the unused
List
import.- Add
from __future__ import annotations
at the top of the file.- Use forward references for
PreprocessingConfig
,AugmentationConfig
,IntensityConfig
, andGeometricConfig
in the respective class attributes.- Add
from typing import Union
to resolve the undefinedUnion
type.After implementing these changes, the file will be more robust and less prone to circular import issues.
The overall implementation is good, and with these minor adjustments, it will be excellent.
🧰 Tools
🪛 Ruff
3-3:
typing.List
imported but unusedRemove unused import:
typing.List
(F401)
34-34: Undefined name
PreprocessingConfig
(F821)
34-34: Undefined name
PreprocessingConfig
(F821)
36-36: Undefined name
AugmentationConfig
(F821)
55-55: Undefined name
Union
(F821)
71-71: Undefined name
IntensityConfig
(F821)
72-72: Undefined name
GeometricConfig
(F821)
@attrs.define | ||
class PreprocessingConfig: | ||
"""Configuration of Preprocessing. | ||
|
||
Attributes: | ||
is_rgb: (bool) True if the image has 3 channels (RGB image). If input has only one channel when this is set to True, then the images from single-channel is replicated along the channel axis. If input has three channels and this is set to False, then we convert the image to grayscale (single-channel) image. | ||
max_height: (int) Maximum height the image should be padded to. If not provided, the original image size will be retained. Default: None. | ||
max_width: (int) Maximum width the image should be padded to. If not provided, the original image size will be retained. Default: None. | ||
scale: (float or List[float]) Factor to resize the image dimensions by, specified as either a float scalar or as a 2-tuple of [scale_x, scale_y]. If a scalar is provided, both dimensions are resized by the same factor. | ||
crop_hw: (Tuple[int]) Crop height and width of each instance (h, w) for centered-instance model. If None, this would be automatically computed based on the largest instance in the sio.Labels file. | ||
min_crop_size: (int) Minimum crop size to be used if crop_hw is None. | ||
""" | ||
|
||
is_rgb: bool = True | ||
max_height: Optional[int] = None | ||
max_width: Optional[int] = None | ||
scale: Union[float, Tuple[float, float]] = 1.0 | ||
crop_hw: Optional[Tuple[int, int]] = None | ||
min_crop_size: int = 32 # to help app work incase of error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approve PreprocessingConfig class and add missing import
The PreprocessingConfig
class is well-structured with appropriate attributes and documentation. However, there's a missing import for the Union
type.
Add the following import at the top of the file:
from typing import Union
This will resolve the undefined name issue for the Union
type used in the scale
attribute.
🧰 Tools
🪛 Ruff
55-55: Undefined name
Union
(F821)
…ackbone_config based on backbone_type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (1)
sleap_nn/config/model_config.py (1)
1-9
: Consider adding type hints to class attributes.While the attributes are well-described in the docstring, adding type hints would improve code clarity and enable better static type checking. Consider updating the class definition to include type information for each attribute. For example:
@attr.s(auto_attribs=True) class ModelConfig: backbone: dict heads: dict base_checkpoint: strReplace
dict
andstr
with more specific types if applicable. This change will enhance code readability and make it easier for developers to understand the expected types of these configuration parameters.🧰 Tools
🪛 Ruff
1-1: Undefined name
attr
(F821)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py
1-1: Undefined name
attr
(F821)
🔇 Additional comments (3)
sleap_nn/config/model_config.py (3)
1-9
: LGTM: Well-structured class definition and docstring.The
ModelConfig
class is well-defined using the@attr.s(auto_attribs=True)
decorator, which is appropriate for automatic attribute management. The docstring provides a clear description of the class purpose and its attributes, following the Google style guide format. This promotes good code documentation and maintainability.🧰 Tools
🪛 Ruff
1-1: Undefined name
attr
(F821)
1-9
: Overall, good implementation of theModelConfig
class.The
ModelConfig
class aligns well with the PR objectives of migrating configuration classes and updating them to use the new attrs API. The implementation is clean and well-documented. With the suggested improvements (adding the import statement and considering type hints), this class will provide a solid foundation for model configuration in the project.🧰 Tools
🪛 Ruff
1-1: Undefined name
attr
(F821)
1-1
:⚠️ Potential issueAdd missing import statement for
attr
.The
attr
module is used in the code but not imported. To resolve the undefined name issue and make the usage of the@attr.s
decorator valid, please add the following import statement at the beginning of the file:import attr🧰 Tools
🪛 Ruff
1-1: Undefined name
attr
(F821)
…ackbone_config based on backbone_type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (1)
sleap_nn/config/model_config.py (1)
23-23
: Consider using consistent decorators across the file.The
UNetConfig
class uses@attrs.define
, while the mainModelConfig
class uses@attr.s
. For consistency, consider using the same decorator style throughout the file. If you're using a newer version ofattrs
, you might want to update all decorators to@attrs.define
. Otherwise, change this to@attr.s(auto_attribs=True)
to match the main class.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py
13-14: SyntaxError: Expected an expression
15-16: SyntaxError: Expected an expression
🔇 Additional comments (2)
sleap_nn/config/model_config.py (2)
1-59
: Overall structure looks good, with some minor improvements needed.The
ModelConfig
class and its nested configuration classes provide a comprehensive and flexible structure for configuring different model architectures. The use ofattrs
for class definitions is a good choice for reducing boilerplate code.To improve the file:
- Add the missing imports for
attr
andEnum
.- Complete the attribute definitions for
pre_trained_weights
andbackbone_config
.- Consider using consistent decorators across all class definitions.
These changes will enhance the code's correctness and consistency.
🧰 Tools
🪛 Ruff
13-14: SyntaxError: Expected an expression
15-16: SyntaxError: Expected an expression
1-1
:⚠️ Potential issueAdd missing import for
attr
module.The
attr
module is used in this file but not imported. Add the following import at the beginning of the file:import attrThis will resolve the undefined name issue for
attr
.
sleap_nn/config/model_config.py
Outdated
pre_trained_weights: str = | ||
backbone_type: BackboneType = BackboneType.UNET | ||
backbone_config: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Complete the attribute definitions for pre_trained_weights
and backbone_config
.
The pre_trained_weights
and backbone_config
attributes are incomplete. Please provide appropriate default values or type annotations for these attributes. For example:
pre_trained_weights: str = ""
backbone_config: Union[UNetConfig, ConvNextConfig, SwinTConfig] = attr.field(default=None)
Make sure to import Union
from the typing
module if you use it.
🧰 Tools
🪛 Ruff
13-14: SyntaxError: Expected an expression
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (3)
sleap_nn/config/model_config.py (3)
20-46
: LGTM! Consider refactoring error messages.The methods in the
ModelConfig
class are well-implemented and provide good initialization and validation logic.To reduce code duplication in error messages, consider creating a helper method for generating error messages:
def _get_weight_error_message(self, backbone_type, valid_weights): return f"Invalid pre-trained weights for {backbone_type}. Must be one of {valid_weights}" # Then use it in validate_pre_trained_weights: if self.backbone_type == BackboneType.CONVNEXT: if self.pre_trained_weights not in convnext_weights: raise ValueError(self._get_weight_error_message("ConvNext", convnext_weights)) elif self.backbone_type == BackboneType.SWINT: if self.pre_trained_weights not in swint_weights: raise ValueError(self._get_weight_error_message("SwinT", swint_weights))This refactoring will make the code more maintainable and reduce the risk of inconsistencies in error messages.
🧰 Tools
🪛 Ruff
25-25: Undefined name
BackboneType
(F821)
26-26: Undefined name
UNetConfig
(F821)
27-27: Undefined name
BackboneType
(F821)
28-28: Undefined name
ConvNextConfig
(F821)
29-29: Undefined name
BackboneType
(F821)
30-30: Undefined name
SwinTConfig
(F821)
38-38: Undefined name
BackboneType
(F821)
41-41: Undefined name
BackboneType
(F821)
44-44: Undefined name
BackboneType
(F821)
53-65
: LGTM! Consider adding validation forNone
values.The
UNetConfig
class is well-defined with appropriate attributes for UNet configuration.Consider adding validation for attributes that have
None
as default value, such asmax_stride
andstem_stride
. You could do this in a post-init method:@attrs.define class UNetConfig: # ... existing attributes ... def __attrs_post_init__(self): if self.max_stride is None: # Set a default value or raise an error if it's required self.max_stride = 16 # Example default value if self.stem_stride is None: # Set a default value or raise an error if it's required self.stem_stride = 1 # Example default valueThis ensures that these critical parameters always have valid values.
66-89
: LGTM! Consider reordering attributes for consistency.The
ConvNextConfig
andSwinTConfig
classes are well-defined with appropriate attributes for their respective architectures.For better consistency between the two classes, consider reordering the attributes in
SwinTConfig
to match the order inConvNextConfig
as closely as possible. For example:@attrs.define class SwinTConfig: model_type: str = "tiny" arch: dict = attrs.field(factory=lambda: {'embed': 96, 'depths': [2, 2, 6, 2], 'channels': [3, 6, 12, 24]}) in_channels: int = 1 patch_size: list = attrs.field(factory=lambda: [4, 4]) stem_patch_stride: int = 2 window_size: list = attrs.field(factory=lambda: [7, 7]) kernel_size: int = 3 filters_rate: float = 1.5 convs_per_block: int = 2 up_interpolate: bool = TrueThis reordering makes it easier to compare the two configurations at a glance.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py
17-17: Undefined name
BackboneType
(F821)
17-17: Undefined name
BackboneType
(F821)
18-18: Undefined name
Union
(F821)
18-18: Undefined name
UNetConfig
(F821)
18-18: Undefined name
ConvNextConfig
(F821)
18-18: Undefined name
SwinTConfig
(F821)
25-25: Undefined name
BackboneType
(F821)
26-26: Undefined name
UNetConfig
(F821)
27-27: Undefined name
BackboneType
(F821)
28-28: Undefined name
ConvNextConfig
(F821)
29-29: Undefined name
BackboneType
(F821)
30-30: Undefined name
SwinTConfig
(F821)
38-38: Undefined name
BackboneType
(F821)
41-41: Undefined name
BackboneType
(F821)
44-44: Undefined name
BackboneType
(F821)
🔇 Additional comments (3)
sleap_nn/config/model_config.py (3)
47-50
: LGTM! Well-defined enum for backbone types.The
BackboneType
enum is correctly defined and provides clear options for the backbone types. This approach enhances type safety and code readability.
1-89
: Overall, excellent implementation with minor suggestions for improvement.The
ModelConfig
class and its associated nested classes provide a comprehensive and well-structured configuration system for model architectures. The use ofattrs
for class definitions andEnum
for backbone types demonstrates good coding practices.Key strengths:
- Clear separation of concerns for different backbone types.
- Proper use of post-initialization and validation methods.
- Effective use of default values and factory functions for complex defaults.
Suggestions for improvement:
- Add import for
Union
fromtyping
.- Refactor error message generation in
validate_pre_trained_weights
.- Add validation for
None
values inUNetConfig
.- Reorder attributes in
SwinTConfig
for consistency withConvNextConfig
.These minor improvements will enhance the overall quality and maintainability of the code.
🧰 Tools
🪛 Ruff
17-17: Undefined name
BackboneType
(F821)
17-17: Undefined name
BackboneType
(F821)
18-18: Undefined name
Union
(F821)
18-18: Undefined name
UNetConfig
(F821)
18-18: Undefined name
ConvNextConfig
(F821)
18-18: Undefined name
SwinTConfig
(F821)
25-25: Undefined name
BackboneType
(F821)
26-26: Undefined name
UNetConfig
(F821)
27-27: Undefined name
BackboneType
(F821)
28-28: Undefined name
ConvNextConfig
(F821)
29-29: Undefined name
BackboneType
(F821)
30-30: Undefined name
SwinTConfig
(F821)
38-38: Undefined name
BackboneType
(F821)
41-41: Undefined name
BackboneType
(F821)
44-44: Undefined name
BackboneType
(F821)
1-18
: LGTM! Consider adding import forUnion
.The imports and main class definition look good. The
Enum
import addresses a previous comment. However, there's a minor improvement we can make:Consider adding the following import at the beginning of the file:
from typing import UnionThis will resolve the undefined name issue for
Union
on line 18.There might be a circular import issue causing some undefined names. Let's verify this:
If this script returns results, it indicates a potential circular import that needs to be addressed.
🧰 Tools
🪛 Ruff
17-17: Undefined name
BackboneType
(F821)
17-17: Undefined name
BackboneType
(F821)
18-18: Undefined name
Union
(F821)
18-18: Undefined name
UNetConfig
(F821)
18-18: Undefined name
ConvNextConfig
(F821)
18-18: Undefined name
SwinTConfig
(F821)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (2)
sleap_nn/config/model_config.py (2)
55-67
: Consider adding type hints to UNetConfig attributes.To improve code clarity and maintainability, consider adding type hints to the attributes in the
UNetConfig
class. For example:@attrs.define class UNetConfig: in_channels: int = 1 kernel_size: int = 3 filters: int = 32 filters_rate: float = 1.5 max_stride: Optional[int] = None stem_stride: Optional[int] = None middle_block: bool = True up_interpolate: bool = True stacks: int = 3 convs_per_block: int = 2Don't forget to import
Optional
fromtyping
if you use it formax_stride
andstem_stride
.
68-91
: Enhance ConvNextConfig and SwinTConfig with type hints and consider usingattrs.Factory
.
- Add type hints to improve code clarity:
@attrs.define class ConvNextConfig: model_type: str = "tiny" arch: Dict[str, Union[List[int], List[int]]] = attrs.field(factory=lambda: {'depths': [3, 3, 9, 3], 'channels': [96, 192, 384, 768]}) stem_patch_kernel: int = 4 stem_patch_stride: int = 2 in_channels: int = 1 kernel_size: int = 3 filters_rate: float = 1.5 convs_per_block: int = 2 up_interpolate: bool = True @attrs.define class SwinTConfig: model_type: str = "tiny" arch: Dict[str, Union[int, List[int]]] = attrs.field(factory=lambda: {'embed': 96, 'depths': [2, 2, 6, 2], 'channels': [3, 6, 12, 24]}) patch_size: List[int] = attrs.field(factory=lambda: [4, 4]) stem_patch_stride: int = 2 window_size: List[int] = attrs.field(factory=lambda: [7, 7]) in_channels: int = 1 kernel_size: int = 3 filters_rate: float = 1.5 convs_per_block: int = 2 up_interpolate: bool = True
- Consider using
attrs.Factory
instead of lambda functions for better readability:from attrs import Factory # In ConvNextConfig arch: Dict[str, Union[List[int], List[int]]] = attrs.field(factory=Factory(lambda: {'depths': [3, 3, 9, 3], 'channels': [96, 192, 384, 768]})) # In SwinTConfig arch: Dict[str, Union[int, List[int]]] = attrs.field(factory=Factory(lambda: {'embed': 96, 'depths': [2, 2, 6, 2], 'channels': [3, 6, 12, 24]})) patch_size: List[int] = attrs.field(factory=Factory(lambda: [4, 4])) window_size: List[int] = attrs.field(factory=Factory(lambda: [7, 7]))These changes will improve type checking and make the code more explicit about the expected types for each attribute.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py
17-17: Undefined name
BackboneType
(F821)
17-17: Undefined name
BackboneType
(F821)
18-18: Undefined name
Union
(F821)
18-18: Undefined name
UNetConfig
(F821)
18-18: Undefined name
ConvNextConfig
(F821)
18-18: Undefined name
SwinTConfig
(F821)
27-27: Undefined name
BackboneType
(F821)
28-28: Undefined name
UNetConfig
(F821)
29-29: Undefined name
BackboneType
(F821)
30-30: Undefined name
ConvNextConfig
(F821)
31-31: Undefined name
BackboneType
(F821)
32-32: Undefined name
SwinTConfig
(F821)
40-40: Undefined name
BackboneType
(F821)
43-43: Undefined name
BackboneType
(F821)
46-46: Undefined name
BackboneType
(F821)
🔇 Additional comments (1)
sleap_nn/config/model_config.py (1)
1-91
: Overall assessment: Good implementation with room for minor improvements.The
ModelConfig
class and its nested configuration classes provide a well-structured and comprehensive approach to managing model architecture configurations. The use ofattrs
for class definitions is a good choice, making the code more concise and less error-prone.Key strengths:
- Clear separation of concerns for different backbone types.
- Thorough validation of pre-trained weights.
- Use of enums for backbone types, improving type safety.
Suggested improvements:
- Resolve import and undefined name issues.
- Add type hints throughout the file for better code clarity and maintainability.
- Minor refactoring of the
validate_pre_trained_weights
method to reduce code duplication.- Consider using
attrs.Factory
for mutable default values.These changes will enhance the overall quality of the code, making it more robust and easier to maintain in the future.
🧰 Tools
🪛 Ruff
17-17: Undefined name
BackboneType
(F821)
17-17: Undefined name
BackboneType
(F821)
18-18: Undefined name
Union
(F821)
18-18: Undefined name
UNetConfig
(F821)
18-18: Undefined name
ConvNextConfig
(F821)
18-18: Undefined name
SwinTConfig
(F821)
27-27: Undefined name
BackboneType
(F821)
28-28: Undefined name
UNetConfig
(F821)
29-29: Undefined name
BackboneType
(F821)
30-30: Undefined name
ConvNextConfig
(F821)
31-31: Undefined name
BackboneType
(F821)
32-32: Undefined name
SwinTConfig
(F821)
40-40: Undefined name
BackboneType
(F821)
43-43: Undefined name
BackboneType
(F821)
46-46: Undefined name
BackboneType
(F821)
sleap_nn/config/model_config.py
Outdated
import attrs | ||
from enum import Enum | ||
|
||
@attrs.define | ||
class ModelConfig: | ||
"""Configurations related to model architecture. | ||
|
||
Attributes: | ||
init_weight: (str) model weights initialization method. "default" uses kaiming uniform initialization and "xavier" uses Xavier initialization method. | ||
pre_trained_weights: (str) Pretrained weights file name supported only for ConvNext and SwinT backbones. For ConvNext, one of ["ConvNeXt_Base_Weights","ConvNeXt_Tiny_Weights", "ConvNeXt_Small_Weights", "ConvNeXt_Large_Weights"]. For SwinT, one of ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"]. | ||
backbone_type: (str) Backbone architecture for the model to be trained. One of "unet", "convnext" or "swint". | ||
|
||
""" | ||
|
||
init_weight: str = "default" | ||
pre_trained_weights: str = None | ||
backbone_type: BackboneType = BackboneType.UNET | ||
backbone_config: Union[UNetConfig, ConvNextConfig, SwinTConfig] = attrs.field(init=False) # backbone_config can be any of these 3 configurations. init=False lets you set the parameters later (not in initialization) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add missing imports and address undefined names.
To resolve the undefined names and improve the overall structure of the file, please make the following changes:
- Add the missing import for
Union
:
from typing import Union
-
Move the
BackboneType
enum definition to the top of the file, just after the imports. -
Add forward references for the nested configuration classes:
UNetConfig = "ModelConfig.UNetConfig"
ConvNextConfig = "ModelConfig.ConvNextConfig"
SwinTConfig = "ModelConfig.SwinTConfig"
These changes will resolve the undefined name issues and improve the overall structure of the file.
🧰 Tools
🪛 Ruff
17-17: Undefined name
BackboneType
(F821)
17-17: Undefined name
BackboneType
(F821)
18-18: Undefined name
Union
(F821)
18-18: Undefined name
UNetConfig
(F821)
18-18: Undefined name
ConvNextConfig
(F821)
18-18: Undefined name
SwinTConfig
(F821)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (3)
sleap_nn/config/model_config.py (3)
1-2
: Add missing imports for type hints.To improve type checking and code clarity, please add the following imports at the beginning of the file:
from typing import Dict, List, Optional, UnionThese imports are necessary for the type hints used throughout the file.
56-92
: Enhance documentation and type hinting for nested configuration classes.For the
UNetConfig
,ConvNextConfig
, andSwinTConfig
classes:
- Add docstrings to explain the purpose of each class and its attributes.
- Consider adding type hints to all attributes for better code clarity and type checking.
For example, for the
UNetConfig
class:@attrs.define class UNetConfig: """Configuration for UNet backbone. Attributes: in_channels (int): Number of input channels. kernel_size (int): Size of the convolutional kernel. filters (int): Number of filters in the first layer. filters_rate (float): Rate at which the number of filters increases. max_stride (Optional[int]): Maximum stride in the network. stem_stride (Optional[int]): Stride in the stem of the network. middle_block (bool): Whether to include a middle block. up_interpolate (bool): Whether to use interpolation for upsampling. stacks (int): Number of encoder/decoder stacks. convs_per_block (int): Number of convolutions per block. """ in_channels: int = 1 kernel_size: int = 3 filters: int = 32 filters_rate: float = 1.5 max_stride: Optional[int] = None stem_stride: Optional[int] = None middle_block: bool = True up_interpolate: bool = True stacks: int = 3 convs_per_block: int = 2Apply similar improvements to
ConvNextConfig
andSwinTConfig
classes.
94-120
: Improve consistency and documentation for additional configuration classes.For the
HeadConfig
,SingleInstanceConfig
, andConfMapsConfig
classes:
- Add docstrings to
HeadConfig
andSingleInstanceConfig
explaining their purpose and attributes, similar toConfMapsConfig
.- Consider using a consistent style for optional attributes. For example, in
ConfMapsConfig
, you could useattrs.field(default=None)
instead ofOptional[Type] = None
for consistency with other classes.Example for
HeadConfig
:@attrs.define class HeadConfig: """Configuration for model heads. Attributes: head_configs (Dict[str, Optional[Dict]]): A dictionary of head configurations. Keys represent head types, and values are their respective configurations. """ head_configs: Dict[str, Optional[Dict]] = attrs.field( factory=lambda: { "single_instance": None, "centroid": None, "centered_instance": None, "bottomup": None } )Apply similar improvements to
SingleInstanceConfig
.🧰 Tools
🪛 Ruff
96-96: Undefined name
Dict
(F821)
96-96: Undefined name
Optional
(F821)
96-96: Undefined name
Dict
(F821)
107-107: Undefined name
Optional
(F821)
107-107: Undefined name
ConfMapsConfig
(F821)
118-118: Undefined name
Optional
(F821)
118-118: Undefined name
List
(F821)
119-119: Undefined name
Optional
(F821)
120-120: Undefined name
Optional
(F821)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py
17-17: Undefined name
BackboneType
(F821)
17-17: Undefined name
BackboneType
(F821)
18-18: Undefined name
Union
(F821)
18-18: Undefined name
UNetConfig
(F821)
18-18: Undefined name
ConvNextConfig
(F821)
18-18: Undefined name
SwinTConfig
(F821)
19-19: Undefined name
HeadConfig
(F821)
19-19: Undefined name
HeadConfig
(F821)
28-28: Undefined name
BackboneType
(F821)
29-29: Undefined name
UNetConfig
(F821)
30-30: Undefined name
BackboneType
(F821)
31-31: Undefined name
ConvNextConfig
(F821)
32-32: Undefined name
BackboneType
(F821)
33-33: Undefined name
SwinTConfig
(F821)
41-41: Undefined name
BackboneType
(F821)
44-44: Undefined name
BackboneType
(F821)
47-47: Undefined name
BackboneType
(F821)
96-96: Undefined name
Dict
(F821)
96-96: Undefined name
Optional
(F821)
96-96: Undefined name
Dict
(F821)
107-107: Undefined name
Optional
(F821)
107-107: Undefined name
ConfMapsConfig
(F821)
118-118: Undefined name
Optional
(F821)
118-118: Undefined name
List
(F821)
119-119: Undefined name
Optional
(F821)
120-120: Undefined name
Optional
(F821)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py
17-17: Undefined name
BackboneType
(F821)
17-17: Undefined name
BackboneType
(F821)
18-18: Undefined name
Union
(F821)
18-18: Undefined name
UNetConfig
(F821)
18-18: Undefined name
ConvNextConfig
(F821)
18-18: Undefined name
SwinTConfig
(F821)
19-19: Undefined name
HeadConfig
(F821)
19-19: Undefined name
HeadConfig
(F821)
28-28: Undefined name
BackboneType
(F821)
29-29: Undefined name
UNetConfig
(F821)
30-30: Undefined name
BackboneType
(F821)
31-31: Undefined name
ConvNextConfig
(F821)
32-32: Undefined name
BackboneType
(F821)
33-33: Undefined name
SwinTConfig
(F821)
41-41: Undefined name
BackboneType
(F821)
44-44: Undefined name
BackboneType
(F821)
47-47: Undefined name
BackboneType
(F821)
96-96: Undefined name
Dict
(F821)
96-96: Undefined name
Optional
(F821)
96-96: Undefined name
Dict
(F821)
108-108: Undefined name
Optional
(F821)
108-108: Undefined name
SingleInstanceConfMapsConfig
(F821)
119-119: Undefined name
Optional
(F821)
119-119: Undefined name
List
(F821)
120-120: Undefined name
Optional
(F821)
121-121: Undefined name
Optional
(F821)
126-126: Undefined name
Optional
(F821)
126-126: Undefined name
CentroidConfMapsConfig
(F821)
137-137: Undefined name
Optional
(F821)
138-138: Undefined name
Optional
(F821)
139-139: Undefined name
Optional
(F821)
144-144: Undefined name
Optional
(F821)
144-144: Undefined name
CenteredInstanceConfMapsConfig
(F821)
156-156: Undefined name
Optional
(F821)
156-156: Undefined name
List
(F821)
157-157: Undefined name
Optional
(F821)
158-158: Undefined name
Optional
(F821)
159-159: Undefined name
Optional
(F821)
🔇 Additional comments (2)
sleap_nn/config/model_config.py (2)
50-92
: LGTM: Nested configuration classes.The nested configuration classes (
BackboneType
,UNetConfig
,ConvNextConfig
, andSwinTConfig
) are well-structured and use appropriateattrs
decorators. The use ofattrs.field(factory=lambda: ...)
for default dictionaries is a good practice to avoid mutable default arguments.
1-160
: Overall assessment: Well-structured configuration system with minor improvements needed.The
model_config.py
file introduces a comprehensive and well-structured configuration system for a machine learning model. It effectively uses theattrs
library for class definitions and provides clear docstrings for attributes. The nested configuration classes for different backbone types and head configurations are well-organized.To further improve the code:
- Add the missing imports for type hints.
- Update the
backbone_type
attribute inModelConfig
to useattrs.field
.- Consider simplifying the
set_backbone_config
andvalidate_pre_trained_weights
methods as suggested.- Fix the naming inconsistency in the
centroid
class.These changes will enhance the code's clarity, maintainability, and consistency.
🧰 Tools
🪛 Ruff
17-17: Undefined name
BackboneType
(F821)
17-17: Undefined name
BackboneType
(F821)
18-18: Undefined name
Union
(F821)
18-18: Undefined name
UNetConfig
(F821)
18-18: Undefined name
ConvNextConfig
(F821)
18-18: Undefined name
SwinTConfig
(F821)
19-19: Undefined name
HeadConfig
(F821)
19-19: Undefined name
HeadConfig
(F821)
28-28: Undefined name
BackboneType
(F821)
29-29: Undefined name
UNetConfig
(F821)
30-30: Undefined name
BackboneType
(F821)
31-31: Undefined name
ConvNextConfig
(F821)
32-32: Undefined name
BackboneType
(F821)
33-33: Undefined name
SwinTConfig
(F821)
41-41: Undefined name
BackboneType
(F821)
44-44: Undefined name
BackboneType
(F821)
47-47: Undefined name
BackboneType
(F821)
96-96: Undefined name
Dict
(F821)
96-96: Undefined name
Optional
(F821)
96-96: Undefined name
Dict
(F821)
108-108: Undefined name
Optional
(F821)
108-108: Undefined name
SingleInstanceConfMapsConfig
(F821)
119-119: Undefined name
Optional
(F821)
119-119: Undefined name
List
(F821)
120-120: Undefined name
Optional
(F821)
121-121: Undefined name
Optional
(F821)
126-126: Undefined name
Optional
(F821)
126-126: Undefined name
CentroidConfMapsConfig
(F821)
137-137: Undefined name
Optional
(F821)
138-138: Undefined name
Optional
(F821)
139-139: Undefined name
Optional
(F821)
144-144: Undefined name
Optional
(F821)
144-144: Undefined name
CenteredInstanceConfMapsConfig
(F821)
156-156: Undefined name
Optional
(F821)
156-156: Undefined name
List
(F821)
157-157: Undefined name
Optional
(F821)
158-158: Undefined name
Optional
(F821)
159-159: Undefined name
Optional
(F821)
sleap_nn/config/model_config.py
Outdated
def __attrs_post_init__(self): | ||
self.backbone_config = self.set_backbone_config() | ||
self.validate_pre_trained_weights() | ||
|
||
# configures back_bone config to one of these types | ||
def set_backbone_config(self): | ||
if self.backbone_type == BackboneType.UNET: | ||
return UNetConfig() | ||
elif self.backbone_type == BackboneType.CONVNEXT: | ||
return ConvNextConfig() | ||
elif self.backbone_type == BackboneType.SWINT: | ||
return SwinTConfig() | ||
else: | ||
raise ValueError(f"Invalid backbone_type: {self.backbone_type}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Simplify set_backbone_config method.
Consider simplifying the set_backbone_config
method using a dictionary mapping:
def set_backbone_config(self):
config_map = {
BackboneType.UNET: UNetConfig,
BackboneType.CONVNEXT: ConvNextConfig,
BackboneType.SWINT: SwinTConfig,
}
config_class = config_map.get(self.backbone_type)
if config_class is None:
raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
return config_class()
This approach reduces repetitive code and makes it easier to add new backbone types in the future.
🧰 Tools
🪛 Ruff
28-28: Undefined name
BackboneType
(F821)
29-29: Undefined name
UNetConfig
(F821)
30-30: Undefined name
BackboneType
(F821)
31-31: Undefined name
ConvNextConfig
(F821)
32-32: Undefined name
BackboneType
(F821)
33-33: Undefined name
SwinTConfig
(F821)
sleap_nn/config/model_config.py
Outdated
@attrs.define | ||
class HeadConfig: | ||
head_configs: Dict[str, Optional[Dict]] = attrs.field( | ||
factory = lambda:{ | ||
"single_instance": None, | ||
"centroid": None, | ||
"centered_instance": None, | ||
"bottomup": None | ||
} | ||
) | ||
|
||
# Head_config single instance | ||
@attrs.define | ||
class SingleInstanceConfig: | ||
confmaps: Optional[SingleInstanceConfMapsConfig] = None | ||
|
||
@attrs.define | ||
class SingleInstanceConfMapsConfig: | ||
''' | ||
|
||
Attributes: | ||
part_names: (List[str]) None if nodes from sio.Labels file can be used directly. Else provide text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'. | ||
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied. | ||
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. | ||
''' | ||
part_names: Optional[List[str]] = None | ||
sigma: Optional[float] = None | ||
output_stride: Optional[float] = None | ||
|
||
# Head_config centroid | ||
@attrs.define | ||
class centroid: | ||
confmaps: Optional[CentroidConfMapsConfig] = None | ||
|
||
@attrs.define | ||
class CentroidConfMapsConfig: | ||
''' | ||
|
||
Attributes: | ||
anchor_part: (int) Note: Only for 'CenteredInstanceConfmapsHead'. Index of the anchor node to use as the anchor point. If None, the midpoint of the bounding box of all visible instance points will be used as the anchor. The bounding box midpoint will also be used if the anchor part is specified but not visible in the instance. Setting a reliable anchor point can significantly improve topdown model accuracy as they benefit from a consistent geometry of the body parts relative to the center of the image. | ||
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied. | ||
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. | ||
''' | ||
anchor_part: Optional[int] = None | ||
sigma: Optional[float] = None | ||
output_stride: Optional[float] = None | ||
|
||
# Head_config centered_instance | ||
@attrs.define | ||
class centered_instance: | ||
confmaps: Optional[CenteredInstanceConfMapsConfig] = None | ||
|
||
@attrs.define | ||
class CenteredInstanceConfMapsConfig: | ||
''' | ||
|
||
Attributes: | ||
part_names: (List[str]) None if nodes from sio.Labels file can be used directly. Else provide text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'. | ||
anchor_part: (int) Note: Only for 'CenteredInstanceConfmapsHead'. Index of the anchor node to use as the anchor point. If None, the midpoint of the bounding box of all visible instance points will be used as the anchor. The bounding box midpoint will also be used if the anchor part is specified but not visible in the instance. Setting a reliable anchor point can significantly improve topdown model accuracy as they benefit from a consistent geometry of the body parts relative to the center of the image. | ||
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied. | ||
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. | ||
''' | ||
part_names: Optional[List[str]] = None | ||
anchor_part: Optional[int] = None | ||
sigma: Optional[float] = None | ||
output_stride: Optional[float] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM: HeadConfig and related classes.
The HeadConfig
and related configuration classes (SingleInstanceConfig
, SingleInstanceConfMapsConfig
, CentroidConfMapsConfig
, CenteredInstanceConfMapsConfig
) are well-structured and use appropriate attrs
decorators. The docstrings provide clear explanations of the attributes.
However, there's a naming inconsistency in the centroid
class. Please update the class name to follow PascalCase convention:
@attrs.define
class Centroid:
confmaps: Optional[CentroidConfMapsConfig] = None
This change will maintain consistency with the naming of other classes in the file.
🧰 Tools
🪛 Ruff
96-96: Undefined name
Dict
(F821)
96-96: Undefined name
Optional
(F821)
96-96: Undefined name
Dict
(F821)
108-108: Undefined name
Optional
(F821)
108-108: Undefined name
SingleInstanceConfMapsConfig
(F821)
119-119: Undefined name
Optional
(F821)
119-119: Undefined name
List
(F821)
120-120: Undefined name
Optional
(F821)
121-121: Undefined name
Optional
(F821)
126-126: Undefined name
Optional
(F821)
126-126: Undefined name
CentroidConfMapsConfig
(F821)
137-137: Undefined name
Optional
(F821)
138-138: Undefined name
Optional
(F821)
139-139: Undefined name
Optional
(F821)
144-144: Undefined name
Optional
(F821)
144-144: Undefined name
CenteredInstanceConfMapsConfig
(F821)
156-156: Undefined name
Optional
(F821)
156-156: Undefined name
List
(F821)
157-157: Undefined name
Optional
(F821)
158-158: Undefined name
Optional
(F821)
159-159: Undefined name
Optional
(F821)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py
17-17: Undefined name
BackboneType
(F821)
17-17: Undefined name
BackboneType
(F821)
18-18: Undefined name
Union
(F821)
18-18: Undefined name
UNetConfig
(F821)
18-18: Undefined name
ConvNextConfig
(F821)
18-18: Undefined name
SwinTConfig
(F821)
19-19: Undefined name
HeadConfig
(F821)
19-19: Undefined name
HeadConfig
(F821)
28-28: Undefined name
BackboneType
(F821)
29-29: Undefined name
UNetConfig
(F821)
30-30: Undefined name
BackboneType
(F821)
31-31: Undefined name
ConvNextConfig
(F821)
32-32: Undefined name
BackboneType
(F821)
33-33: Undefined name
SwinTConfig
(F821)
41-41: Undefined name
BackboneType
(F821)
44-44: Undefined name
BackboneType
(F821)
47-47: Undefined name
BackboneType
(F821)
96-96: Undefined name
Dict
(F821)
96-96: Undefined name
Optional
(F821)
96-96: Undefined name
Dict
(F821)
108-108: Undefined name
Optional
(F821)
108-108: Undefined name
SingleInstanceConfMapsConfig
(F821)
119-119: Undefined name
Optional
(F821)
119-119: Undefined name
List
(F821)
120-120: Undefined name
Optional
(F821)
121-121: Undefined name
Optional
(F821)
126-126: Undefined name
Optional
(F821)
126-126: Undefined name
CentroidConfMapsConfig
(F821)
137-137: Undefined name
Optional
(F821)
138-138: Undefined name
Optional
(F821)
139-139: Undefined name
Optional
(F821)
144-144: Undefined name
Optional
(F821)
144-144: Undefined name
CenteredInstanceConfMapsConfig
(F821)
156-156: Undefined name
Optional
(F821)
156-156: Undefined name
List
(F821)
157-157: Undefined name
Optional
(F821)
158-158: Undefined name
Optional
(F821)
159-159: Undefined name
Optional
(F821)
164-164: Undefined name
Optional
(F821)
164-164: Undefined name
BottomUpConfMapsConfig
(F821)
165-165: Undefined name
Optional
(F821)
165-165: Undefined name
PAFConfig
(F821)
177-177: Undefined name
Optional
(F821)
177-177: Undefined name
List
(F821)
178-178: Undefined name
Optional
(F821)
179-179: Undefined name
Optional
(F821)
180-180: Undefined name
Optional
(F821)
192-192: Undefined name
Optional
(F821)
192-192: Undefined name
List
(F821)
193-193: Undefined name
Optional
(F821)
194-194: Undefined name
Optional
(F821)
195-195: Undefined name
Optional
(F821)
🔇 Additional comments (1)
sleap_nn/config/model_config.py (1)
1-196
: Overall, well-structured configuration classes with room for improvement.The file contains a comprehensive set of configuration classes for various aspects of a machine learning model. The use of
attrs
for class definitions and the detailed docstrings are commendable. However, there are a few areas for improvement:
- Add missing imports from the
typing
module to resolve undefined names.- Consider moving nested classes (BackboneType, UNetConfig, ConvNextConfig, SwinTConfig) to the module level for better modularity.
- Fix the naming inconsistency in the
centroid
class.- Simplify the
validate_pre_trained_weights
method in the ModelConfig class.Addressing these points will enhance the overall quality and maintainability of the code.
🧰 Tools
🪛 Ruff
17-17: Undefined name
BackboneType
(F821)
17-17: Undefined name
BackboneType
(F821)
18-18: Undefined name
Union
(F821)
18-18: Undefined name
UNetConfig
(F821)
18-18: Undefined name
ConvNextConfig
(F821)
18-18: Undefined name
SwinTConfig
(F821)
19-19: Undefined name
HeadConfig
(F821)
19-19: Undefined name
HeadConfig
(F821)
28-28: Undefined name
BackboneType
(F821)
29-29: Undefined name
UNetConfig
(F821)
30-30: Undefined name
BackboneType
(F821)
31-31: Undefined name
ConvNextConfig
(F821)
32-32: Undefined name
BackboneType
(F821)
33-33: Undefined name
SwinTConfig
(F821)
41-41: Undefined name
BackboneType
(F821)
44-44: Undefined name
BackboneType
(F821)
47-47: Undefined name
BackboneType
(F821)
96-96: Undefined name
Dict
(F821)
96-96: Undefined name
Optional
(F821)
96-96: Undefined name
Dict
(F821)
108-108: Undefined name
Optional
(F821)
108-108: Undefined name
SingleInstanceConfMapsConfig
(F821)
119-119: Undefined name
Optional
(F821)
119-119: Undefined name
List
(F821)
120-120: Undefined name
Optional
(F821)
121-121: Undefined name
Optional
(F821)
126-126: Undefined name
Optional
(F821)
126-126: Undefined name
CentroidConfMapsConfig
(F821)
137-137: Undefined name
Optional
(F821)
138-138: Undefined name
Optional
(F821)
139-139: Undefined name
Optional
(F821)
144-144: Undefined name
Optional
(F821)
144-144: Undefined name
CenteredInstanceConfMapsConfig
(F821)
156-156: Undefined name
Optional
(F821)
156-156: Undefined name
List
(F821)
157-157: Undefined name
Optional
(F821)
158-158: Undefined name
Optional
(F821)
159-159: Undefined name
Optional
(F821)
164-164: Undefined name
Optional
(F821)
164-164: Undefined name
BottomUpConfMapsConfig
(F821)
165-165: Undefined name
Optional
(F821)
165-165: Undefined name
PAFConfig
(F821)
177-177: Undefined name
Optional
(F821)
177-177: Undefined name
List
(F821)
178-178: Undefined name
Optional
(F821)
179-179: Undefined name
Optional
(F821)
180-180: Undefined name
Optional
(F821)
192-192: Undefined name
Optional
(F821)
192-192: Undefined name
List
(F821)
193-193: Undefined name
Optional
(F821)
194-194: Undefined name
Optional
(F821)
195-195: Undefined name
Optional
(F821)
import attrs | ||
from enum import Enum |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add missing imports from typing module.
To resolve the undefined names flagged by the static analysis tool, please add the following imports at the beginning of the file:
from typing import Union, Optional, List, Dict
This will address many of the undefined name issues throughout the file.
sleap_nn/config/model_config.py
Outdated
@attrs.define | ||
class ModelConfig: | ||
"""Configurations related to model architecture. | ||
|
||
Attributes: | ||
init_weight: (str) model weights initialization method. "default" uses kaiming uniform initialization and "xavier" uses Xavier initialization method. | ||
pre_trained_weights: (str) Pretrained weights file name supported only for ConvNext and SwinT backbones. For ConvNext, one of ["ConvNeXt_Base_Weights","ConvNeXt_Tiny_Weights", "ConvNeXt_Small_Weights", "ConvNeXt_Large_Weights"]. For SwinT, one of ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"]. | ||
backbone_type: (str) Backbone architecture for the model to be trained. One of "unet", "convnext" or "swint". | ||
|
||
""" | ||
|
||
init_weight: str = "default" | ||
pre_trained_weights: str = None | ||
backbone_type: BackboneType = BackboneType.UNET | ||
backbone_config: Union[UNetConfig, ConvNextConfig, SwinTConfig] = attrs.field(init=False) # backbone_config can be any of these 3 configurations. init=False lets you set the parameters later (not in initialization) | ||
head_configs: HeadConfig = attrs.field(factory=HeadConfig) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider moving BackboneType enum outside ModelConfig class.
The BackboneType
enum is currently defined inside the ModelConfig
class. It's generally more common and easier to use if enums are defined at the module level. Consider moving it outside and above the ModelConfig
class definition.
🧰 Tools
🪛 Ruff
17-17: Undefined name
BackboneType
(F821)
17-17: Undefined name
BackboneType
(F821)
18-18: Undefined name
Union
(F821)
18-18: Undefined name
UNetConfig
(F821)
18-18: Undefined name
ConvNextConfig
(F821)
18-18: Undefined name
SwinTConfig
(F821)
19-19: Undefined name
HeadConfig
(F821)
19-19: Undefined name
HeadConfig
(F821)
sleap_nn/config/model_config.py
Outdated
@attrs.define | ||
class UNetConfig: | ||
in_channels: int = 1 | ||
kernel_size: int = 3 | ||
filters: int = 32 | ||
filters_rate: float = 1.5 | ||
max_stride: int = None | ||
stem_stride: int = None | ||
middle_block: bool = True | ||
up_interpolate: bool = True | ||
stacks: int = 3 | ||
convs_per_block: int = 2 | ||
|
||
@attrs.define | ||
class ConvNextConfig: | ||
model_type: str = "tiny" # Options: tiny, small, base, large | ||
arch: dict = attrs.field(factory=lambda: {'depths': [3, 3, 9, 3], 'channels': [96, 192, 384, 768]}) | ||
stem_patch_kernel: int = 4 | ||
stem_patch_stride: int = 2 | ||
in_channels: int = 1 | ||
kernel_size: int = 3 | ||
filters_rate: float = 1.5 | ||
convs_per_block: int = 2 | ||
up_interpolate: bool = True | ||
|
||
@attrs.define | ||
class SwinTConfig: | ||
model_type: str = "tiny" # Options: tiny, small, base | ||
arch: dict = attrs.field(factory=lambda: {'embed': 96, 'depths': [2, 2, 6, 2], 'channels': [3, 6, 12, 24]}) | ||
patch_size: list = attrs.field(factory=lambda: [4, 4]) | ||
stem_patch_stride: int = 2 | ||
window_size: list = attrs.field(factory=lambda: [7, 7]) | ||
in_channels: int = 1 | ||
kernel_size: int = 3 | ||
filters_rate: float = 1.5 | ||
convs_per_block: int = 2 | ||
up_interpolate: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider moving backbone config classes outside ModelConfig.
The UNetConfig
, ConvNextConfig
, and SwinTConfig
classes are currently nested within the ModelConfig
class. This structure might make it harder to use these configs independently or to import them in other parts of the codebase. Consider moving these classes to the module level, just after the ModelConfig
class definition. This would improve modularity and make the configs more accessible.
Example:
@attrs.define
class ModelConfig:
# ... (existing ModelConfig code) ...
@attrs.define
class UNetConfig:
# ... (existing UNetConfig code) ...
@attrs.define
class ConvNextConfig:
# ... (existing ConvNextConfig code) ...
@attrs.define
class SwinTConfig:
# ... (existing SwinTConfig code) ...
This change would make the code structure more flat and easier to navigate.
sleap_nn/config/model_config.py
Outdated
@attrs.define | ||
class centroid: | ||
confmaps: Optional[CentroidConfMapsConfig] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix naming inconsistency in centroid class.
The centroid
class is using lowercase naming, which is inconsistent with the PascalCase naming convention used for other classes in this file. Please update the class name to follow the PascalCase convention:
@attrs.define
class Centroid:
confmaps: Optional[CentroidConfMapsConfig] = None
This change will maintain consistency with the naming of other classes in the file.
🧰 Tools
🪛 Ruff
126-126: Undefined name
Optional
(F821)
126-126: Undefined name
CentroidConfMapsConfig
(F821)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- sleap_nn/config/trainer_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/trainer_config.py
1-1:
omegaconf.OmegaConf
imported but unusedRemove unused import:
omegaconf.OmegaConf
(F401)
30-30: Undefined name
attrs
(F821)
sleap_nn/config/trainer_config.py
Outdated
class TrainerConfig: | ||
"""Configuration of Trainer. | ||
|
||
Attributes: | ||
train_data_loader: (Note: Any parameters from Torch's DataLoader could be used.) | ||
val_data_loader: (Similar to train_data_loader) | ||
model_ckpt: (Note: Any parameters from Lightning's ModelCheckpoint could be used.) | ||
trainer_devices: (int) Number of devices to train on (int), which devices to train on (list or str), or "auto" to select automatically. | ||
trainer_accelerator: (str) One of the ("cpu", "gpu", "tpu", "ipu", "auto"). "auto" recognises the machine the model is running on and chooses the appropriate accelerator for the Trainer to be connected to. | ||
enable_progress_bar: (bool) When True, enables printing the logs during training. | ||
steps_per_epoch: (int) Minimum number of iterations in a single epoch. (Useful if model is trained with very few data points). Refer limit_train_batches parameter of Torch Trainer. If None, the number of iterations depends on the number of samples in the train dataset. | ||
max_epochs: (int) Maxinum number of epochs to run. | ||
seed: (int) Seed value for the current experiment. | ||
use_wandb: (bool) True to enable wandb logging. | ||
save_ckpt: (bool) True to enable checkpointing. | ||
save_ckpt_path: (str) Directory path to save the training config and checkpoint files. Default: "./" | ||
resume_ckpt_path: (str) Path to .ckpt file from which training is resumed. Default: None. | ||
wandb: (Only if use_wandb is True, else skip this) | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Define the attributes within the TrainerConfig
class.
The TrainerConfig
class currently lacks attribute definitions. You need to define the attributes using attrs
fields to hold the configuration parameters.
Add the attribute definitions and necessary imports:
+from typing import Any, Union, List, Optional
class TrainerConfig:
"""Configuration of Trainer.
Attributes:
train_data_loader: (Note: Any parameters from Torch's DataLoader could be used.)
val_data_loader: (Similar to train_data_loader)
model_ckpt: (Note: Any parameters from Lightning's ModelCheckpoint could be used.)
trainer_devices: (int) Number of devices to train on (int), which devices to train on (list or str), or "auto" to select automatically.
trainer_accelerator: (str) One of the ("cpu", "gpu", "tpu", "ipu", "auto"). "auto" recognises the machine the model is running on and chooses the appropriate accelerator for the Trainer to be connected to.
enable_progress_bar: (bool) When True, enables printing the logs during training.
steps_per_epoch: (int) Minimum number of iterations in a single epoch. (Useful if model is trained with very few data points). Refer limit_train_batches parameter of Torch Trainer. If None, the number of iterations depends on the number of samples in the train dataset.
max_epochs: (int) Maxinum number of epochs to run.
seed: (int) Seed value for the current experiment.
use_wandb: (bool) True to enable wandb logging.
save_ckpt: (bool) True to enable checkpointing.
save_ckpt_path: (str) Directory path to save the training config and checkpoint files. Default: "./"
resume_ckpt_path: (str) Path to .ckpt file from which training is resumed. Default: None.
wandb: (Only if use_wandb is True, else skip this)
"""
+ train_data_loader: Any = attrs.field()
+ val_data_loader: Any = attrs.field()
+ model_ckpt: Any = attrs.field()
+ trainer_devices: Union[int, List[int], str] = attrs.field(default="auto")
+ trainer_accelerator: str = attrs.field(default="auto")
+ enable_progress_bar: bool = attrs.field(default=True)
+ steps_per_epoch: Optional[int] = attrs.field(default=None)
+ max_epochs: int = attrs.field(default=100)
+ seed: int = attrs.field(default=42)
+ use_wandb: bool = attrs.field(default=False)
+ save_ckpt: bool = attrs.field(default=True)
+ save_ckpt_path: str = attrs.field(default="./")
+ resume_ckpt_path: Optional[str] = attrs.field(default=None)
+ wandb: Any = attrs.field(default=None)
Committable suggestion was skipped due to low confidence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Outside diff range and nitpick comments (1)
sleap_nn/config/trainer_config.py (1)
55-57
: Use consistent quote style for docstrings.For consistency with the rest of the file, use double quotes for multi-line docstrings instead of single quotes.
Replace the single quotes with double quotes:
""" train and val data_loader: (Note: Any parameters from Torch's DataLoader could be used.) Attributes: ... """
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- sleap_nn/config/trainer_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/trainer_config.py
1-1:
omegaconf.OmegaConf
imported but unusedRemove unused import:
omegaconf.OmegaConf
(F401)
30-30: Undefined name
attrs
(F821)
50-50: Undefined name
TrainDataLoader
(F821)
50-50: Undefined name
attrs
(F821)
50-50: Undefined name
DataLoader
(F821)
51-51: Undefined name
ValDataLoader
(F821)
51-51: Undefined name
attrs
(F821)
51-51: Undefined name
DataLoader
(F821)
53-53: Undefined name
attrs
(F821)
🔇 Additional comments (3)
sleap_nn/config/trainer_config.py (3)
3-27
: LGTM: Well-documented module purpose.The module-level docstring provides a clear and comprehensive explanation of the purpose and design of the configuration classes. It effectively communicates the intended use and benefits of this approach.
32-49
: LGTM: Well-documentedTrainerConfig
class.The docstring for the
TrainerConfig
class provides clear and detailed explanations for each attribute, which is excellent for maintainability and usability.
1-66
: Summary: Good implementation with minor improvements needed.Overall, the implementation of
TrainerConfig
andDataLoader
classes provides a solid foundation for managing training job parameters. The code is well-documented and structured. To improve it further:
- Address the unused import and missing
attrs
import.- Ensure consistency in attribute definitions using
attrs.field
.- Define or import the
TrainDataLoader
andValDataLoader
types.- Use consistent quote style for docstrings.
These changes will enhance the code's maintainability and type safety.
🧰 Tools
🪛 Ruff
1-1:
omegaconf.OmegaConf
imported but unusedRemove unused import:
omegaconf.OmegaConf
(F401)
30-30: Undefined name
attrs
(F821)
50-50: Undefined name
TrainDataLoader
(F821)
50-50: Undefined name
attrs
(F821)
50-50: Undefined name
DataLoader
(F821)
51-51: Undefined name
ValDataLoader
(F821)
51-51: Undefined name
attrs
(F821)
51-51: Undefined name
DataLoader
(F821)
53-53: Undefined name
attrs
(F821)
""" | ||
|
||
|
||
@attrs.define |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add missing import for attrs
.
The attrs
module is used but not imported. Add the import at the beginning of the file.
Add this import at the top of the file:
import attrs
🧰 Tools
🪛 Ruff
30-30: Undefined name
attrs
(F821)
sleap_nn/config/trainer_config.py
Outdated
train_data_loader: TrainDataLoader = attrs.field(factory=DataLoader) | ||
val_data_loader: ValDataLoader = attrs.field(factory=DataLoader) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Define or import TrainDataLoader
and ValDataLoader
types.
The TrainDataLoader
and ValDataLoader
types are used but not defined or imported. Ensure these types are properly defined or imported to maintain type consistency.
If these types are defined elsewhere in the project, import them at the beginning of the file:
from .data_config import TrainDataLoader, ValDataLoader # Adjust the import path as needed
If they are aliases for the DataLoader
class, consider using type aliases:
from typing import TypeAlias
TrainDataLoader: TypeAlias = DataLoader
ValDataLoader: TypeAlias = DataLoader
🧰 Tools
🪛 Ruff
50-50: Undefined name
TrainDataLoader
(F821)
50-50: Undefined name
attrs
(F821)
50-50: Undefined name
DataLoader
(F821)
51-51: Undefined name
ValDataLoader
(F821)
51-51: Undefined name
attrs
(F821)
51-51: Undefined name
DataLoader
(F821)
sleap_nn/config/trainer_config.py
Outdated
batch_size: int = 1 | ||
shuffle: bool=False | ||
num_workers: int=0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use attrs.field
for attribute definitions in DataLoader
class.
For consistency with the TrainerConfig
class, use attrs.field
to define the attributes in the DataLoader
class.
Apply this diff to use attrs.field
:
- batch_size: int = 1
- shuffle: bool=False
- num_workers: int=0
+ batch_size: int = attrs.field(default=1)
+ shuffle: bool = attrs.field(default=False)
+ num_workers: int = attrs.field(default=0)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
batch_size: int = 1 | |
shuffle: bool=False | |
num_workers: int=0 | |
batch_size: int = attrs.field(default=1) | |
shuffle: bool = attrs.field(default=False) | |
num_workers: int = attrs.field(default=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- sleap_nn/config/trainer_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/trainer_config.py
63-64: SyntaxError: Expected an expression
🔇 Additional comments (1)
sleap_nn/config/trainer_config.py (1)
75-77
: 🛠️ Refactor suggestionUse attrs.field for attribute definitions in DataLoaderConfig.
For consistency with other classes and to leverage attrs features, use
attrs.field
to define the attributes in theDataLoaderConfig
class.Apply this diff to use
attrs.field
:- batch_size: int = 1 - shuffle: bool=False - num_workers: int=0 + batch_size: int = attrs.field(default=1) + shuffle: bool = attrs.field(default=False) + num_workers: int = attrs.field(default=0)Likely invalid or redundant comment.
@@ -0,0 +1,88 @@ | |||
from omegaconf import OmegaConf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add missing imports and remove unused import.
- Remove the unused import of
OmegaConf
. - Add the necessary imports for
attrs
and typing modules.
Apply this diff to fix the imports:
-from omegaconf import OmegaConf
+import attrs
+from typing import Union, List, Optional
Also applies to: 30-30
sleap_nn/config/trainer_config.py
Outdated
wandb: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Complete the 'wandb' attribute definition.
The 'wandb' attribute is incomplete. It should be properly defined with a type and default value.
Consider updating the 'wandb' attribute as follows:
wandb: Optional[Any] = attrs.field(default=None)
Also, add from typing import Any
to the import statement if not already present.
🧰 Tools
🪛 Ruff
63-64: SyntaxError: Expected an expression
sleap_nn/config/trainer_config.py
Outdated
save_top_k: int = 1 | ||
save_last: Optional[bool]=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use attrs.field for attribute definitions in ModelCkptConfig.
For consistency with other classes and to leverage attrs features, use attrs.field
to define the attributes in the ModelCkptConfig
class.
Apply this diff to use attrs.field
:
- save_top_k: int = 1
- save_last: Optional[bool]=None
+ save_top_k: int = attrs.field(default=1)
+ save_last: Optional[bool] = attrs.field(default=None)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
save_top_k: int = 1 | |
save_last: Optional[bool]=None | |
save_top_k: int = attrs.field(default=1) | |
save_last: Optional[bool] = attrs.field(default=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- sleap_nn/config/trainer_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/trainer_config.py
2-2:
omegaconf.OmegaConf
imported but unusedRemove unused import:
omegaconf.OmegaConf
(F401)
55-55: Undefined name
DataLoaderConfig
(F821)
55-55: Undefined name
DataLoaderConfig
(F821)
56-56: Undefined name
DataLoaderConfig
(F821)
56-56: Undefined name
DataLoaderConfig
(F821)
57-57: Undefined name
ModelCkptConfig
(F821)
57-57: Undefined name
ModelCkptConfig
(F821)
58-58: Undefined name
Union
(F821)
58-58: Undefined name
List
(F821)
61-61: Undefined name
Optional
(F821)
63-63: Undefined name
Optional
(F821)
67-67: Undefined name
Optional
(F821)
68-68: Undefined name
Optional
(F821)
68-68: Undefined name
WandBConfig
(F821)
69-69: Undefined name
Optional
(F821)
69-69: Undefined name
OptimizerConfig
(F821)
69-69: Undefined name
OptimizerConfig
(F821)
70-70: Undefined name
LRSchedulerConfig
(F821)
70-70: Undefined name
LRSchedulerConfig
(F821)
71-71: Undefined name
EarlyStoppingConfig
(F821)
71-71: Undefined name
attr
(F821)
71-71: Undefined name
EarlyStoppingConfig
(F821)
104-104: Undefined name
Optional
(F821)
119-119: Undefined name
Optional
(F821)
120-120: Undefined name
Optional
(F821)
121-121: Undefined name
Optional
(F821)
122-122: Undefined name
Optional
(F821)
123-123: Undefined name
Optional
(F821)
124-124: Undefined name
Optional
(F821)
125-125: Undefined name
Optional
(F821)
125-125: Undefined name
List
(F821)
157-157: Undefined name
Union
(F821)
157-157: Undefined name
List
(F821)
🔇 Additional comments (5)
sleap_nn/config/trainer_config.py (5)
4-28
: LGTM: Well-written file-level docstring.The file-level docstring provides a clear explanation of the purpose and design principles of the configuration classes. It effectively communicates the separation of concerns between parameter specification and implementation.
106-125
: LGTM: Well-structuredWandBConfig
class.The
WandBConfig
class is well-defined with appropriate use ofattrs.define
andOptional
type hints for its attributes. The docstring provides clear explanations for each attribute.🧰 Tools
🪛 Ruff
119-119: Undefined name
Optional
(F821)
120-120: Undefined name
Optional
(F821)
121-121: Undefined name
Optional
(F821)
122-122: Undefined name
Optional
(F821)
123-123: Undefined name
Optional
(F821)
124-124: Undefined name
Optional
(F821)
125-125: Undefined name
Optional
(F821)
125-125: Undefined name
List
(F821)
1-170
: Summary: Well-structured configuration classes with minor improvements needed.Overall, this file introduces a comprehensive set of configuration classes for training job parameters. The design is solid, and the documentation is thorough. To improve the code:
- Ensure consistency in attribute definitions by using
attrs.field
across all classes.- Add missing type hints and remove unused imports.
- Fix the
early_stopping
attribute inTrainerConfig
to useattrs.field
.These changes will enhance the consistency and type safety of the configuration classes.
🧰 Tools
🪛 Ruff
2-2:
omegaconf.OmegaConf
imported but unusedRemove unused import:
omegaconf.OmegaConf
(F401)
55-55: Undefined name
DataLoaderConfig
(F821)
55-55: Undefined name
DataLoaderConfig
(F821)
56-56: Undefined name
DataLoaderConfig
(F821)
56-56: Undefined name
DataLoaderConfig
(F821)
57-57: Undefined name
ModelCkptConfig
(F821)
57-57: Undefined name
ModelCkptConfig
(F821)
58-58: Undefined name
Union
(F821)
58-58: Undefined name
List
(F821)
61-61: Undefined name
Optional
(F821)
63-63: Undefined name
Optional
(F821)
67-67: Undefined name
Optional
(F821)
68-68: Undefined name
Optional
(F821)
68-68: Undefined name
WandBConfig
(F821)
69-69: Undefined name
Optional
(F821)
69-69: Undefined name
OptimizerConfig
(F821)
69-69: Undefined name
OptimizerConfig
(F821)
70-70: Undefined name
LRSchedulerConfig
(F821)
70-70: Undefined name
LRSchedulerConfig
(F821)
71-71: Undefined name
EarlyStoppingConfig
(F821)
71-71: Undefined name
attr
(F821)
71-71: Undefined name
EarlyStoppingConfig
(F821)
104-104: Undefined name
Optional
(F821)
119-119: Undefined name
Optional
(F821)
120-120: Undefined name
Optional
(F821)
121-121: Undefined name
Optional
(F821)
122-122: Undefined name
Optional
(F821)
123-123: Undefined name
Optional
(F821)
124-124: Undefined name
Optional
(F821)
125-125: Undefined name
Optional
(F821)
125-125: Undefined name
List
(F821)
157-157: Undefined name
Union
(F821)
157-157: Undefined name
List
(F821)
82-93
: 🛠️ Refactor suggestionUse
attrs.field
for attribute definitions inDataLoaderConfig
.For consistency with the
TrainerConfig
class and to leverage attrs features, useattrs.field
to define the attributes in theDataLoaderConfig
class.Apply this diff:
- batch_size: int = 1 - shuffle: bool=False - num_workers: int=0 + batch_size: int = attrs.field(default=1) + shuffle: bool = attrs.field(default=False) + num_workers: int = attrs.field(default=0)Likely invalid or redundant comment.
95-104
: 🛠️ Refactor suggestionUse
attrs.field
for attribute definitions inModelCkptConfig
.For consistency with the
TrainerConfig
class and to leverage attrs features, useattrs.field
to define the attributes in theModelCkptConfig
class.Apply this diff:
- save_top_k: int = 1 - save_last: Optional[bool]=None + save_top_k: int = attrs.field(default=1) + save_last: Optional[bool] = attrs.field(default=None)Likely invalid or redundant comment.
🧰 Tools
🪛 Ruff
104-104: Undefined name
Optional
(F821)
sleap_nn/config/trainer_config.py
Outdated
@attrs.define | ||
class OptimizerConfig: | ||
'''optimizer configuration | ||
|
||
lr: (float) Learning rate of type float. Default: 1e-3 | ||
amsgrad: (bool) Enable AMSGrad with the optimizer. Default: False | ||
''' | ||
lr: float = 1e-3 | ||
amsgrad: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use attrs.field
for attribute definitions in OptimizerConfig
.
For consistency with the TrainerConfig
class and to leverage attrs features, use attrs.field
to define the attributes in the OptimizerConfig
class.
Apply this diff:
- lr: float = 1e-3
- amsgrad: bool = False
+ lr: float = attrs.field(default=1e-3)
+ amsgrad: bool = attrs.field(default=False)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
@attrs.define | |
class OptimizerConfig: | |
'''optimizer configuration | |
lr: (float) Learning rate of type float. Default: 1e-3 | |
amsgrad: (bool) Enable AMSGrad with the optimizer. Default: False | |
''' | |
lr: float = 1e-3 | |
amsgrad: bool = False | |
@attrs.define | |
class OptimizerConfig: | |
'''optimizer configuration | |
lr: (float) Learning rate of type float. Default: 1e-3 | |
amsgrad: (bool) Enable AMSGrad with the optimizer. Default: False | |
''' | |
lr: float = attrs.field(default=1e-3) | |
amsgrad: bool = attrs.field(default=False) |
sleap_nn/config/trainer_config.py
Outdated
@attrs.define | ||
class EarlyStoppingConfig: | ||
'''early_stopping configuration | ||
|
||
Attributes: | ||
stop_training_on_plateau: (bool) True if early stopping should be enabled. | ||
min_delta: (float) Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than or equal to min_delta, will count as no improvement. | ||
patience: (int) Number of checks with no improvement after which training will be stopped. Under the default configuration, one check happens after every training epoch. | ||
''' | ||
stop_training_on_plateau: bool = False | ||
min_delta: float = 0.0 | ||
patience: int = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use attrs.field
for attribute definitions in EarlyStoppingConfig
.
For consistency with the TrainerConfig
class and to leverage attrs features, use attrs.field
to define the attributes in the EarlyStoppingConfig
class.
Apply this diff:
- stop_training_on_plateau: bool = False
- min_delta: float = 0.0
- patience: int = 1
+ stop_training_on_plateau: bool = attrs.field(default=False)
+ min_delta: float = attrs.field(default=0.0)
+ patience: int = attrs.field(default=1)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
@attrs.define | |
class EarlyStoppingConfig: | |
'''early_stopping configuration | |
Attributes: | |
stop_training_on_plateau: (bool) True if early stopping should be enabled. | |
min_delta: (float) Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than or equal to min_delta, will count as no improvement. | |
patience: (int) Number of checks with no improvement after which training will be stopped. Under the default configuration, one check happens after every training epoch. | |
''' | |
stop_training_on_plateau: bool = False | |
min_delta: float = 0.0 | |
patience: int = 1 | |
@attrs.define | |
class EarlyStoppingConfig: | |
'''early_stopping configuration | |
Attributes: | |
stop_training_on_plateau: (bool) True if early stopping should be enabled. | |
min_delta: (float) Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than or equal to min_delta, will count as no improvement. | |
patience: (int) Number of checks with no improvement after which training will be stopped. Under the default configuration, one check happens after every training epoch. | |
''' | |
stop_training_on_plateau: bool = attrs.field(default=False) | |
min_delta: float = attrs.field(default=0.0) | |
patience: int = attrs.field(default=1) |
sleap_nn/config/trainer_config.py
Outdated
@attrs.define | ||
class TrainerConfig: | ||
"""Configuration of Trainer. | ||
|
||
Attributes: | ||
train_data_loader: (Note: Any parameters from Torch's DataLoader could be used.) | ||
val_data_loader: (Similar to train_data_loader) | ||
model_ckpt: (Note: Any parameters from Lightning's ModelCheckpoint could be used.) | ||
trainer_devices: (int) Number of devices to train on (int), which devices to train on (list or str), or "auto" to select automatically. | ||
trainer_accelerator: (str) One of the ("cpu", "gpu", "tpu", "ipu", "auto"). "auto" recognises the machine the model is running on and chooses the appropriate accelerator for the Trainer to be connected to. | ||
enable_progress_bar: (bool) When True, enables printing the logs during training. | ||
steps_per_epoch: (int) Minimum number of iterations in a single epoch. (Useful if model is trained with very few data points). Refer limit_train_batches parameter of Torch Trainer. If None, the number of iterations depends on the number of samples in the train dataset. | ||
max_epochs: (int) Maxinum number of epochs to run. | ||
seed: (int) Seed value for the current experiment. | ||
use_wandb: (bool) True to enable wandb logging. | ||
save_ckpt: (bool) True to enable checkpointing. | ||
save_ckpt_path: (str) Directory path to save the training config and checkpoint files. Default: "./" | ||
resume_ckpt_path: (str) Path to .ckpt file from which training is resumed. Default: None. | ||
wandb: (Only if use_wandb is True, else skip this) | ||
optimizer_name: (str) Optimizer to be used. One of ["Adam", "AdamW"]. | ||
optimizer: | ||
lr_scheduler: | ||
early_stopping: | ||
""" | ||
train_data_loader: DataLoaderConfig = attrs.field(factory=DataLoaderConfig) | ||
val_data_loader: DataLoaderConfig = attrs.field(factory=DataLoaderConfig) | ||
model_ckpt: ModelCkptConfig = attrs.field(factory=ModelCkptConfig) | ||
trainer_devices: Union[int, List[int], str] = "auto" | ||
trainer_accelerator: str="auto" | ||
enable_progress_bar: bool = True | ||
steps_per_epoch: Optional[int] = None | ||
max_epochs: int = 10 | ||
seed: Optional[int] = None | ||
use_wandb: bool = False | ||
save_ckpt: bool = False | ||
save_ckpt_path: str = "./" | ||
resume_ckpt_path: Optional[str] = None | ||
wandb: Optional[WandBConfig] = attrs.field(init=False) | ||
optimizer: Optional[OptimizerConfig] = attrs.field(factory=OptimizerConfig) | ||
lr_scheduler: LRSchedulerConfig = attrs.field(factory=LRSchedulerConfig) | ||
early_stopping: EarlyStoppingConfig = attr.field(factory=EarlyStoppingConfig) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix early_stopping
attribute and add missing type hints.
- The
early_stopping
attribute usesattr.field
instead ofattrs.field
. - Add type hints for
trainer_devices
,trainer_accelerator
, and other attributes to resolve static analysis warnings.
Apply these changes:
- trainer_devices: Union[int, List[int], str] = "auto"
- trainer_accelerator: str="auto"
+ trainer_devices: Union[int, List[int], str] = attrs.field(default="auto")
+ trainer_accelerator: str = attrs.field(default="auto")
# ... (apply similar changes to other attributes)
- early_stopping: EarlyStoppingConfig = attr.field(factory=EarlyStoppingConfig)
+ early_stopping: EarlyStoppingConfig = attrs.field(factory=EarlyStoppingConfig)
Committable suggestion was skipped due to low confidence.
🧰 Tools
🪛 Ruff
55-55: Undefined name
DataLoaderConfig
(F821)
55-55: Undefined name
DataLoaderConfig
(F821)
56-56: Undefined name
DataLoaderConfig
(F821)
56-56: Undefined name
DataLoaderConfig
(F821)
57-57: Undefined name
ModelCkptConfig
(F821)
57-57: Undefined name
ModelCkptConfig
(F821)
58-58: Undefined name
Union
(F821)
58-58: Undefined name
List
(F821)
61-61: Undefined name
Optional
(F821)
63-63: Undefined name
Optional
(F821)
67-67: Undefined name
Optional
(F821)
68-68: Undefined name
Optional
(F821)
68-68: Undefined name
WandBConfig
(F821)
69-69: Undefined name
Optional
(F821)
69-69: Undefined name
OptimizerConfig
(F821)
69-69: Undefined name
OptimizerConfig
(F821)
70-70: Undefined name
LRSchedulerConfig
(F821)
70-70: Undefined name
LRSchedulerConfig
(F821)
71-71: Undefined name
EarlyStoppingConfig
(F821)
71-71: Undefined name
attr
(F821)
71-71: Undefined name
EarlyStoppingConfig
(F821)
sleap_nn/config/trainer_config.py
Outdated
@attrs.define | ||
class LRSchedulerConfig: | ||
'''lr_scheduler configuration | ||
|
||
Attributes: | ||
mode: (str) One of "min", "max". In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: "min". | ||
threshold: (float) Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4. | ||
threshold_mode: (str) One of "rel", "abs". In rel mode, dynamic_threshold = best * ( 1 + threshold ) in max mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: "rel". | ||
cooldown: (int) Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0 | ||
patience: (int) Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the third epoch if the loss still hasn’t improved then. Default: 10. | ||
factor: (float) Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1. | ||
min_lr: (float or List[float]) A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0. | ||
''' | ||
|
||
mode: str = "min" | ||
threshold: float = 1e-4 | ||
threshold_mode: str = "rel" | ||
cooldown: int = 0 | ||
patience: int = 10 | ||
factor: float = 0.1 | ||
min_lr: Union[float, List[float]] = 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use attrs.field
for attribute definitions in LRSchedulerConfig
.
For consistency with the TrainerConfig
class and to leverage attrs features, use attrs.field
to define the attributes in the LRSchedulerConfig
class.
Apply this diff:
- mode: str = "min"
- threshold: float = 1e-4
- threshold_mode: str = "rel"
- cooldown: int = 0
- patience: int = 10
- factor: float = 0.1
- min_lr: Union[float, List[float]] = 0.0
+ mode: str = attrs.field(default="min")
+ threshold: float = attrs.field(default=1e-4)
+ threshold_mode: str = attrs.field(default="rel")
+ cooldown: int = attrs.field(default=0)
+ patience: int = attrs.field(default=10)
+ factor: float = attrs.field(default=0.1)
+ min_lr: Union[float, List[float]] = attrs.field(default=0.0)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
@attrs.define | |
class LRSchedulerConfig: | |
'''lr_scheduler configuration | |
Attributes: | |
mode: (str) One of "min", "max". In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: "min". | |
threshold: (float) Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4. | |
threshold_mode: (str) One of "rel", "abs". In rel mode, dynamic_threshold = best * ( 1 + threshold ) in max mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: "rel". | |
cooldown: (int) Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0 | |
patience: (int) Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the third epoch if the loss still hasn’t improved then. Default: 10. | |
factor: (float) Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1. | |
min_lr: (float or List[float]) A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0. | |
''' | |
mode: str = "min" | |
threshold: float = 1e-4 | |
threshold_mode: str = "rel" | |
cooldown: int = 0 | |
patience: int = 10 | |
factor: float = 0.1 | |
min_lr: Union[float, List[float]] = 0.0 | |
@attrs.define | |
class LRSchedulerConfig: | |
'''lr_scheduler configuration | |
Attributes: | |
mode: (str) One of "min", "max". In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: "min". | |
threshold: (float) Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4. | |
threshold_mode: (str) One of "rel", "abs". In rel mode, dynamic_threshold = best * ( 1 + threshold ) in max mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: "rel". | |
cooldown: (int) Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0 | |
patience: (int) Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the third epoch if the loss still hasn't improved then. Default: 10. | |
factor: (float) Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1. | |
min_lr: (float or List[float]) A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0. | |
''' | |
mode: str = attrs.field(default="min") | |
threshold: float = attrs.field(default=1e-4) | |
threshold_mode: str = attrs.field(default="rel") | |
cooldown: int = attrs.field(default=0) | |
patience: int = attrs.field(default=10) | |
factor: float = attrs.field(default=0.1) | |
min_lr: Union[float, List[float]] = attrs.field(default=0.0) | |
``` | |
Note: The static analysis hints indicate that `Union` and `List` are undefined. To address this, you should add the following import at the beginning of the file: | |
```python | |
from typing import Union, List |
🧰 Tools
🪛 Ruff
157-157: Undefined name
Union
(F821)
157-157: Undefined name
List
(F821)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py
17-17: Undefined name
BackboneType
(F821)
17-17: Undefined name
BackboneType
(F821)
18-18: Undefined name
Union
(F821)
18-18: Undefined name
UNetConfig
(F821)
18-18: Undefined name
ConvNextConfig
(F821)
18-18: Undefined name
SwinTConfig
(F821)
19-19: Undefined name
HeadConfig
(F821)
19-19: Undefined name
HeadConfig
(F821)
28-28: Undefined name
BackboneType
(F821)
29-29: Undefined name
UNetConfig
(F821)
30-30: Undefined name
BackboneType
(F821)
31-31: Undefined name
ConvNextConfig
(F821)
32-32: Undefined name
BackboneType
(F821)
33-33: Undefined name
SwinTConfig
(F821)
41-41: Undefined name
BackboneType
(F821)
44-44: Undefined name
BackboneType
(F821)
47-47: Undefined name
BackboneType
(F821)
94-94: Undefined name
oneof
(F821)
95-95: Undefined name
attr
(F821)
110-110: Undefined name
Optional
(F821)
110-110: Undefined name
SingleInstanceConfig
(F821)
111-111: Undefined name
Optional
(F821)
111-111: Undefined name
CentroidConfig
(F821)
112-112: Undefined name
Optional
(F821)
112-112: Undefined name
CenteredInstanceConfig
(F821)
113-113: Undefined name
Optional
(F821)
113-113: Undefined name
BottomUpConfig
(F821)
118-118: Undefined name
Optional
(F821)
118-118: Undefined name
SingleInstanceConfMapsConfig
(F821)
123-123: Undefined name
Optional
(F821)
123-123: Undefined name
CentroidConfMapsConfig
(F821)
128-128: Undefined name
Optional
(F821)
128-128: Undefined name
CenteredInstanceConfMapsConfig
(F821)
133-133: Undefined name
Optional
(F821)
133-133: Undefined name
BottomUpConfMapsConfig
(F821)
134-134: Undefined name
Optional
(F821)
134-134: Undefined name
PAFConfig
(F821)
145-145: Undefined name
Optional
(F821)
145-145: Undefined name
List
(F821)
146-146: Undefined name
Optional
(F821)
147-147: Undefined name
Optional
(F821)
158-158: Undefined name
Optional
(F821)
159-159: Undefined name
Optional
(F821)
160-160: Undefined name
Optional
(F821)
172-172: Undefined name
Optional
(F821)
172-172: Undefined name
List
(F821)
173-173: Undefined name
Optional
(F821)
174-174: Undefined name
Optional
(F821)
175-175: Undefined name
Optional
(F821)
187-187: Undefined name
Optional
(F821)
187-187: Undefined name
List
(F821)
188-188: Undefined name
Optional
(F821)
189-189: Undefined name
Optional
(F821)
190-190: Undefined name
Optional
(F821)
202-202: Undefined name
Optional
(F821)
202-202: Undefined name
List
(F821)
203-203: Undefined name
Optional
(F821)
204-204: Undefined name
Optional
(F821)
205-205: Undefined name
Optional
(F821)
🔇 Additional comments (1)
sleap_nn/config/model_config.py (1)
1-3
:⚠️ Potential issueAdd missing imports from typing module.
Add the following imports to resolve undefined names and improve type hints:
import attrs from enum import Enum +from typing import Union, Optional, List
Likely invalid or redundant comment.
sleap_nn/config/model_config.py
Outdated
@attrs.define | ||
class SingleInstanceConfMapsConfig: | ||
''' | ||
|
||
Attributes: | ||
part_names: (List[str]) None if nodes from sio.Labels file can be used directly. Else provide text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'. | ||
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied. | ||
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. | ||
''' | ||
part_names: Optional[List[str]] = None | ||
sigma: Optional[float] = None | ||
output_stride: Optional[float] = None | ||
|
||
@attrs.define | ||
class CentroidConfMapsConfig: | ||
''' | ||
|
||
Attributes: | ||
anchor_part: (int) Note: Only for 'CenteredInstanceConfmapsHead'. Index of the anchor node to use as the anchor point. If None, the midpoint of the bounding box of all visible instance points will be used as the anchor. The bounding box midpoint will also be used if the anchor part is specified but not visible in the instance. Setting a reliable anchor point can significantly improve topdown model accuracy as they benefit from a consistent geometry of the body parts relative to the center of the image. | ||
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied. | ||
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. | ||
''' | ||
anchor_part: Optional[int] = None | ||
sigma: Optional[float] = None | ||
output_stride: Optional[float] = None | ||
|
||
@attrs.define | ||
class CenteredInstanceConfMapsConfig: | ||
''' | ||
|
||
Attributes: | ||
part_names: (List[str]) None if nodes from sio.Labels file can be used directly. Else provide text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'. | ||
anchor_part: (int) Note: Only for 'CenteredInstanceConfmapsHead'. Index of the anchor node to use as the anchor point. If None, the midpoint of the bounding box of all visible instance points will be used as the anchor. The bounding box midpoint will also be used if the anchor part is specified but not visible in the instance. Setting a reliable anchor point can significantly improve topdown model accuracy as they benefit from a consistent geometry of the body parts relative to the center of the image. | ||
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied. | ||
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. | ||
''' | ||
part_names: Optional[List[str]] = None | ||
anchor_part: Optional[int] = None | ||
sigma: Optional[float] = None | ||
output_stride: Optional[float] = None | ||
|
||
@attrs.define | ||
class BottomUpConfMapsConfig(): | ||
''' | ||
|
||
Attributes: | ||
part_names: (List[str]) None if nodes from sio.Labels file can be used directly. Else provide text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'. | ||
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied. | ||
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. | ||
loss_weight: (float) Scalar float used to weigh the loss term for this head during training. Increase this to encourage the optimization to focus on improving this specific output in multi-head models. | ||
''' | ||
part_names: Optional[List[str]] = None | ||
sigma: Optional[float] = None | ||
output_stride: Optional[float] = None | ||
loss_weight: Optional[float] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider creating a base configuration class for shared attributes.
The configuration classes (SingleInstanceConfMapsConfig, CentroidConfMapsConfig, etc.) share common attributes like sigma
and output_stride
. Consider creating a base class:
@attrs.define
class BaseConfMapsConfig:
sigma: Optional[float] = None
output_stride: Optional[float] = None
@attrs.define
class SingleInstanceConfMapsConfig(BaseConfMapsConfig):
part_names: Optional[List[str]] = None
# Similar for other classes
This would reduce code duplication and make maintenance easier.
🧰 Tools
🪛 Ruff
145-145: Undefined name
Optional
(F821)
145-145: Undefined name
List
(F821)
146-146: Undefined name
Optional
(F821)
147-147: Undefined name
Optional
(F821)
158-158: Undefined name
Optional
(F821)
159-159: Undefined name
Optional
(F821)
160-160: Undefined name
Optional
(F821)
172-172: Undefined name
Optional
(F821)
172-172: Undefined name
List
(F821)
173-173: Undefined name
Optional
(F821)
174-174: Undefined name
Optional
(F821)
175-175: Undefined name
Optional
(F821)
187-187: Undefined name
Optional
(F821)
187-187: Undefined name
List
(F821)
188-188: Undefined name
Optional
(F821)
189-189: Undefined name
Optional
(F821)
190-190: Undefined name
Optional
(F821)
""" | ||
|
||
init_weight: str = "default" | ||
pre_trained_weights: str = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix type annotation for pre_trained_weights.
The attribute should use Optional[str] since it can be None:
- pre_trained_weights: str = None
+ pre_trained_weights: Optional[str] = None
Committable suggestion was skipped due to low confidence.
sleap_nn/config/model_config.py
Outdated
@oneof | ||
@attr.s(auto_attribs=True) | ||
class HeadsConfig: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update HeadsConfig decorator and class definition.
- The
@oneof
decorator is undefined - The class uses older
attr.s
style instead ofattrs.define
-@oneof
-@attr.s(auto_attribs=True)
+@attrs.define
class HeadsConfig:
Committable suggestion was skipped due to low confidence.
🧰 Tools
🪛 Ruff
94-94: Undefined name
oneof
(F821)
95-95: Undefined name
attr
(F821)
sleap_nn/config/model_config.py
Outdated
class BackboneType(Enum): | ||
UNET = "unet" | ||
CONVNEXT = 'convnext' | ||
SWINT = 'swint' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider moving BackboneType enum outside ModelConfig.
Moving the enum outside would improve reusability and follow Python's common patterns:
from enum import Enum
+class BackboneType(Enum):
+ UNET = "unet"
+ CONVNEXT = 'convnext'
+ SWINT = 'swint'
+
@attrs.define
class ModelConfig:
- class BackboneType(Enum):
- UNET = "unet"
- CONVNEXT = 'convnext'
- SWINT = 'swint'
Committable suggestion was skipped due to low confidence.
sleap_nn/config/model_config.py
Outdated
def validate_pre_trained_weights(self): | ||
convnext_weights = ["ConvNeXt_Base_Weights", "ConvNeXt_Tiny_Weights", "ConvNeXt_Small_Weights", "ConvNeXt_Large_Weights"] | ||
swint_weights = ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"] | ||
|
||
if self.backbone_type == BackboneType.CONVNEXT: | ||
if self.pre_trained_weights not in convnext_weights: | ||
raise ValueError(f"Invalid pre-trained weights for ConvNext. Must be one of {convnext_weights}") | ||
elif self.backbone_type == BackboneType.SWINT: | ||
if self.pre_trained_weights not in swint_weights: | ||
raise ValueError(f"Invalid pre-trained weights for SwinT. Must be one of {swint_weights}") | ||
elif self.backbone_type == BackboneType.UNET and self.pre_trained_weights is not None: | ||
raise ValueError("UNet does not support pre-trained weights.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Simplify validation logic using a dictionary mapping.
The validation logic can be made more maintainable using a dictionary:
def validate_pre_trained_weights(self):
- convnext_weights = ["ConvNeXt_Base_Weights", "ConvNeXt_Tiny_Weights", "ConvNeXt_Small_Weights", "ConvNeXt_Large_Weights"]
- swint_weights = ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"]
-
- if self.backbone_type == BackboneType.CONVNEXT:
- if self.pre_trained_weights not in convnext_weights:
- raise ValueError(f"Invalid pre-trained weights for ConvNext. Must be one of {convnext_weights}")
- elif self.backbone_type == BackboneType.SWINT:
- if self.pre_trained_weights not in swint_weights:
- raise ValueError(f"Invalid pre-trained weights for SwinT. Must be one of {swint_weights}")
- elif self.backbone_type == BackboneType.UNET and self.pre_trained_weights is not None:
- raise ValueError("UNet does not support pre-trained weights.")
+ valid_weights = {
+ BackboneType.CONVNEXT: ["ConvNeXt_Base_Weights", "ConvNeXt_Tiny_Weights", "ConvNeXt_Small_Weights", "ConvNeXt_Large_Weights"],
+ BackboneType.SWINT: ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"],
+ BackboneType.UNET: None
+ }
+ allowed_weights = valid_weights[self.backbone_type]
+ if allowed_weights is None and self.pre_trained_weights is not None:
+ raise ValueError(f"{self.backbone_type.value} does not support pre-trained weights.")
+ elif allowed_weights and self.pre_trained_weights not in allowed_weights:
+ raise ValueError(f"Invalid pre-trained weights for {self.backbone_type.value}. Must be one of {allowed_weights}")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def validate_pre_trained_weights(self): | |
convnext_weights = ["ConvNeXt_Base_Weights", "ConvNeXt_Tiny_Weights", "ConvNeXt_Small_Weights", "ConvNeXt_Large_Weights"] | |
swint_weights = ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"] | |
if self.backbone_type == BackboneType.CONVNEXT: | |
if self.pre_trained_weights not in convnext_weights: | |
raise ValueError(f"Invalid pre-trained weights for ConvNext. Must be one of {convnext_weights}") | |
elif self.backbone_type == BackboneType.SWINT: | |
if self.pre_trained_weights not in swint_weights: | |
raise ValueError(f"Invalid pre-trained weights for SwinT. Must be one of {swint_weights}") | |
elif self.backbone_type == BackboneType.UNET and self.pre_trained_weights is not None: | |
raise ValueError("UNet does not support pre-trained weights.") | |
def validate_pre_trained_weights(self): | |
valid_weights = { | |
BackboneType.CONVNEXT: ["ConvNeXt_Base_Weights", "ConvNeXt_Tiny_Weights", "ConvNeXt_Small_Weights", "ConvNeXt_Large_Weights"], | |
BackboneType.SWINT: ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"], | |
BackboneType.UNET: None | |
} | |
allowed_weights = valid_weights[self.backbone_type] | |
if allowed_weights is None and self.pre_trained_weights is not None: | |
raise ValueError(f"{self.backbone_type.value} does not support pre-trained weights.") | |
elif allowed_weights and self.pre_trained_weights not in allowed_weights: | |
raise ValueError(f"Invalid pre-trained weights for {self.backbone_type.value}. Must be one of {allowed_weights}") |
🧰 Tools
🪛 Ruff
41-41: Undefined name
BackboneType
(F821)
44-44: Undefined name
BackboneType
(F821)
47-47: Undefined name
BackboneType
(F821)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 11
🧹 Outside diff range and nitpick comments (1)
sleap_nn/config/trainer_config.py (1)
31-54
: Complete the documentation for optimizer, lr_scheduler, and early_stopping attributes.The class docstring is missing descriptions for some attributes.
Add these descriptions to the docstring:
wandb: (Only if use_wandb is True, else skip this) optimizer_name: (str) Optimizer to be used. One of ["Adam", "AdamW"]. - optimizer: - lr_scheduler: - early_stopping: + optimizer: (OptimizerConfig) Configuration for the optimizer. + lr_scheduler: (LRSchedulerConfig) Configuration for the learning rate scheduler. + early_stopping: (EarlyStoppingConfig) Configuration for early stopping criteria.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
- sleap_nn/config/model_config.py (1 hunks)
- sleap_nn/config/trainer_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py
19-19: Undefined name
BackboneType
(F821)
19-19: Undefined name
BackboneType
(F821)
20-20: Undefined name
Union
(F821)
20-20: Undefined name
UNetConfig
(F821)
20-20: Undefined name
ConvNextConfig
(F821)
20-20: Undefined name
SwinTConfig
(F821)
23-23: Undefined name
HeadConfig
(F821)
23-23: Undefined name
HeadConfig
(F821)
32-32: Undefined name
BackboneType
(F821)
33-33: Undefined name
UNetConfig
(F821)
34-34: Undefined name
BackboneType
(F821)
35-35: Undefined name
ConvNextConfig
(F821)
36-36: Undefined name
BackboneType
(F821)
37-37: Undefined name
SwinTConfig
(F821)
51-51: Undefined name
BackboneType
(F821)
56-56: Undefined name
BackboneType
(F821)
62-62: Undefined name
BackboneType
(F821)
166-166: Undefined name
oneof
(F821)
167-167: Undefined name
attr
(F821)
182-182: Undefined name
Optional
(F821)
182-182: Undefined name
SingleInstanceConfig
(F821)
183-183: Undefined name
Optional
(F821)
183-183: Undefined name
CentroidConfig
(F821)
184-184: Undefined name
Optional
(F821)
184-184: Undefined name
CenteredInstanceConfig
(F821)
185-185: Undefined name
Optional
(F821)
185-185: Undefined name
BottomUpConfig
(F821)
191-191: Undefined name
Optional
(F821)
191-191: Undefined name
SingleInstanceConfMapsConfig
(F821)
197-197: Undefined name
Optional
(F821)
197-197: Undefined name
CentroidConfMapsConfig
(F821)
203-203: Undefined name
Optional
(F821)
203-203: Undefined name
CenteredInstanceConfMapsConfig
(F821)
209-209: Undefined name
Optional
(F821)
209-209: Undefined name
BottomUpConfMapsConfig
(F821)
210-210: Undefined name
Optional
(F821)
210-210: Undefined name
PAFConfig
(F821)
223-223: Undefined name
Optional
(F821)
223-223: Undefined name
List
(F821)
224-224: Undefined name
Optional
(F821)
225-225: Undefined name
Optional
(F821)
238-238: Undefined name
Optional
(F821)
239-239: Undefined name
Optional
(F821)
240-240: Undefined name
Optional
(F821)
254-254: Undefined name
Optional
(F821)
254-254: Undefined name
List
(F821)
255-255: Undefined name
Optional
(F821)
256-256: Undefined name
Optional
(F821)
257-257: Undefined name
Optional
(F821)
271-271: Undefined name
Optional
(F821)
271-271: Undefined name
List
(F821)
272-272: Undefined name
Optional
(F821)
273-273: Undefined name
Optional
(F821)
274-274: Undefined name
Optional
(F821)
288-288: Undefined name
Optional
(F821)
288-288: Undefined name
List
(F821)
289-289: Undefined name
Optional
(F821)
290-290: Undefined name
Optional
(F821)
291-291: Undefined name
Optional
(F821)
sleap_nn/config/trainer_config.py
2-2:
omegaconf.OmegaConf
imported but unusedRemove unused import:
omegaconf.OmegaConf
(F401)
56-56: Undefined name
DataLoaderConfig
(F821)
56-56: Undefined name
DataLoaderConfig
(F821)
57-57: Undefined name
DataLoaderConfig
(F821)
57-57: Undefined name
DataLoaderConfig
(F821)
58-58: Undefined name
ModelCkptConfig
(F821)
58-58: Undefined name
ModelCkptConfig
(F821)
59-59: Undefined name
Union
(F821)
59-59: Undefined name
List
(F821)
62-62: Undefined name
Optional
(F821)
64-64: Undefined name
Optional
(F821)
68-68: Undefined name
Optional
(F821)
69-69: Undefined name
Optional
(F821)
69-69: Undefined name
WandBConfig
(F821)
70-70: Undefined name
Optional
(F821)
70-70: Undefined name
OptimizerConfig
(F821)
70-70: Undefined name
OptimizerConfig
(F821)
71-71: Undefined name
LRSchedulerConfig
(F821)
71-71: Undefined name
LRSchedulerConfig
(F821)
72-72: Undefined name
EarlyStoppingConfig
(F821)
72-72: Undefined name
attr
(F821)
72-72: Undefined name
EarlyStoppingConfig
(F821)
110-110: Undefined name
Optional
(F821)
127-127: Undefined name
Optional
(F821)
128-128: Undefined name
Optional
(F821)
129-129: Undefined name
Optional
(F821)
130-130: Undefined name
Optional
(F821)
131-131: Undefined name
Optional
(F821)
132-132: Undefined name
Optional
(F821)
133-133: Undefined name
Optional
(F821)
133-133: Undefined name
List
(F821)
169-169: Undefined name
Union
(F821)
169-169: Undefined name
List
(F821)
🔇 Additional comments (5)
sleap_nn/config/trainer_config.py (2)
4-28
: LGTM! Well-documented module purpose.The module docstring clearly explains the purpose and design principles of these configuration classes.
74-83
: LGTM! Clean post-initialization logic.The
__attrs_post_init__
method correctly handles the conditional initialization of thewandb
configuration.sleap_nn/config/model_config.py (3)
18-18
:⚠️ Potential issueUse Optional type for pre_trained_weights.
The attribute allows None but is typed as str.
- pre_trained_weights: str = None + pre_trained_weights: Optional[str] = NoneLikely invalid or redundant comment.
1-4
:⚠️ Potential issueAdd missing imports from typing module.
The code uses type annotations but lacks the required imports. Add the following imports:
import attrs from enum import Enum +from typing import Dict, List, Optional, Union
Likely invalid or redundant comment.
166-168
:⚠️ Potential issueUpdate HeadsConfig decorator.
The class uses deprecated attrs style and an undefined oneof decorator.
-@oneof -@attr.s(auto_attribs=True) +@attrs.define class HeadsConfig:Note: If the
oneof
decorator is required for validation, please ensure it's properly imported or implemented.🧰 Tools
🪛 Ruff
166-166: Undefined name
oneof
(F821)
167-167: Undefined name
attr
(F821)
lr: float = 1e-3 | ||
amsgrad: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use attrs.field for OptimizerConfig attributes.
For consistency with other classes, use attrs.field
for attribute definitions.
- lr: float = 1e-3
- amsgrad: bool = False
+ lr: float = attrs.field(default=1e-3)
+ amsgrad: bool = attrs.field(default=False)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
lr: float = 1e-3 | |
amsgrad: bool = False | |
lr: float = attrs.field(default=1e-3) | |
amsgrad: bool = attrs.field(default=False) |
stop_training_on_plateau: bool = False | ||
min_delta: float = 0.0 | ||
patience: int = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use attrs.field for EarlyStoppingConfig attributes.
For consistency with other classes, use attrs.field
for attribute definitions.
- stop_training_on_plateau: bool = False
- min_delta: float = 0.0
- patience: int = 1
+ stop_training_on_plateau: bool = attrs.field(default=False)
+ min_delta: float = attrs.field(default=0.0)
+ patience: int = attrs.field(default=1)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
stop_training_on_plateau: bool = False | |
min_delta: float = 0.0 | |
patience: int = 1 | |
stop_training_on_plateau: bool = attrs.field(default=False) | |
min_delta: float = attrs.field(default=0.0) | |
patience: int = attrs.field(default=1) |
batch_size: int = 1 | ||
shuffle: bool = False | ||
num_workers: int = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use attrs.field for DataLoaderConfig attributes.
For consistency with other classes, use attrs.field
for attribute definitions.
- batch_size: int = 1
- shuffle: bool = False
- num_workers: int = 0
+ batch_size: int = attrs.field(default=1)
+ shuffle: bool = attrs.field(default=False)
+ num_workers: int = attrs.field(default=0)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
batch_size: int = 1 | |
shuffle: bool = False | |
num_workers: int = 0 | |
batch_size: int = attrs.field(default=1) | |
shuffle: bool = attrs.field(default=False) | |
num_workers: int = attrs.field(default=0) |
save_top_k: int = 1 | ||
save_last: Optional[bool] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use attrs.field for ModelCkptConfig attributes.
For consistency with other classes, use attrs.field
for attribute definitions.
- save_top_k: int = 1
- save_last: Optional[bool] = None
+ save_top_k: int = attrs.field(default=1)
+ save_last: Optional[bool] = attrs.field(default=None)
Committable suggestion was skipped due to low confidence.
🧰 Tools
🪛 Ruff
110-110: Undefined name
Optional
(F821)
backbone_config: Union[UNetConfig, ConvNextConfig, SwinTConfig] = attrs.field( | ||
init=False | ||
) # backbone_config can be any of these 3 configurations. init=False lets you set the parameters later (not in initialization) | ||
head_configs: HeadConfig = attrs.field(factory=HeadConfig) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix inconsistent class reference.
The attribute references HeadConfig but the class is named HeadsConfig.
- head_configs: HeadConfig = attrs.field(factory=HeadConfig)
+ head_configs: HeadsConfig = attrs.field(factory=HeadsConfig)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
head_configs: HeadConfig = attrs.field(factory=HeadConfig) | |
head_configs: HeadsConfig = attrs.field(factory=HeadsConfig) |
🧰 Tools
🪛 Ruff
23-23: Undefined name
HeadConfig
(F821)
23-23: Undefined name
HeadConfig
(F821)
sleap_nn/config/model_config.py
Outdated
kernel_size: int = 3 | ||
filters: int = 32 | ||
filters_rate: float = 1.5 | ||
max_stride: int = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use Optional type for max_stride.
The attribute allows None but is typed as int.
- max_stride: int = None
+ max_stride: Optional[int] = None
Committable suggestion was skipped due to low confidence.
sleap_nn/config/model_config.py
Outdated
class BackboneType(Enum): | ||
UNET = "unet" | ||
CONVNEXT = "convnext" | ||
SWINT = "swint" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Move BackboneType enum to module level.
The enum should be at module level for better reusability and to follow Python conventions.
- class BackboneType(Enum):
- UNET = "unet"
- CONVNEXT = "convnext"
- SWINT = "swint"
+class BackboneType(Enum):
+ UNET = "unet"
+ CONVNEXT = "convnext"
+ SWINT = "swint"
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
class BackboneType(Enum): | |
UNET = "unet" | |
CONVNEXT = "convnext" | |
SWINT = "swint" | |
class BackboneType(Enum): | |
UNET = "unet" | |
CONVNEXT = "convnext" | |
SWINT = "swint" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py
2-2: enum.Enum
imported but unused
Remove unused import: enum.Enum
(F401)
18-18: Undefined name BackboneConfig
(F821)
18-18: Undefined name BackboneConfig
(F821)
19-19: Undefined name HeadConfig
(F821)
19-19: Undefined name HeadConfig
(F821)
35-35: Undefined name BackboneType
(F821)
40-40: Undefined name BackboneType
(F821)
46-46: Undefined name BackboneType
(F821)
52-52: Undefined name oneof
(F821)
63-63: Undefined name Optional
(F821)
63-63: Undefined name UNetConfig
(F821)
64-64: Undefined name Optional
(F821)
64-64: Undefined name ConvNextConfig
(F821)
65-65: Undefined name Optional
(F821)
65-65: Undefined name SwinTConfig
(F821)
164-164: Undefined name oneof
(F821)
180-180: Undefined name Optional
(F821)
180-180: Undefined name SingleInstanceConfig
(F821)
181-181: Undefined name Optional
(F821)
181-181: Undefined name CentroidConfig
(F821)
182-182: Undefined name Optional
(F821)
182-182: Undefined name CenteredInstanceConfig
(F821)
183-183: Undefined name Optional
(F821)
183-183: Undefined name BottomUpConfig
(F821)
189-189: Undefined name Optional
(F821)
189-189: Undefined name SingleInstanceConfMapsConfig
(F821)
195-195: Undefined name Optional
(F821)
195-195: Undefined name CentroidConfMapsConfig
(F821)
201-201: Undefined name Optional
(F821)
201-201: Undefined name CenteredInstanceConfMapsConfig
(F821)
207-207: Undefined name Optional
(F821)
207-207: Undefined name BottomUpConfMapsConfig
(F821)
208-208: Undefined name Optional
(F821)
208-208: Undefined name PAFConfig
(F821)
221-221: Undefined name Optional
(F821)
221-221: Undefined name List
(F821)
222-222: Undefined name Optional
(F821)
223-223: Undefined name Optional
(F821)
236-236: Undefined name Optional
(F821)
237-237: Undefined name Optional
(F821)
238-238: Undefined name Optional
(F821)
252-252: Undefined name Optional
(F821)
252-252: Undefined name List
(F821)
253-253: Undefined name Optional
(F821)
254-254: Undefined name Optional
(F821)
255-255: Undefined name Optional
(F821)
269-269: Undefined name Optional
(F821)
269-269: Undefined name List
(F821)
270-270: Undefined name Optional
(F821)
271-271: Undefined name Optional
(F821)
272-272: Undefined name Optional
(F821)
286-286: Undefined name Optional
(F821)
286-286: Undefined name List
(F821)
287-287: Undefined name Optional
(F821)
288-288: Undefined name Optional
(F821)
289-289: Undefined name Optional
(F821)
🔇 Additional comments (2)
sleap_nn/config/model_config.py (2)
1-4
:
Add missing imports for type hints and decorators.
Add the following imports to resolve undefined names and improve type safety:
import attrs
from enum import Enum
+from typing import Optional, List, Dict, Union
+from omegaconf import MISSING
Let's verify if omegaconf is a project dependency:
✅ Verification successful
Add missing imports for type hints and decorators
Add the following imports to resolve undefined names and improve type safety:
import attrs
from enum import Enum
+from typing import Optional, List, Dict, Union
+from omegaconf import MISSING
The verification confirms:
omegaconf
is a project dependency throughhydra-core
inpyproject.toml
omegaconf
is used across the codebase, including other config files- Specifically,
MISSING
fromomegaconf
is already used insleap_nn/config/data_config.py
- The project follows a consistent pattern of using omegaconf for configuration management
The suggested imports are appropriate and align with the project's existing practices.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Check if omegaconf is listed in requirements
rg -l "omegaconf" "requirements.*"
Length of output: 89
Script:
#!/bin/bash
# Let's check for setup.py, pyproject.toml, or any other dependency files
fd -e toml -e txt -e py "setup|requirements|pyproject"
# Also search for omegaconf usage in the codebase
rg -l "omegaconf"
Length of output: 1002
Script:
#!/bin/bash
# Let's check the content of pyproject.toml for dependencies
cat pyproject.toml
# Also check environment files for omegaconf
grep -h "omegaconf" environment*.yml
Length of output: 1667
Script:
#!/bin/bash
# Let's check if hydra-core (which includes omegaconf) is actually used in the config files
rg "from hydra|import hydra" sleap_nn/config/
rg "from omegaconf|import omegaconf" sleap_nn/config/
# Also check how MISSING is used in the config files
rg "MISSING" sleap_nn/config/
Length of output: 455
🧰 Tools
🪛 Ruff
2-2: enum.Enum
imported but unused
Remove unused import: enum.Enum
(F401)
52-66
:
Fix oneof decorator usage.
The @oneof
decorator is undefined. This appears to be a custom decorator that needs to be imported or defined.
Let's check if this decorator exists in the codebase:
🧰 Tools
🪛 Ruff
52-52: Undefined name oneof
(F821)
63-63: Undefined name Optional
(F821)
63-63: Undefined name UNetConfig
(F821)
64-64: Undefined name Optional
(F821)
64-64: Undefined name ConvNextConfig
(F821)
65-65: Undefined name Optional
(F821)
65-65: Undefined name SwinTConfig
(F821)
kernel_size: int = 3 | ||
filters: int = 32 | ||
filters_rate: float = 1.5 | ||
max_stride: int = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix type hint for max_stride in UNetConfig.
The max_stride
attribute allows None but is typed as int
.
- max_stride: int = None
+ max_stride: Optional[int] = None
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
max_stride: int = None | |
max_stride: Optional[int] = None |
@attrs.define | ||
class ModelConfig: | ||
"""Configurations related to model architecture. | ||
|
||
Attributes: | ||
init_weight: (str) model weights initialization method. "default" uses kaiming uniform initialization and "xavier" uses Xavier initialization method. | ||
pre_trained_weights: (str) Pretrained weights file name supported only for ConvNext and SwinT backbones. For ConvNext, one of ["ConvNeXt_Base_Weights","ConvNeXt_Tiny_Weights", "ConvNeXt_Small_Weights", "ConvNeXt_Large_Weights"]. For SwinT, one of ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"]. | ||
backbone_config: initialize either UNetConfig, ConvNextConfig, or SwinTConfig based on input from backbone_type | ||
head_config: head_configs: (Dict) Dictionary with the following keys having head configs for the model to be trained. Note: Configs should be provided only for the model to train and others should be None | ||
""" | ||
|
||
init_weight: str = "default" | ||
pre_trained_weights: str = None | ||
backbone_config: BackboneConfig = attrs.field(factory=BackboneConfig) | ||
head_configs: HeadConfig = attrs.field(factory=HeadConfig) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix type hints and class references in ModelConfig.
- The
pre_trained_weights
should beOptional[str]
since it accepts None - The
head_configs
referencesHeadConfig
but the class is namedHeadsConfig
- The
BackboneType
enum is missing
Apply these fixes:
+class BackboneType(Enum):
+ UNET = "unet"
+ CONVNEXT = "convnext"
+ SWINT = "swint"
+
@attrs.define
class ModelConfig:
"""Configurations related to model architecture.
...
"""
init_weight: str = "default"
- pre_trained_weights: str = None
+ pre_trained_weights: Optional[str] = None
backbone_config: BackboneConfig = attrs.field(factory=BackboneConfig)
- head_configs: HeadConfig = attrs.field(factory=HeadConfig)
+ head_configs: HeadsConfig = attrs.field(factory=HeadsConfig)
Committable suggestion was skipped due to low confidence.
🧰 Tools
🪛 Ruff
18-18: Undefined name BackboneConfig
(F821)
18-18: Undefined name BackboneConfig
(F821)
19-19: Undefined name HeadConfig
(F821)
19-19: Undefined name HeadConfig
(F821)
@attrs.define | ||
class SingleInstanceConfMapsConfig: | ||
"""Single Instance configuration map | ||
|
||
Attributes: | ||
part_names: (List[str]) None if nodes from sio.Labels file can be used directly. Else provide text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'. | ||
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied. | ||
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. | ||
""" | ||
|
||
part_names: Optional[List[str]] = None | ||
sigma: Optional[float] = None | ||
output_stride: Optional[float] = None | ||
|
||
|
||
@attrs.define | ||
class CentroidConfMapsConfig: | ||
"""Centroid configuration map | ||
|
||
Attributes: | ||
anchor_part: (int) Note: Only for 'CenteredInstanceConfmapsHead'. Index of the anchor node to use as the anchor point. If None, the midpoint of the bounding box of all visible instance points will be used as the anchor. The bounding box midpoint will also be used if the anchor part is specified but not visible in the instance. Setting a reliable anchor point can significantly improve topdown model accuracy as they benefit from a consistent geometry of the body parts relative to the center of the image. | ||
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied. | ||
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. | ||
""" | ||
|
||
anchor_part: Optional[int] = None | ||
sigma: Optional[float] = None | ||
output_stride: Optional[float] = None | ||
|
||
|
||
@attrs.define | ||
class CenteredInstanceConfMapsConfig: | ||
"""Centered Instance configuration map | ||
|
||
Attributes: | ||
part_names: (List[str]) None if nodes from sio.Labels file can be used directly. Else provide text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'. | ||
anchor_part: (int) Note: Only for 'CenteredInstanceConfmapsHead'. Index of the anchor node to use as the anchor point. If None, the midpoint of the bounding box of all visible instance points will be used as the anchor. The bounding box midpoint will also be used if the anchor part is specified but not visible in the instance. Setting a reliable anchor point can significantly improve topdown model accuracy as they benefit from a consistent geometry of the body parts relative to the center of the image. | ||
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied. | ||
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. | ||
""" | ||
|
||
part_names: Optional[List[str]] = None | ||
anchor_part: Optional[int] = None | ||
sigma: Optional[float] = None | ||
output_stride: Optional[float] = None | ||
|
||
|
||
@attrs.define | ||
class BottomUpConfMapsConfig: | ||
"""Bottomup configuration map | ||
|
||
Attributes: | ||
part_names: (List[str]) None if nodes from sio.Labels file can be used directly. Else provide text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'. | ||
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied. | ||
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. | ||
loss_weight: (float) Scalar float used to weigh the loss term for this head during training. Increase this to encourage the optimization to focus on improving this specific output in multi-head models. | ||
""" | ||
|
||
part_names: Optional[List[str]] = None | ||
sigma: Optional[float] = None | ||
output_stride: Optional[float] = None | ||
loss_weight: Optional[float] = None | ||
|
||
|
||
@attrs.define | ||
class PAFConfig: | ||
"""PAF configuration map | ||
|
||
Attributes: | ||
edges: (List[str]) None if edges from sio.Labels file can be used directly. Note: Only for 'PartAffinityFieldsHead'. List of indices (src, dest) that form an edge. | ||
sigma: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied. | ||
output_stride: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. | ||
loss_weight: (float) Scalar float used to weigh the loss term for this head during training. Increase this to encourage the optimization to focus on improving this specific output in multi-head models. | ||
""" | ||
|
||
edges: Optional[List[str]] = None | ||
sigma: Optional[float] = None | ||
output_stride: Optional[float] = None | ||
loss_weight: Optional[float] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider creating a base configuration class for common attributes.
The configuration classes share common attributes (sigma
, output_stride
, loss_weight
). Consider creating a base class to reduce code duplication.
+@attrs.define
+class BaseConfMapsConfig:
+ """Base configuration for confidence maps.
+
+ Attributes:
+ sigma: Spread of the Gaussian distribution of the confidence maps.
+ output_stride: The stride of the output confidence maps relative to the input image.
+ loss_weight: Weight of the loss term during training.
+ """
+ sigma: Optional[float] = None
+ output_stride: Optional[float] = None
+ loss_weight: Optional[float] = None
+
@attrs.define
-class SingleInstanceConfMapsConfig:
+class SingleInstanceConfMapsConfig(BaseConfMapsConfig):
"""Single Instance configuration map"""
part_names: Optional[List[str]] = None
- sigma: Optional[float] = None
- output_stride: Optional[float] = None
# Similar changes for other configuration classes
Committable suggestion was skipped due to low confidence.
🧰 Tools
🪛 Ruff
221-221: Undefined name Optional
(F821)
221-221: Undefined name List
(F821)
222-222: Undefined name Optional
(F821)
223-223: Undefined name Optional
(F821)
236-236: Undefined name Optional
(F821)
237-237: Undefined name Optional
(F821)
238-238: Undefined name Optional
(F821)
252-252: Undefined name Optional
(F821)
252-252: Undefined name List
(F821)
253-253: Undefined name Optional
(F821)
254-254: Undefined name Optional
(F821)
255-255: Undefined name Optional
(F821)
269-269: Undefined name Optional
(F821)
269-269: Undefined name List
(F821)
270-270: Undefined name Optional
(F821)
271-271: Undefined name Optional
(F821)
272-272: Undefined name Optional
(F821)
286-286: Undefined name Optional
(F821)
286-286: Undefined name List
(F821)
287-287: Undefined name Optional
(F821)
288-288: Undefined name Optional
(F821)
289-289: Undefined name Optional
(F821)
def validate_pre_trained_weights(self): | ||
convnext_weights = [ | ||
"ConvNeXt_Base_Weights", | ||
"ConvNeXt_Tiny_Weights", | ||
"ConvNeXt_Small_Weights", | ||
"ConvNeXt_Large_Weights", | ||
] | ||
swint_weights = ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"] | ||
|
||
if self.backbone_type == BackboneType.CONVNEXT: | ||
if self.pre_trained_weights not in convnext_weights: | ||
raise ValueError( | ||
f"Invalid pre-trained weights for ConvNext. Must be one of {convnext_weights}" | ||
) | ||
elif self.backbone_type == BackboneType.SWINT: | ||
if self.pre_trained_weights not in swint_weights: | ||
raise ValueError( | ||
f"Invalid pre-trained weights for SwinT. Must be one of {swint_weights}" | ||
) | ||
elif ( | ||
self.backbone_type == BackboneType.UNET | ||
and self.pre_trained_weights is not None | ||
): | ||
raise ValueError("UNet does not support pre-trained weights.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Simplify pre-trained weights validation using a mapping.
The validation logic can be simplified and made more maintainable using a dictionary mapping.
def validate_pre_trained_weights(self):
- convnext_weights = [
- "ConvNeXt_Base_Weights",
- "ConvNeXt_Tiny_Weights",
- "ConvNeXt_Small_Weights",
- "ConvNeXt_Large_Weights",
- ]
- swint_weights = ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"]
-
- if self.backbone_type == BackboneType.CONVNEXT:
- if self.pre_trained_weights not in convnext_weights:
- raise ValueError(
- f"Invalid pre-trained weights for ConvNext. Must be one of {convnext_weights}"
- )
- elif self.backbone_type == BackboneType.SWINT:
- if self.pre_trained_weights not in swint_weights:
- raise ValueError(
- f"Invalid pre-trained weights for SwinT. Must be one of {swint_weights}"
- )
- elif (
- self.backbone_type == BackboneType.UNET
- and self.pre_trained_weights is not None
- ):
- raise ValueError("UNet does not support pre-trained weights.")
+ VALID_WEIGHTS = {
+ BackboneType.CONVNEXT: [
+ "ConvNeXt_Base_Weights",
+ "ConvNeXt_Tiny_Weights",
+ "ConvNeXt_Small_Weights",
+ "ConvNeXt_Large_Weights",
+ ],
+ BackboneType.SWINT: [
+ "Swin_T_Weights",
+ "Swin_S_Weights",
+ "Swin_B_Weights"
+ ],
+ BackboneType.UNET: None
+ }
+
+ allowed_weights = VALID_WEIGHTS.get(self.backbone_type)
+ if allowed_weights is None and self.pre_trained_weights is not None:
+ raise ValueError(f"{self.backbone_type.value} does not support pre-trained weights.")
+ elif allowed_weights and self.pre_trained_weights not in allowed_weights:
+ raise ValueError(
+ f"Invalid pre-trained weights for {self.backbone_type.value}. "
+ f"Must be one of {allowed_weights}"
+ )
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def validate_pre_trained_weights(self): | |
convnext_weights = [ | |
"ConvNeXt_Base_Weights", | |
"ConvNeXt_Tiny_Weights", | |
"ConvNeXt_Small_Weights", | |
"ConvNeXt_Large_Weights", | |
] | |
swint_weights = ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"] | |
if self.backbone_type == BackboneType.CONVNEXT: | |
if self.pre_trained_weights not in convnext_weights: | |
raise ValueError( | |
f"Invalid pre-trained weights for ConvNext. Must be one of {convnext_weights}" | |
) | |
elif self.backbone_type == BackboneType.SWINT: | |
if self.pre_trained_weights not in swint_weights: | |
raise ValueError( | |
f"Invalid pre-trained weights for SwinT. Must be one of {swint_weights}" | |
) | |
elif ( | |
self.backbone_type == BackboneType.UNET | |
and self.pre_trained_weights is not None | |
): | |
raise ValueError("UNet does not support pre-trained weights.") | |
def validate_pre_trained_weights(self): | |
VALID_WEIGHTS = { | |
BackboneType.CONVNEXT: [ | |
"ConvNeXt_Base_Weights", | |
"ConvNeXt_Tiny_Weights", | |
"ConvNeXt_Small_Weights", | |
"ConvNeXt_Large_Weights", | |
], | |
BackboneType.SWINT: [ | |
"Swin_T_Weights", | |
"Swin_S_Weights", | |
"Swin_B_Weights" | |
], | |
BackboneType.UNET: None | |
} | |
allowed_weights = VALID_WEIGHTS.get(self.backbone_type) | |
if allowed_weights is None and self.pre_trained_weights is not None: | |
raise ValueError(f"{self.backbone_type.value} does not support pre-trained weights.") | |
elif allowed_weights and self.pre_trained_weights not in allowed_weights: | |
raise ValueError( | |
f"Invalid pre-trained weights for {self.backbone_type.value}. " | |
f"Must be one of {allowed_weights}" | |
) |
🧰 Tools
🪛 Ruff
35-35: Undefined name BackboneType
(F821)
40-40: Undefined name BackboneType
(F821)
46-46: Undefined name BackboneType
(F821)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Outside diff range and nitpick comments (2)
sleap_nn/config/data_config.py (2)
21-21
: Address TODO comment in docstringThe TODO comment about changing chunk_size in inference checkpoints should be tracked properly.
Would you like me to create a GitHub issue to track this TODO item?
65-75
: Enhance class documentationThe documentation could be more detailed by describing what types of augmentations are supported and how they interact.
Consider expanding the docstring to include:
- Purpose of intensity vs. geometric augmentations
- How these configurations are used in the training pipeline
- Examples of common configurations
🧰 Tools
🪛 Ruff
74-74: Undefined name
IntensityConfig
(F821)
75-75: Undefined name
GeometricConfig
(F821)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
sleap_nn/config/data_config.py
(1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py
3-3: typing.List
imported but unused
Remove unused import
(F401)
3-3: typing.Dict
imported but unused
Remove unused import
(F401)
34-34: Undefined name PreprocessingConfig
(F821)
34-34: Undefined name PreprocessingConfig
(F821)
36-36: Undefined name AugmentationConfig
(F821)
60-60: Undefined name Union
(F821)
74-74: Undefined name IntensityConfig
(F821)
75-75: Undefined name GeometricConfig
(F821)
🔇 Additional comments (3)
sleap_nn/config/data_config.py (3)
39-41
: LGTM!
The validation function is well-implemented with a clear error message and is used consistently throughout the configuration classes.
44-62
: LGTM!
The PreprocessingConfig class is well-structured with:
- Clear documentation for each attribute
- Appropriate default values
- Type hints for optional fields
🧰 Tools
🪛 Ruff
60-60: Undefined name Union
(F821)
78-110
: LGTM!
The IntensityConfig class is well-implemented with:
- Comprehensive parameter validation
- Appropriate default values
- Clear documentation
import attrs | ||
from omegaconf import MISSING | ||
from typing import Optional, Tuple, List, Dict | ||
|
||
|
||
"""Serializable configuration classes for specifying all data configuration parameters. | ||
|
||
These configuration classes are intended to specify all | ||
the parameters required to initialize the data config. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optimize imports
The imports need the following adjustments:
- Remove unused imports:
List
andDict
- Add missing import:
Union
from typing
Apply this diff to fix the imports:
import attrs
from omegaconf import MISSING
-from typing import Optional, Tuple, List, Dict
+from typing import Optional, Tuple, Union
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
import attrs | |
from omegaconf import MISSING | |
from typing import Optional, Tuple, List, Dict | |
"""Serializable configuration classes for specifying all data configuration parameters. | |
These configuration classes are intended to specify all | |
the parameters required to initialize the data config. | |
""" | |
import attrs | |
from omegaconf import MISSING | |
from typing import Optional, Tuple, Union | |
"""Serializable configuration classes for specifying all data configuration parameters. | |
These configuration classes are intended to specify all | |
the parameters required to initialize the data config. | |
""" |
🧰 Tools
🪛 Ruff
3-3: typing.List
imported but unused
Remove unused import
(F401)
3-3: typing.Dict
imported but unused
Remove unused import
(F401)
preprocessing: PreprocessingConfig = attrs.field(factory=PreprocessingConfig) | ||
use_augmentations_train: bool = False | ||
augmentation_config: Optional[AugmentationConfig] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use forward references for configuration classes
To avoid circular imports and undefined names, use string literals for type hints of configuration classes.
Apply this diff:
- preprocessing: PreprocessingConfig = attrs.field(factory=PreprocessingConfig)
- augmentation_config: Optional[AugmentationConfig] = None
+ preprocessing: 'PreprocessingConfig' = attrs.field(
+ factory=lambda: PreprocessingConfig()
+ )
+ augmentation_config: Optional['AugmentationConfig'] = None
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
preprocessing: PreprocessingConfig = attrs.field(factory=PreprocessingConfig) | |
use_augmentations_train: bool = False | |
augmentation_config: Optional[AugmentationConfig] = None | |
preprocessing: 'PreprocessingConfig' = attrs.field( | |
factory=lambda: PreprocessingConfig() | |
) | |
use_augmentations_train: bool = False | |
augmentation_config: Optional['AugmentationConfig'] = None |
🧰 Tools
🪛 Ruff
34-34: Undefined name PreprocessingConfig
(F821)
34-34: Undefined name PreprocessingConfig
(F821)
36-36: Undefined name AugmentationConfig
(F821)
@attrs.define | ||
class GeometricConfig: | ||
""" | ||
Configuration of Geometric (Optional) | ||
|
||
Attributes: | ||
rotation: (float) Angles in degrees as a scalar float of the amount of rotation. A random angle in (-rotation, rotation) will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation. | ||
scale: (float) scaling factor interval. If (a, b) represents isotropic scaling, the scale is randomly sampled from the range a <= scale <= b. If (a, b, c, d), the scale is randomly sampled from the range a <= scale_x <= b, c <= scale_y <= d Default: None. | ||
translate_width: (float) Maximum absolute fraction for horizontal translation. For example, if translate_width=a, then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a. Will not translate by default. | ||
translate_height: (float) Maximum absolute fraction for vertical translation. For example, if translate_height=a, then vertical shift is randomly sampled in the range -img_height * a < dy < img_height * a. Will not translate by default. | ||
affine_p: (float) Probability of applying random affine transformations. Default=0.0 | ||
erase_scale_min: (float) Minimum value of range of proportion of erased area against input image. Default: 0.0001. | ||
erase_scale_max: (float) Maximum value of range of proportion of erased area against input image. Default: 0.01. | ||
erase_ration_min: (float) Minimum value of range of aspect ratio of erased area. Default: 1. | ||
erase_ratio_max: (float) Maximum value of range of aspect ratio of erased area. Default: 1. | ||
erase_p: (float) Probability of applying random erase. Default=0.0 | ||
mixup_lambda: (float) min-max value of mixup strength. Default is 0-1. Default: None. | ||
mixup_p: (float) Probability of applying random mixup v2. Default=0.0 | ||
input_key: (str) Can be image or instance. The input_key instance expects the KorniaAugmenter to follow the InstanceCropper else image otherwise for default. | ||
random_crop_p: (float) Probability of applying random crop. | ||
random_crop_height: (int) Desired output height of the random crop. | ||
random_crop_width: (int) Desired output height of the random crop. | ||
""" | ||
|
||
rotation: float = 0.0 | ||
scale: Optional[Tuple[float, float, float, float]] = None | ||
translate_width: float = 0.0 | ||
translate_height: float = 0.0 | ||
affine_p: float = attrs.field(default=0.0, validator=validate_proportion) | ||
erase_scale_min: float = 0.0001 | ||
erase_scale_max: float = 0.01 | ||
erase_ratio_min: float = 1.0 | ||
erase_ratio_max: float = 1.0 | ||
erase_p: float = attrs.field(default=0.0, validator=validate_proportion) | ||
mixup_lambda: Optional[float] = None | ||
mixup_p: float = attrs.field(default=0.0, validator=validate_proportion) | ||
input_key: str = "image" | ||
random_crop_p: Optional[float] = None | ||
random_crop_height: Optional[int] = None | ||
random_crop_width: Optional[int] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix documentation and attribute name issues
There are two issues in the GeometricConfig class:
- Typo in attribute documentation: "erase_ration_min" should be "erase_ratio_min"
- Documentation error: random_crop_width's description says "output height" instead of "output width"
Apply this diff to fix the documentation:
- erase_ration_min: (float) Minimum value of range of aspect ratio of erased area. Default: 1.
+ erase_ratio_min: (float) Minimum value of range of aspect ratio of erased area. Default: 1.
- random_crop_width: (int) Desired output height of the random crop.
+ random_crop_width: (int) Desired output width of the random crop.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
@attrs.define | |
class GeometricConfig: | |
""" | |
Configuration of Geometric (Optional) | |
Attributes: | |
rotation: (float) Angles in degrees as a scalar float of the amount of rotation. A random angle in (-rotation, rotation) will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation. | |
scale: (float) scaling factor interval. If (a, b) represents isotropic scaling, the scale is randomly sampled from the range a <= scale <= b. If (a, b, c, d), the scale is randomly sampled from the range a <= scale_x <= b, c <= scale_y <= d Default: None. | |
translate_width: (float) Maximum absolute fraction for horizontal translation. For example, if translate_width=a, then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a. Will not translate by default. | |
translate_height: (float) Maximum absolute fraction for vertical translation. For example, if translate_height=a, then vertical shift is randomly sampled in the range -img_height * a < dy < img_height * a. Will not translate by default. | |
affine_p: (float) Probability of applying random affine transformations. Default=0.0 | |
erase_scale_min: (float) Minimum value of range of proportion of erased area against input image. Default: 0.0001. | |
erase_scale_max: (float) Maximum value of range of proportion of erased area against input image. Default: 0.01. | |
erase_ration_min: (float) Minimum value of range of aspect ratio of erased area. Default: 1. | |
erase_ratio_max: (float) Maximum value of range of aspect ratio of erased area. Default: 1. | |
erase_p: (float) Probability of applying random erase. Default=0.0 | |
mixup_lambda: (float) min-max value of mixup strength. Default is 0-1. Default: None. | |
mixup_p: (float) Probability of applying random mixup v2. Default=0.0 | |
input_key: (str) Can be image or instance. The input_key instance expects the KorniaAugmenter to follow the InstanceCropper else image otherwise for default. | |
random_crop_p: (float) Probability of applying random crop. | |
random_crop_height: (int) Desired output height of the random crop. | |
random_crop_width: (int) Desired output height of the random crop. | |
""" | |
rotation: float = 0.0 | |
scale: Optional[Tuple[float, float, float, float]] = None | |
translate_width: float = 0.0 | |
translate_height: float = 0.0 | |
affine_p: float = attrs.field(default=0.0, validator=validate_proportion) | |
erase_scale_min: float = 0.0001 | |
erase_scale_max: float = 0.01 | |
erase_ratio_min: float = 1.0 | |
erase_ratio_max: float = 1.0 | |
erase_p: float = attrs.field(default=0.0, validator=validate_proportion) | |
mixup_lambda: Optional[float] = None | |
mixup_p: float = attrs.field(default=0.0, validator=validate_proportion) | |
input_key: str = "image" | |
random_crop_p: Optional[float] = None | |
random_crop_height: Optional[int] = None | |
random_crop_width: Optional[int] = None | |
@attrs.define | |
class GeometricConfig: | |
""" | |
Configuration of Geometric (Optional) | |
Attributes: | |
rotation: (float) Angles in degrees as a scalar float of the amount of rotation. A random angle in (-rotation, rotation) will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation. | |
scale: (float) scaling factor interval. If (a, b) represents isotropic scaling, the scale is randomly sampled from the range a <= scale <= b. If (a, b, c, d), the scale is randomly sampled from the range a <= scale_x <= b, c <= scale_y <= d Default: None. | |
translate_width: (float) Maximum absolute fraction for horizontal translation. For example, if translate_width=a, then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a. Will not translate by default. | |
translate_height: (float) Maximum absolute fraction for vertical translation. For example, if translate_height=a, then vertical shift is randomly sampled in the range -img_height * a < dy < img_height * a. Will not translate by default. | |
affine_p: (float) Probability of applying random affine transformations. Default=0.0 | |
erase_scale_min: (float) Minimum value of range of proportion of erased area against input image. Default: 0.0001. | |
erase_scale_max: (float) Maximum value of range of proportion of erased area against input image. Default: 0.01. | |
erase_ratio_min: (float) Minimum value of range of aspect ratio of erased area. Default: 1. | |
erase_ratio_max: (float) Maximum value of range of aspect ratio of erased area. Default: 1. | |
erase_p: (float) Probability of applying random erase. Default=0.0 | |
mixup_lambda: (float) min-max value of mixup strength. Default is 0-1. Default: None. | |
mixup_p: (float) Probability of applying random mixup v2. Default=0.0 | |
input_key: (str) Can be image or instance. The input_key instance expects the KorniaAugmenter to follow the InstanceCropper else image otherwise for default. | |
random_crop_p: (float) Probability of applying random crop. | |
random_crop_height: (int) Desired output height of the random crop. | |
random_crop_width: (int) Desired output width of the random crop. | |
""" | |
rotation: float = 0.0 | |
scale: Optional[Tuple[float, float, float, float]] = None | |
translate_width: float = 0.0 | |
translate_height: float = 0.0 | |
affine_p: float = attrs.field(default=0.0, validator=validate_proportion) | |
erase_scale_min: float = 0.0001 | |
erase_scale_max: float = 0.01 | |
erase_ratio_min: float = 1.0 | |
erase_ratio_max: float = 1.0 | |
erase_p: float = attrs.field(default=0.0, validator=validate_proportion) | |
mixup_lambda: Optional[float] = None | |
mixup_p: float = attrs.field(default=0.0, validator=validate_proportion) | |
input_key: str = "image" | |
random_crop_p: Optional[float] = None | |
random_crop_height: Optional[int] = None | |
random_crop_width: Optional[int] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
sleap_nn/config/trainer_config.py
(1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/trainer_config.py
56-56: Undefined name DataLoaderConfig
(F821)
56-56: Undefined name DataLoaderConfig
(F821)
57-57: Undefined name DataLoaderConfig
(F821)
57-57: Undefined name DataLoaderConfig
(F821)
58-58: Undefined name ModelCkptConfig
(F821)
58-58: Undefined name ModelCkptConfig
(F821)
59-59: Undefined name Union
(F821)
59-59: Undefined name List
(F821)
62-62: Undefined name Optional
(F821)
64-64: Undefined name Optional
(F821)
68-68: Undefined name Optional
(F821)
69-69: Undefined name Optional
(F821)
69-69: Undefined name WandBConfig
(F821)
70-70: Undefined name Optional
(F821)
70-70: Undefined name OptimizerConfig
(F821)
70-70: Undefined name OptimizerConfig
(F821)
71-71: Undefined name LRSchedulerConfig
(F821)
71-71: Undefined name LRSchedulerConfig
(F821)
72-72: Undefined name EarlyStoppingConfig
(F821)
72-72: Undefined name attr
(F821)
72-72: Undefined name EarlyStoppingConfig
(F821)
85-85: Undefined name Dict
(F821)
85-85: Undefined name Text
(F821)
85-85: Undefined name Any
(F821)
85-85: Undefined name TrainerJobConfig
(F821)
96-96: Undefined name Omega
(F821)
99-99: Undefined name Text
(F821)
99-99: Undefined name TrainingJobConfig
(F821)
108-108: Undefined name json
(F821)
112-112: Undefined name Text
(F821)
112-112: Undefined name TrainingJobConfig
(F821)
125-125: Undefined name Text
(F821)
125-125: Undefined name TrainingJobConfig
(F821)
140-140: Undefined name Text
(F821)
140-140: Undefined name TrainingJobConfig
(F821)
153-153: Undefined name DictConfig
(F821)
168-168: Undefined name json
(F821)
179-179: Undefined name Text
(F821)
188-188: Undefined name Text
(F821)
222-222: Undefined name Optional
(F821)
239-239: Undefined name Optional
(F821)
240-240: Undefined name Optional
(F821)
241-241: Undefined name Optional
(F821)
242-242: Undefined name Optional
(F821)
243-243: Undefined name Optional
(F821)
244-244: Undefined name Optional
(F821)
245-245: Undefined name Optional
(F821)
245-245: Undefined name List
(F821)
281-281: Undefined name Union
(F821)
281-281: Undefined name List
(F821)
🔇 Additional comments (2)
sleap_nn/config/trainer_config.py (2)
1-2
:
Add missing imports for typing and json modules.
The code uses types from the typing
module (e.g., Union
, List
, Optional
, Dict
, Any
, Text
) and the json
module, but these are not imported. Add the necessary imports at the top of the file to resolve the undefined names.
Apply this diff to add the missing imports:
import attrs
+from typing import Any, Dict, List, Optional, Text, Union
+import json
from omegaconf import OmegaConf
Likely invalid or redundant comment.
85-112
:
Correct return type annotations to match the class name.
The methods from_dict
, from_json
, and from_yaml
have return type annotations as "TrainingJobConfig"
, but the class is named TrainerConfig
. Update the return type annotations to "TrainerConfig"
for consistency.
Apply this diff to correct the return type annotations:
@classmethod
- def from_dict(cls, config_dict: Dict[Text, Any]) -> "TrainingJobConfig":
+ def from_dict(cls, config_dict: Dict[Text, Any]) -> "TrainerConfig":
@classmethod
- def from_json(cls, json_data: Text) -> "TrainingJobConfig":
+ def from_json(cls, json_data: Text) -> "TrainerConfig":
@classmethod
- def from_yaml(cls, yaml_data: Text) -> "TrainingJobConfig":
+ def from_yaml(cls, yaml_data: Text) -> "TrainerConfig"
Likely invalid or redundant comment.
🧰 Tools
🪛 Ruff
85-85: Undefined name Dict
(F821)
85-85: Undefined name Text
(F821)
85-85: Undefined name Any
(F821)
85-85: Undefined name TrainerJobConfig
(F821)
96-96: Undefined name Omega
(F821)
99-99: Undefined name Text
(F821)
99-99: Undefined name TrainingJobConfig
(F821)
108-108: Undefined name json
(F821)
112-112: Undefined name Text
(F821)
112-112: Undefined name TrainingJobConfig
(F821)
sleap_nn/config/trainer_config.py
Outdated
""" | ||
# Convert dictionary to an OmegaConf config, then instantiate from it. | ||
config = OmegaConf.create(config_dict) | ||
return Omega.to_object(config, cls) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace Omega
with OmegaConf
to fix undefined name.
In the from_dict
method, Omega.to_object
is used, but Omega
is undefined. It should be OmegaConf.to_object
to correctly utilize the omegaconf
module.
Apply this diff to correct the method call:
- return Omega.to_object(config, cls)
+ return OmegaConf.to_object(config, cls)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
return Omega.to_object(config, cls) | |
return OmegaConf.to_object(config, cls) |
🧰 Tools
🪛 Ruff
96-96: Undefined name Omega
(F821)
sleap_nn/config/trainer_config.py
Outdated
wandb: Optional[WandBConfig] = attrs.field(init=False) | ||
optimizer: Optional[OptimizerConfig] = attrs.field(factory=OptimizerConfig) | ||
lr_scheduler: LRSchedulerConfig = attrs.field(factory=LRSchedulerConfig) | ||
early_stopping: EarlyStoppingConfig = attr.field(factory=EarlyStoppingConfig) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct typo in attrs.field
usage.
The attribute early_stopping
uses attr.field
instead of attrs.field
. Replace attr.field
with attrs.field
to fix the typo and ensure consistency.
Apply this diff to fix the typo:
- early_stopping: EarlyStoppingConfig = attr.field(factory=EarlyStoppingConfig)
+ early_stopping: EarlyStoppingConfig = attrs.field(factory=EarlyStoppingConfig)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
early_stopping: EarlyStoppingConfig = attr.field(factory=EarlyStoppingConfig) | |
early_stopping: EarlyStoppingConfig = attrs.field(factory=EarlyStoppingConfig) |
🧰 Tools
🪛 Ruff
72-72: Undefined name EarlyStoppingConfig
(F821)
72-72: Undefined name attr
(F821)
72-72: Undefined name EarlyStoppingConfig
(F821)
sleap_nn/config/trainer_config.py
Outdated
trainer_devices: Union[int, List[int], str] = "auto" | ||
trainer_accelerator: str = "auto" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use attrs.field
for attribute definitions for consistency.
In TrainerConfig
, the attributes trainer_devices
and trainer_accelerator
are assigned default values directly instead of using attrs.field
. For consistency and to leverage attrs
features, use attrs.field
to define these attributes.
Apply this diff to use attrs.field
:
- trainer_devices: Union[int, List[int], str] = "auto"
- trainer_accelerator: str = "auto"
+ trainer_devices: Union[int, List[int], str] = attrs.field(default="auto")
+ trainer_accelerator: str = attrs.field(default="auto")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
trainer_devices: Union[int, List[int], str] = "auto" | |
trainer_accelerator: str = "auto" | |
trainer_devices: Union[int, List[int], str] = attrs.field(default="auto") | |
trainer_accelerator: str = attrs.field(default="auto") |
🧰 Tools
🪛 Ruff
59-59: Undefined name Union
(F821)
59-59: Undefined name List
(F821)
sleap_nn/config/trainer_config.py
Outdated
config = OmegaConf.load(filename) | ||
return OmegaConf.to_object(config, cls) | ||
|
||
def to_dict(self) -> DictConfig: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import DictConfig
from omegaconf
module.
The method to_dict
returns DictConfig
, but DictConfig
is not imported. Add the import at the top of the file to resolve the undefined name.
Apply this diff to import DictConfig
:
from omegaconf import OmegaConf
+from omegaconf import DictConfig
Or combine the imports:
-from omegaconf import OmegaConf
+from omegaconf import OmegaConf, DictConfig
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def to_dict(self) -> DictConfig: | |
from omegaconf import OmegaConf, DictConfig | |
def to_dict(self) -> DictConfig: |
🧰 Tools
🪛 Ruff
153-153: Undefined name DictConfig
(F821)
sleap_nn/config/trainer_config.py
Outdated
def load_json(cls, filename: Text) -> "TrainingJobConfig": | ||
"""Load a training job configuration from a json file. | ||
|
||
Arguments: | ||
filename: Path to a training job configuration JSON file or a directory | ||
containing `"training_job.json"`. | ||
|
||
Returns: | ||
A TrainingJobConfig instance parsed from the json file. | ||
""" | ||
with open(filename, "r") as f: | ||
json_data = f.read() | ||
return cls.from_json(json_data) | ||
|
||
@classmethod | ||
def load_yaml(cls, filename:Text) -> "TrainingJobConfig": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update return type annotations in loader methods.
The methods load_json
and load_yaml
also have incorrect return type annotations referencing "TrainingJobConfig"
. Update them to "TrainerConfig"
to match the class name.
Apply this diff to correct the return type annotations:
@classmethod
- def load_json(cls, filename: Text) -> "TrainingJobConfig":
+ def load_json(cls, filename: Text) -> "TrainerConfig":
@classmethod
- def load_yaml(cls, filename: Text) -> "TrainingJobConfig":
+ def load_yaml(cls, filename: Text) -> "TrainerConfig":
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def load_json(cls, filename: Text) -> "TrainingJobConfig": | |
"""Load a training job configuration from a json file. | |
Arguments: | |
filename: Path to a training job configuration JSON file or a directory | |
containing `"training_job.json"`. | |
Returns: | |
A TrainingJobConfig instance parsed from the json file. | |
""" | |
with open(filename, "r") as f: | |
json_data = f.read() | |
return cls.from_json(json_data) | |
@classmethod | |
def load_yaml(cls, filename:Text) -> "TrainingJobConfig": | |
def load_json(cls, filename: Text) -> "TrainerConfig": | |
"""Load a training job configuration from a json file. | |
Arguments: | |
filename: Path to a training job configuration JSON file or a directory | |
containing `"training_job.json"`. | |
Returns: | |
A TrainingJobConfig instance parsed from the json file. | |
""" | |
with open(filename, "r") as f: | |
json_data = f.read() | |
return cls.from_json(json_data) | |
@classmethod | |
def load_yaml(cls, filename:Text) -> "TrainerConfig": |
🧰 Tools
🪛 Ruff
125-125: Undefined name Text
(F821)
125-125: Undefined name TrainingJobConfig
(F821)
140-140: Undefined name Text
(F821)
140-140: Undefined name TrainingJobConfig
(F821)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
sleap_nn/config/trainer_config.py
(1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/trainer_config.py
56-56: Undefined name DataLoaderConfig
(F821)
56-56: Undefined name DataLoaderConfig
(F821)
57-57: Undefined name DataLoaderConfig
(F821)
57-57: Undefined name DataLoaderConfig
(F821)
58-58: Undefined name ModelCkptConfig
(F821)
58-58: Undefined name ModelCkptConfig
(F821)
59-59: Undefined name Union
(F821)
59-59: Undefined name List
(F821)
62-62: Undefined name Optional
(F821)
64-64: Undefined name Optional
(F821)
68-68: Undefined name Optional
(F821)
69-69: Undefined name Optional
(F821)
69-69: Undefined name WandBConfig
(F821)
70-70: Undefined name Optional
(F821)
70-70: Undefined name OptimizerConfig
(F821)
70-70: Undefined name OptimizerConfig
(F821)
71-71: Undefined name LRSchedulerConfig
(F821)
71-71: Undefined name LRSchedulerConfig
(F821)
72-72: Undefined name EarlyStoppingConfig
(F821)
72-72: Undefined name attr
(F821)
72-72: Undefined name EarlyStoppingConfig
(F821)
85-85: Undefined name Dict
(F821)
85-85: Undefined name Text
(F821)
85-85: Undefined name Any
(F821)
96-96: Undefined name Omega
(F821)
99-99: Undefined name Text
(F821)
108-108: Undefined name json
(F821)
112-112: Undefined name Text
(F821)
125-125: Undefined name Text
(F821)
140-140: Undefined name Text
(F821)
153-153: Undefined name DictConfig
(F821)
168-168: Undefined name json
(F821)
179-179: Undefined name Text
(F821)
188-188: Undefined name Text
(F821)
223-223: Undefined name Optional
(F821)
240-240: Undefined name Optional
(F821)
241-241: Undefined name Optional
(F821)
242-242: Undefined name Optional
(F821)
243-243: Undefined name Optional
(F821)
244-244: Undefined name Optional
(F821)
245-245: Undefined name Optional
(F821)
246-246: Undefined name Optional
(F821)
246-246: Undefined name List
(F821)
282-282: Undefined name Union
(F821)
282-282: Undefined name List
(F821)
🔇 Additional comments (6)
sleap_nn/config/trainer_config.py (6)
96-96
:
Fix incorrect OmegaConf usage.
The Omega.to_object
call is incorrect. It should be OmegaConf.to_object
.
- return Omega.to_object(config, cls)
+ return OmegaConf.to_object(config, cls)
Likely invalid or redundant comment.
🧰 Tools
🪛 Ruff
96-96: Undefined name Omega
(F821)
295-297
: 🛠️ Refactor suggestion
Use attrs.field in EarlyStoppingConfig for consistency.
- stop_training_on_plateau: bool = False
- min_delta: float = 0.0
- patience: int = 1
+ stop_training_on_plateau: bool = attrs.field(default=False)
+ min_delta: float = attrs.field(default=0.0)
+ patience: int = attrs.field(default=1)
Likely invalid or redundant comment.
258-259
: 🛠️ Refactor suggestion
Use attrs.field in OptimizerConfig for consistency.
- lr: float = 1e-3
- amsgrad: bool = False
+ lr: float = attrs.field(default=1e-3)
+ amsgrad: bool = attrs.field(default=False)
Likely invalid or redundant comment.
208-210
: 🛠️ Refactor suggestion
Use attrs.field in DataLoaderConfig for consistency.
For consistency with TrainerConfig and to leverage attrs features, use attrs.field for attribute definitions.
- batch_size: int = 1
- shuffle: bool = False
- num_workers: int = 0
+ batch_size: int = attrs.field(default=1)
+ shuffle: bool = attrs.field(default=False)
+ num_workers: int = attrs.field(default=0)
Likely invalid or redundant comment.
222-223
: 🛠️ Refactor suggestion
Use attrs.field in ModelCkptConfig for consistency.
- save_top_k: int = 1
- save_last: Optional[bool] = None
+ save_top_k: int = attrs.field(default=1)
+ save_last: Optional[bool] = attrs.field(default=None)
Likely invalid or redundant comment.
🧰 Tools
🪛 Ruff
223-223: Undefined name Optional
(F821)
72-72
:
Fix incorrect attrs import usage.
The attr.field
usage is incorrect and inconsistent with other field definitions in the class.
- early_stopping: EarlyStoppingConfig = attr.field(factory=EarlyStoppingConfig)
+ early_stopping: EarlyStoppingConfig = attrs.field(factory=EarlyStoppingConfig)
Likely invalid or redundant comment.
🧰 Tools
🪛 Ruff
72-72: Undefined name EarlyStoppingConfig
(F821)
72-72: Undefined name attr
(F821)
72-72: Undefined name EarlyStoppingConfig
(F821)
sleap_nn/config/trainer_config.py
Outdated
return OmegaConf.to_object(config, cls) | ||
|
||
@classmethod | ||
def load_json(cls, filename: Text) -> "TrainerConfig": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix incorrect return type hints.
The return type hints in load_json
and load_yaml
methods incorrectly reference "TrainingJobConfig" instead of "TrainerConfig".
- def load_json(cls, filename: Text) -> "TrainingJobConfig":
+ def load_json(cls, filename: Text) -> "TrainerConfig":
- def load_yaml(cls, filename: Text) -> "TrainingJobConfig":
+ def load_yaml(cls, filename: Text) -> "TrainerConfig":
Also applies to: 140-140
🧰 Tools
🪛 Ruff
125-125: Undefined name Text
(F821)
sleap_nn/config/trainer_config.py
Outdated
optimizer_name: (str) Optimizer to be used. One of ["Adam", "AdamW"]. | ||
optimizer: | ||
lr_scheduler: | ||
early_stopping: | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add missing optimizer_name attribute.
The optimizer_name
attribute is documented but not implemented in the class.
optimizer: Optional[OptimizerConfig] = attrs.field(factory=OptimizerConfig)
+ optimizer_name: str = attrs.field(default="Adam")
lr_scheduler: LRSchedulerConfig = attrs.field(factory=LRSchedulerConfig)
Committable suggestion skipped: line range outside the PR's diff.
import attrs | ||
from omegaconf import OmegaConf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add missing imports.
Several required imports are missing. Add these imports at the top of the file:
import attrs
from omegaconf import OmegaConf
+from omegaconf import DictConfig
+from typing import Any, Dict, List, Optional, Text, Union
+import json
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
import attrs | |
from omegaconf import OmegaConf | |
import attrs | |
from omegaconf import OmegaConf | |
from omegaconf import DictConfig | |
from typing import Any, Dict, List, Optional, Text, Union | |
import json |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
sleap_nn/config/trainer_config.py
(1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/trainer_config.py
56-56: Undefined name DataLoaderConfig
(F821)
56-56: Undefined name DataLoaderConfig
(F821)
57-57: Undefined name DataLoaderConfig
(F821)
57-57: Undefined name DataLoaderConfig
(F821)
58-58: Undefined name ModelCkptConfig
(F821)
58-58: Undefined name ModelCkptConfig
(F821)
59-59: Undefined name Union
(F821)
59-59: Undefined name List
(F821)
62-62: Undefined name Optional
(F821)
64-64: Undefined name Optional
(F821)
68-68: Undefined name Optional
(F821)
69-69: Undefined name Optional
(F821)
69-69: Undefined name WandBConfig
(F821)
70-70: Undefined name Optional
(F821)
70-70: Undefined name OptimizerConfig
(F821)
70-70: Undefined name OptimizerConfig
(F821)
71-71: Undefined name LRSchedulerConfig
(F821)
71-71: Undefined name LRSchedulerConfig
(F821)
72-72: Undefined name EarlyStoppingConfig
(F821)
72-72: Undefined name attr
(F821)
72-72: Undefined name EarlyStoppingConfig
(F821)
85-85: Undefined name Text
(F821)
98-98: Undefined name Text
(F821)
120-120: Undefined name Text
(F821)
155-155: Undefined name Optional
(F821)
172-172: Undefined name Optional
(F821)
173-173: Undefined name Optional
(F821)
174-174: Undefined name Optional
(F821)
175-175: Undefined name Optional
(F821)
176-176: Undefined name Optional
(F821)
177-177: Undefined name Optional
(F821)
178-178: Undefined name Optional
(F821)
178-178: Undefined name List
(F821)
214-214: Undefined name Union
(F821)
214-214: Undefined name List
(F821)
🔇 Additional comments (7)
sleap_nn/config/trainer_config.py (7)
154-155
: 🛠️ Refactor suggestion
Use attrs.field in ModelCkptConfig for consistency.
For consistency with other classes and to leverage attrs features, use attrs.field
for attribute definitions.
- save_top_k: int = 1
- save_last: Optional[bool] = None
+ save_top_k: int = attrs.field(default=1)
+ save_last: Optional[bool] = attrs.field(default=None)
Likely invalid or redundant comment.
🧰 Tools
🪛 Ruff
155-155: Undefined name Optional
(F821)
1-2
:
Add missing imports.
Several required imports are missing. Add these imports at the top of the file:
import attrs
from omegaconf import OmegaConf
+from omegaconf import DictConfig
+from typing import Any, Dict, List, Optional, Text, Union
+import json
Likely invalid or redundant comment.
72-72
:
Fix incorrect attrs import usage.
The early_stopping
attribute uses attr.field
instead of attrs.field
. This will cause a runtime error.
- early_stopping: EarlyStoppingConfig = attr.field(factory=EarlyStoppingConfig)
+ early_stopping: EarlyStoppingConfig = attrs.field(factory=EarlyStoppingConfig)
Likely invalid or redundant comment.
🧰 Tools
🪛 Ruff
72-72: Undefined name EarlyStoppingConfig
(F821)
72-72: Undefined name attr
(F821)
72-72: Undefined name EarlyStoppingConfig
(F821)
227-229
: 🛠️ Refactor suggestion
Use attrs.field in EarlyStoppingConfig for consistency.
For consistency with other classes and to leverage attrs features, use attrs.field
for attribute definitions.
- stop_training_on_plateau: bool = False
- min_delta: float = 0.0
- patience: int = 1
+ stop_training_on_plateau: bool = attrs.field(default=False)
+ min_delta: float = attrs.field(default=0.0)
+ patience: int = attrs.field(default=1)
Likely invalid or redundant comment.
50-54
:
Implement documented optimizer_name attribute.
The optimizer_name
attribute is documented in the class docstring but not implemented in the class definition.
+ optimizer_name: str = attrs.field(default="Adam")
optimizer: Optional[OptimizerConfig] = attrs.field(factory=OptimizerConfig)
Likely invalid or redundant comment.
140-142
: 🛠️ Refactor suggestion
Use attrs.field in DataLoaderConfig for consistency.
For consistency with other classes and to leverage attrs features, use attrs.field
for attribute definitions.
- batch_size: int = 1
- shuffle: bool = False
- num_workers: int = 0
+ batch_size: int = attrs.field(default=1)
+ shuffle: bool = attrs.field(default=False)
+ num_workers: int = attrs.field(default=0)
Likely invalid or redundant comment.
59-60
: 🛠️ Refactor suggestion
Use attrs.field for consistency.
For consistency with other attributes and to leverage attrs features, use attrs.field
for these attributes.
- trainer_devices: Union[int, List[int], str] = "auto"
- trainer_accelerator: str = "auto"
+ trainer_devices: Union[int, List[int], str] = attrs.field(default="auto")
+ trainer_accelerator: str = attrs.field(default="auto")
Likely invalid or redundant comment.
🧰 Tools
🪛 Ruff
59-59: Undefined name Union
(F821)
59-59: Undefined name List
(F821)
sleap_nn/config/trainer_config.py
Outdated
mode: str = "min" | ||
threshold: float = 1e-4 | ||
threshold_mode: str = "rel" | ||
cooldown: int = 0 | ||
patience: int = 10 | ||
factor: float = 0.1 | ||
min_lr: Union[float, List[float]] = 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use attrs.field in LRSchedulerConfig for consistency.
For consistency with other classes and to leverage attrs features, use attrs.field
for attribute definitions.
- mode: str = "min"
- threshold: float = 1e-4
- threshold_mode: str = "rel"
- cooldown: int = 0
- patience: int = 10
- factor: float = 0.1
- min_lr: Union[float, List[float]] = 0.0
+ mode: str = attrs.field(default="min")
+ threshold: float = attrs.field(default=1e-4)
+ threshold_mode: str = attrs.field(default="rel")
+ cooldown: int = attrs.field(default=0)
+ patience: int = attrs.field(default=10)
+ factor: float = attrs.field(default=0.1)
+ min_lr: Union[float, List[float]] = attrs.field(default=0.0)
Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff
214-214: Undefined name Union
(F821)
214-214: Undefined name List
(F821)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Outside diff range and nitpick comments (5)
tests/config/test_trainer_config.py (1)
13-27
: Consider adding edge case tests for DataLoaderConfig.While the basic functionality is well tested, consider adding tests for:
- Negative batch sizes
- Very large batch sizes
- Negative number of workers
Example addition:
def test_dataloader_config_edge_cases(): with pytest.raises(ValueError): OmegaConf.structured(DataLoaderConfig(batch_size=-1)) with pytest.raises(ValueError): OmegaConf.structured(DataLoaderConfig(num_workers=-1))sleap_nn/config/trainer_config.py (4)
43-45
: Use attrs.field() for attribute definitions.For consistency with other classes and to leverage attrs features like validation, use attrs.field().
- batch_size: int = 1 - shuffle: bool = False - num_workers: int = 0 + batch_size: int = attrs.field(default=1) + shuffle: bool = attrs.field(default=False) + num_workers: int = attrs.field(default=0)
117-117
: Improve type hint for min_lr attribute.The min_lr attribute is validated to be either float or List[float], but typed as Any. Update the type hint to match the validation.
- min_lr: Any = 0.0 + min_lr: Union[float, List[float]] = attrs.field(default=0.0)
173-173
: Improve type hint for trainer_devices attribute.The trainer_devices attribute has specific validation rules but is typed as Any. Update the type hint to match the validation.
- trainer_devices: Any = "auto" + trainer_devices: Union[int, List[int], str] = attrs.field( + default="auto", + validator=attrs.validators.instance_of((int, list, str)) + )
253-254
: Add error handling for file operations.The file write operation should handle potential IOErrors.
- with open(filename, "w") as f: - f.write(self.to_yaml()) + try: + with open(filename, "w") as f: + f.write(self.to_yaml()) + except IOError as e: + raise IOError(f"Failed to save configuration to {filename}: {e}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
sleap_nn/config/trainer_config.py
(1 hunks)tests/config/test_trainer_config.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.7.0)
sleap_nn/config/trainer_config.py
3-3: typing.Union
imported but unused
Remove unused import: typing.Union
(F401)
tests/config/test_trainer_config.py
1-1: pytest
imported but unused
Remove unused import: pytest
(F401)
🔇 Additional comments (2)
tests/config/test_trainer_config.py (1)
1-11
: LGTM! Ignore the unused import warning.
The pytest
import is required for the test framework to function properly, even though it's not explicitly used in the code. The static analysis warning can be safely ignored.
🧰 Tools
🪛 Ruff (0.7.0)
1-1: pytest
imported but unused
Remove unused import: pytest
(F401)
sleap_nn/config/trainer_config.py (1)
1-30
: LGTM! Well-documented module with clear design philosophy.
The imports are appropriate and the module documentation clearly explains the purpose, design decisions, and benefits of the configuration system.
🧰 Tools
🪛 Ruff (0.7.0)
3-3: typing.Union
imported but unused
Remove unused import: typing.Union
(F401)
def test_lr_scheduler_config(): | ||
# Check default values | ||
conf = OmegaConf.structured(LRSchedulerConfig) | ||
assert conf.mode == "min" | ||
assert conf.threshold == 1e-4 | ||
assert conf.patience == 10 | ||
|
||
# Test customization | ||
custom_conf = OmegaConf.structured(LRSchedulerConfig(mode="max", patience=5, factor=0.5)) | ||
assert custom_conf.mode == "max" | ||
assert custom_conf.patience == 5 | ||
assert custom_conf.factor == 0.5 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Enhance LRSchedulerConfig test coverage.
The current tests don't validate the mode
field constraints. Consider adding tests for:
- Invalid mode values
- Boundary conditions for patience and threshold
- Edge cases for the factor value (0.0, 1.0, >1.0)
Example addition:
def test_lr_scheduler_config_validation():
with pytest.raises(ValueError):
OmegaConf.structured(LRSchedulerConfig(mode="invalid"))
with pytest.raises(ValueError):
OmegaConf.structured(LRSchedulerConfig(factor=0.0))
def test_trainer_config(): | ||
# Check default values | ||
conf = OmegaConf.structured(TrainerConfig) | ||
assert conf.train_data_loader.batch_size == 1 | ||
assert conf.val_data_loader.shuffle is False | ||
assert conf.model_ckpt.save_top_k == 1 | ||
assert conf.optimizer.lr == 1e-3 | ||
assert conf.lr_scheduler.mode == "min" | ||
assert conf.early_stopping.patience == 1 | ||
assert conf.use_wandb is False | ||
assert conf.save_ckpt_path == "./" | ||
|
||
# Test customization | ||
custom_conf = OmegaConf.structured( | ||
TrainerConfig( | ||
max_epochs=20, | ||
train_data_loader=DataLoaderConfig(batch_size=32), | ||
optimizer=OptimizerConfig(lr=0.01), | ||
use_wandb=True, | ||
) | ||
) | ||
assert custom_conf.max_epochs == 20 | ||
assert custom_conf.train_data_loader.batch_size == 32 | ||
assert custom_conf.optimizer.lr == 0.01 | ||
assert custom_conf.use_wandb is True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Enhance TrainerConfig integration testing.
While the basic integration is tested, consider adding:
- Tests for config relationships (e.g., early stopping patience vs lr scheduler patience)
- Validation of config serialization/deserialization
- Tests for invalid configuration combinations
Example additions:
def test_trainer_config_relationships():
# Test that early stopping patience is less than lr scheduler patience
conf = OmegaConf.structured(
TrainerConfig(
early_stopping=EarlyStoppingConfig(patience=10),
lr_scheduler=LRSchedulerConfig(patience=5)
)
)
with pytest.raises(ValueError):
conf.validate() # Assuming validate() method exists
def test_trainer_config_serialization():
conf = TrainerConfig(max_epochs=20)
yaml_str = OmegaConf.to_yaml(conf)
loaded_conf = OmegaConf.create(yaml_str)
assert OmegaConf.to_container(conf) == OmegaConf.to_container(loaded_conf)
entity: Optional[str] = None | ||
project: Optional[str] = None | ||
name: Optional[str] = None | ||
api_key: Optional[str] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Secure sensitive data in serialization.
The api_key in WandBConfig should be masked when serialized to prevent accidental exposure in configuration files.
Consider implementing custom serialization for the api_key field:
@attrs.define
class WandBConfig:
# ... other attributes ...
api_key: Optional[str] = attrs.field(
default=None,
repr=lambda value: '****' if value else None,
metadata={'sensitive': True}
)
Then update the to_yaml method in TrainerConfig to handle sensitive fields:
def to_yaml(self) -> str:
config = self.to_dict()
if self.wandb and self.wandb.api_key:
config['wandb']['api_key'] = '****'
return OmegaConf.to_yaml(config)
This is the first PR of #75. Here we want to:
Summary by CodeRabbit
New Features
Bug Fixes
Documentation