-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Pafs shape #121
Fix Pafs shape #121
Conversation
WalkthroughThe pull request introduces significant changes to the tensor shapes returned by several functions in the Changes
Possibly related PRs
Suggested reviewers
Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #121 +/- ##
==========================================
+ Coverage 96.64% 97.61% +0.96%
==========================================
Files 23 39 +16
Lines 1818 3982 +2164
==========================================
+ Hits 1757 3887 +2130
- Misses 61 95 +34 ☔ View full report in Codecov by Sentry. |
6b672c5
to
99ddf86
Compare
2be7004
to
26ecbd5
Compare
26ecbd5
to
24eada7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Outside diff range and nitpick comments (4)
sleap_nn/data/edge_maps.py (4)
149-149
: Update docstring to reflect new output tensor shape inmake_pafs
At line 149, the function's docstring describes the output tensor shape as
(n_edges, 2, grid_height, grid_width)
. Ensure that this documentation accurately reflects the function's updated return value for clarity.
195-195
: Correct output shape description inmake_multi_pafs
docstringAt line 195, the docstring specifies the output tensor shape as
(n_edges, 2, grid_height, grid_width)
. Verify that this matches the actual output of the function after the recent changes.
273-274
: Update return shape ingenerate_pafs
documentationIn lines 273 to 274, the return shape is documented as
(n_edges, 2, grid_height, grid_width)
. Ensure this reflects the actual output shape of the function and update any dependent documentation accordingly.
341-342
: Adjust class docstring to reflect new PAF shapeIn the
PartAffinityFieldsGenerator
class docstring at lines 341 to 342, update the tensor shape description to(n_edges, 2, height, width)
to match the updated output format and aid in developer understanding.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (8)
sleap_nn/data/edge_maps.py
(8 hunks)sleap_nn/training/model_trainer.py
(2 hunks)tests/data/test_custom_datasets.py
(3 hunks)tests/data/test_edge_maps.py
(4 hunks)tests/data/test_pipelines.py
(3 hunks)tests/data/test_streaming_datasets.py
(2 hunks)tests/inference/test_predictors.py
(1 hunks)tests/training/test_model_trainer.py
(2 hunks)
🔇 Additional comments (17)
tests/data/test_edge_maps.py (4)
84-105
: Verify correctness of updated test assertions
In the test_make_pafs
function, the expected output values in the assertions from lines 84 to 105 have been updated. Ensure that these values accurately reflect the intended outputs given the changes in tensor shapes and data computations.
143-164
: Confirm updated assertions for multi-instance PAFs
In the test_make_multi_pafs
function, the assertions from lines 143 to 164 have been modified to match the new output tensor shapes. Verify that these changes are correct and consistent with the updated function behavior.
207-207
: Ensure expected PAF shape in test_generate_pafs
At line 207, the assertion checks that pafs.shape == (1, 2, 192, 192)
. Confirm that this expected shape aligns with the changes made to the generate_pafs
function and the overall tensor shape conventions.
219-219
: Validate output shape in test_part_affinity_fields_generator
At line 219, the assertion verifies that part_affinity_fields
has the shape (1, 2, 192, 192)
. Ensure that this shape is consistent with the new data structures and that the generator produces outputs matching this expected shape.
tests/data/test_streaming_datasets.py (2)
62-62
: Update assertion to match new PAF tensor shape
At line 62, the assertion expects samples[0]["part_affinity_fields"].shape
to be (2, 50, 50)
. Verify that this shape is correct based on the recent changes to the PAF tensor dimensions in the dataset.
106-106
: Adjust assertion for PAF shape after augmentation
At line 106, the assertion checks for samples[0]["part_affinity_fields"].shape == (2, 75, 75)
. Confirm that this shape accurately reflects the output after applying random crop augmentation and the updated tensor structure.
sleap_nn/data/edge_maps.py (5)
203-203
: Initialize PAF tensor with updated dimensions
At line 203, the PAF tensor is initialized with shape (n_edges, 2, grid_height, grid_width)
. This adjustment ensures consistency with the new tensor shape used throughout the function.
318-318
: Confirm assertion for PAF tensor shape
At line 318, the assertion checks that pafs.shape == (n_edges, 2, grid_height, grid_width)
. Verify that this assertion is correct given the updated tensor shapes to prevent potential runtime errors.
321-322
: Validate reshaping logic when flattening channels
In lines 321 to 322, when flatten_channels
is True, the PAF tensor is reshaped to (n_edges * 2, grid_height, grid_width)
. Ensure that this reshaping correctly handles the updated tensor dimensions and maintains data integrity.
437-437
: Assert correct PAF tensor shape after adjustments
At line 437, the assertion verifies pafs.shape == (n_edges, 2, grid_height, grid_width)
. Confirm that this assertion is accurate based on the recent changes to the tensor dimensions.
440-441
: Ensure proper reshaping of PAFs when flattening
In lines 440 to 441, the PAF tensor is reshaped when flatten_channels
is True. Verify that the reshaping logic accurately reflects the new tensor structure to prevent data corruption.
tests/data/test_custom_datasets.py (1)
55-55
: LGTM: PAF shape assertions correctly updated to channel-first format
The test assertions for part_affinity_fields
shape have been consistently updated across all test cases to use the channel-first format (n_edges*2, height, width)
, aligning with PyTorch's conventions and the PR objectives.
Also applies to: 97-97, 172-172
tests/data/test_pipelines.py (1)
589-589
: LGTM: PAF shape assertions correctly updated in pipeline tests
The test assertions for part_affinity_fields
shape in the BottomUpPipeline tests have been consistently updated to use the channel-first format (n_edges*2, height, width)
, maintaining consistency with the PR objectives.
Also applies to: 633-633, 712-712
tests/training/test_model_trainer.py (2)
688-701
: Good addition: Validation step testing for TopDownCenteredInstanceModel
Added comprehensive validation step testing to ensure model output shapes and loss calculations are correct during validation. This improves test coverage.
779-791
: Good addition: Validation step testing for CentroidModel
Added comprehensive validation step testing to ensure model output shapes and loss calculations are correct during validation. This improves test coverage.
sleap_nn/training/model_trainer.py (2)
Line range hint 1172-1178
: LGTM: Consistent PAFs shape handling in validation step.
The validation step mirrors the training step changes, maintaining consistency in PAFs tensor shape handling throughout the model pipeline.
Line range hint 1153-1159
: LGTM: PAFs shape change in training step.
The removal of tensor permutation aligns with the PR objective to use channel-first ordering (n_edges*2, height, width) for PAFs. The loss calculation now directly uses the model output without shape transformation.
Let's verify the tensor shapes in the data pipeline:
✅ Verification successful
The data pipeline documentation confirms channel-first ordering (n_edges*2, grid_height, grid_width) when flatten_channels is True. Let's verify the model's output shape matches this format to ensure compatibility with the loss calculation.
The previous searches show PAFs are used with shape indexing in inference, but we need to verify the model's output shape consistency with the data pipeline's documented shape.
LGTM: PAFs shape change is consistent with data pipeline
Based on the findings:
- Data pipeline documentation confirms PAFs shape as (n_edges*2, grid_height, grid_width) when flatten_channels is True
- The removal of permutation operations in the training step aligns with this channel-first format
- The loss calculation directly uses the model output, indicating shape compatibility
The changes in the training step are correct and maintain consistency with the documented tensor shapes in the data pipeline. The direct usage of PAFs output in loss calculation without permutation is appropriate as both the model output and target tensors follow the same channel-first ordering.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check PAFs tensor shape in data pipeline
# Expect: Consistent channel-first ordering
# Search for PAFs shape definition in data pipeline
rg -A 5 "part_affinity_fields.*shape" sleap_nn/data/
Length of output: 1853
Script:
#!/bin/bash
# Let's verify the model's output shape and loss calculation
# Search for PAFs shape in model definition and forward pass
rg -A 5 "PartAffinityFieldsHead.*forward" sleap_nn/
# Also check if there are any shape transformations in the model
rg -B 2 -A 2 "permute.*paf" sleap_nn/
Length of output: 93
Script:
#!/bin/bash
# Let's try a different approach to find the PAFs head implementation
ast-grep --pattern 'class PartAffinityFieldsHead {
$$$
forward($_) {
$$$
}
$$$
}'
# Also search for any PAF-related shape operations
rg -B 2 -A 2 "\.shape.*paf" sleap_nn/
Length of output: 454
Script:
#!/bin/bash
# Let's try one final search to find the model implementation
rg -l "PartAffinityFieldsHead" sleap_nn/
# And look for any tensor operations in the model trainer
rg -A 5 "class BottomUpModel" sleap_nn/training/model_trainer.py
Length of output: 458
This PR modifies the shape of pafs generated in the bottom-up pipeline. Currently, the shape of pafs generated in the data pipeline is
(height, width, n_edges*2)
. Since, torch supports channel-first ordering, the pafs output of the torch model is of the shape(n_edges*2, height, width)
. In this PR, we fix the shape of pafs generated in the data pipeline by generating pafs with shape(n_edges*2, height, width)
to be consistent with the output of torch model.Summary by CodeRabbit
New Features
Predictor
class with new scenarios and error handling.TopDownCenteredInstanceModel
andCentroidModel
in the training tests.Bug Fixes
Tests