Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Pafs shape #121

Merged
merged 5 commits into from
Dec 10, 2024
Merged

Fix Pafs shape #121

merged 5 commits into from
Dec 10, 2024

Conversation

gitttt-1234
Copy link
Contributor

@gitttt-1234 gitttt-1234 commented Nov 22, 2024

This PR modifies the shape of pafs generated in the bottom-up pipeline. Currently, the shape of pafs generated in the data pipeline is (height, width, n_edges*2). Since, torch supports channel-first ordering, the pafs output of the torch model is of the shape (n_edges*2, height, width). In this PR, we fix the shape of pafs generated in the data pipeline by generating pafs with shape (n_edges*2, height, width) to be consistent with the output of torch model.

Summary by CodeRabbit

  • New Features

    • Updated tensor shapes for part affinity fields and edge maps in various functions.
    • Enhanced testing for the Predictor class with new scenarios and error handling.
    • Added validation steps for TopDownCenteredInstanceModel and CentroidModel in the training tests.
  • Bug Fixes

    • Adjusted shape assertions in multiple test cases to reflect new output formats for part affinity fields and confidence maps.
  • Tests

    • Expanded test coverage for model checkpoint loading and validation steps across different models.
    • Updated expected output shapes in several dataset and pipeline tests.

Copy link
Contributor

coderabbitai bot commented Nov 22, 2024

Walkthrough

The pull request introduces significant changes to the tensor shapes returned by several functions in the sleap_nn/data/edge_maps.py file, specifically for generating part affinity fields (PAFs) and edge maps. The tensor shapes have been altered from (grid_height, grid_width, n_edges, 2) to (n_edges, 2, grid_height, grid_width). Corresponding modifications have been made in the BottomUpModel class within sleap_nn/training/model_trainer.py, where the permutation of tensor dimensions has been removed. Additionally, tests across multiple files have been updated to reflect these changes in expected output shapes.

Changes

File Change Summary
sleap_nn/data/edge_maps.py Updated tensor shapes in make_pafs, make_multi_pafs, and generate_pafs functions. Added dimension permutation in make_pafs. Updated docstrings.
sleap_nn/training/model_trainer.py Removed tensor dimension permutation in BottomUpModel class's training_step and validation_step methods.
tests/data/test_custom_datasets.py Updated expected shapes for part_affinity_fields and confidence_maps in various test cases.
tests/data/test_edge_maps.py Modified expected output structures for make_pafs, make_multi_pafs, and generate_pafs functions in tests.
tests/data/test_pipelines.py Adjusted expected shapes for part_affinity_fields in BottomUpPipeline tests. Added new key in assertions.
tests/data/test_streaming_datasets.py Updated shape assertions for part_affinity_fields and confidence_maps in streaming dataset tests.
tests/inference/test_predictors.py Enhanced tests for Predictor class, focusing on model checkpoint loading and error handling.
tests/training/test_model_trainer.py Added validation checks for TopDownCenteredInstanceModel and CentroidModel, ensuring output shapes and loss calculations are correct.

Possibly related PRs

Suggested reviewers

  • talmo

Poem

🐇 In the land of code where tensors play,
Shapes are shifting in a grand ballet.
From heights and widths to edges anew,
Our data dances, vibrant and true.
With tests refined and models bright,
We hop along, coding with delight! 🌟


📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 847bee8 and 39bc560.

📒 Files selected for processing (1)
  • tests/training/test_model_trainer.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/training/test_model_trainer.py

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR. (Experiment)
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

codecov bot commented Nov 22, 2024

Codecov Report

Attention: Patch coverage is 90.00000% with 1 line in your changes missing coverage. Please review.

Project coverage is 97.61%. Comparing base (f093ce2) to head (39bc560).
Report is 28 commits behind head on main.

Files with missing lines Patch % Lines
sleap_nn/training/model_trainer.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #121      +/-   ##
==========================================
+ Coverage   96.64%   97.61%   +0.96%     
==========================================
  Files          23       39      +16     
  Lines        1818     3982    +2164     
==========================================
+ Hits         1757     3887    +2130     
- Misses         61       95      +34     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@gitttt-1234 gitttt-1234 force-pushed the divya/torch-dataset-pipeline branch 2 times, most recently from 6b672c5 to 99ddf86 Compare December 9, 2024 22:22
@gitttt-1234 gitttt-1234 changed the base branch from divya/torch-dataset-pipeline to main December 9, 2024 22:29
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Outside diff range and nitpick comments (4)
sleap_nn/data/edge_maps.py (4)

149-149: Update docstring to reflect new output tensor shape in make_pafs

At line 149, the function's docstring describes the output tensor shape as (n_edges, 2, grid_height, grid_width). Ensure that this documentation accurately reflects the function's updated return value for clarity.


195-195: Correct output shape description in make_multi_pafs docstring

At line 195, the docstring specifies the output tensor shape as (n_edges, 2, grid_height, grid_width). Verify that this matches the actual output of the function after the recent changes.


273-274: Update return shape in generate_pafs documentation

In lines 273 to 274, the return shape is documented as (n_edges, 2, grid_height, grid_width). Ensure this reflects the actual output shape of the function and update any dependent documentation accordingly.


341-342: Adjust class docstring to reflect new PAF shape

In the PartAffinityFieldsGenerator class docstring at lines 341 to 342, update the tensor shape description to (n_edges, 2, height, width) to match the updated output format and aid in developer understanding.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between f7d89a3 and 847bee8.

📒 Files selected for processing (8)
  • sleap_nn/data/edge_maps.py (8 hunks)
  • sleap_nn/training/model_trainer.py (2 hunks)
  • tests/data/test_custom_datasets.py (3 hunks)
  • tests/data/test_edge_maps.py (4 hunks)
  • tests/data/test_pipelines.py (3 hunks)
  • tests/data/test_streaming_datasets.py (2 hunks)
  • tests/inference/test_predictors.py (1 hunks)
  • tests/training/test_model_trainer.py (2 hunks)
🔇 Additional comments (17)
tests/data/test_edge_maps.py (4)

84-105: Verify correctness of updated test assertions

In the test_make_pafs function, the expected output values in the assertions from lines 84 to 105 have been updated. Ensure that these values accurately reflect the intended outputs given the changes in tensor shapes and data computations.


143-164: Confirm updated assertions for multi-instance PAFs

In the test_make_multi_pafs function, the assertions from lines 143 to 164 have been modified to match the new output tensor shapes. Verify that these changes are correct and consistent with the updated function behavior.


207-207: Ensure expected PAF shape in test_generate_pafs

At line 207, the assertion checks that pafs.shape == (1, 2, 192, 192). Confirm that this expected shape aligns with the changes made to the generate_pafs function and the overall tensor shape conventions.


219-219: Validate output shape in test_part_affinity_fields_generator

At line 219, the assertion verifies that part_affinity_fields has the shape (1, 2, 192, 192). Ensure that this shape is consistent with the new data structures and that the generator produces outputs matching this expected shape.

tests/data/test_streaming_datasets.py (2)

62-62: Update assertion to match new PAF tensor shape

At line 62, the assertion expects samples[0]["part_affinity_fields"].shape to be (2, 50, 50). Verify that this shape is correct based on the recent changes to the PAF tensor dimensions in the dataset.


106-106: Adjust assertion for PAF shape after augmentation

At line 106, the assertion checks for samples[0]["part_affinity_fields"].shape == (2, 75, 75). Confirm that this shape accurately reflects the output after applying random crop augmentation and the updated tensor structure.

sleap_nn/data/edge_maps.py (5)

203-203: Initialize PAF tensor with updated dimensions

At line 203, the PAF tensor is initialized with shape (n_edges, 2, grid_height, grid_width). This adjustment ensures consistency with the new tensor shape used throughout the function.


318-318: Confirm assertion for PAF tensor shape

At line 318, the assertion checks that pafs.shape == (n_edges, 2, grid_height, grid_width). Verify that this assertion is correct given the updated tensor shapes to prevent potential runtime errors.


321-322: Validate reshaping logic when flattening channels

In lines 321 to 322, when flatten_channels is True, the PAF tensor is reshaped to (n_edges * 2, grid_height, grid_width). Ensure that this reshaping correctly handles the updated tensor dimensions and maintains data integrity.


437-437: Assert correct PAF tensor shape after adjustments

At line 437, the assertion verifies pafs.shape == (n_edges, 2, grid_height, grid_width). Confirm that this assertion is accurate based on the recent changes to the tensor dimensions.


440-441: Ensure proper reshaping of PAFs when flattening

In lines 440 to 441, the PAF tensor is reshaped when flatten_channels is True. Verify that the reshaping logic accurately reflects the new tensor structure to prevent data corruption.

tests/data/test_custom_datasets.py (1)

55-55: LGTM: PAF shape assertions correctly updated to channel-first format

The test assertions for part_affinity_fields shape have been consistently updated across all test cases to use the channel-first format (n_edges*2, height, width), aligning with PyTorch's conventions and the PR objectives.

Also applies to: 97-97, 172-172

tests/data/test_pipelines.py (1)

589-589: LGTM: PAF shape assertions correctly updated in pipeline tests

The test assertions for part_affinity_fields shape in the BottomUpPipeline tests have been consistently updated to use the channel-first format (n_edges*2, height, width), maintaining consistency with the PR objectives.

Also applies to: 633-633, 712-712

tests/training/test_model_trainer.py (2)

688-701: Good addition: Validation step testing for TopDownCenteredInstanceModel

Added comprehensive validation step testing to ensure model output shapes and loss calculations are correct during validation. This improves test coverage.


779-791: Good addition: Validation step testing for CentroidModel

Added comprehensive validation step testing to ensure model output shapes and loss calculations are correct during validation. This improves test coverage.

sleap_nn/training/model_trainer.py (2)

Line range hint 1172-1178: LGTM: Consistent PAFs shape handling in validation step.

The validation step mirrors the training step changes, maintaining consistency in PAFs tensor shape handling throughout the model pipeline.


Line range hint 1153-1159: LGTM: PAFs shape change in training step.

The removal of tensor permutation aligns with the PR objective to use channel-first ordering (n_edges*2, height, width) for PAFs. The loss calculation now directly uses the model output without shape transformation.

Let's verify the tensor shapes in the data pipeline:

✅ Verification successful

The data pipeline documentation confirms channel-first ordering (n_edges*2, grid_height, grid_width) when flatten_channels is True. Let's verify the model's output shape matches this format to ensure compatibility with the loss calculation.


The previous searches show PAFs are used with shape indexing in inference, but we need to verify the model's output shape consistency with the data pipeline's documented shape.


LGTM: PAFs shape change is consistent with data pipeline

Based on the findings:

  1. Data pipeline documentation confirms PAFs shape as (n_edges*2, grid_height, grid_width) when flatten_channels is True
  2. The removal of permutation operations in the training step aligns with this channel-first format
  3. The loss calculation directly uses the model output, indicating shape compatibility

The changes in the training step are correct and maintain consistency with the documented tensor shapes in the data pipeline. The direct usage of PAFs output in loss calculation without permutation is appropriate as both the model output and target tensors follow the same channel-first ordering.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check PAFs tensor shape in data pipeline
# Expect: Consistent channel-first ordering

# Search for PAFs shape definition in data pipeline
rg -A 5 "part_affinity_fields.*shape" sleap_nn/data/

Length of output: 1853


Script:

#!/bin/bash
# Let's verify the model's output shape and loss calculation
# Search for PAFs shape in model definition and forward pass
rg -A 5 "PartAffinityFieldsHead.*forward" sleap_nn/

# Also check if there are any shape transformations in the model
rg -B 2 -A 2 "permute.*paf" sleap_nn/

Length of output: 93


Script:

#!/bin/bash
# Let's try a different approach to find the PAFs head implementation
ast-grep --pattern 'class PartAffinityFieldsHead {
  $$$
  forward($_) {
    $$$
  }
  $$$
}'

# Also search for any PAF-related shape operations
rg -B 2 -A 2 "\.shape.*paf" sleap_nn/

Length of output: 454


Script:

#!/bin/bash
# Let's try one final search to find the model implementation
rg -l "PartAffinityFieldsHead" sleap_nn/

# And look for any tensor operations in the model trainer
rg -A 5 "class BottomUpModel" sleap_nn/training/model_trainer.py

Length of output: 458

tests/inference/test_predictors.py Show resolved Hide resolved
tests/inference/test_predictors.py Show resolved Hide resolved
tests/inference/test_predictors.py Show resolved Hide resolved
@gitttt-1234 gitttt-1234 merged commit bc75aaf into main Dec 10, 2024
6 of 7 checks passed
@gitttt-1234 gitttt-1234 deleted the divya/fix-pafs-shape branch December 10, 2024 05:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants