-
Notifications
You must be signed in to change notification settings - Fork 405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SemanticSegmentationTask: add class-wise metrics #2130
Conversation
Given that most metrics of interest are broken (e.g., all of them when I'm saying this because these are only the issues I've found so far, but I've also noticed other suspicious things like the fact that my classwise recall values are not the same as those in the confusion matrix when you normalize it with respect to ground truth (I haven't checked if this is also the case with precision, so when the matrix is normalized column-wise). I'm also pretty confident that if all of this is wrong then micro averaging is also probably wrong. I should be pretty easy to compute all these metrics straight from the confusion matrix (assuming it at least is correct) and I've actually tried to reimplent them this way but it hasn't really been a priority because I’ve found that all these wrong (?) values are basically a lower bound of the actual ones. If you look at the official implementations, this is actually what they are doing, and my guess is that they have a bug in their logic later on. But indeed all these metrics inherit from I’m actually pretty dumbfounded these issues are not a top priority for the TorchMetrics team and instead they focus on adding to their docs but to each their own… |
@DimitrisMantas good call on my ignoring the |
Sure, that makes sense; please excuse the rant haha. |
Applied
![]() ![]() Note that Val is unaffected:
For a task with 2 classes there are a grand total of |
I just set to be explicit but I think that pytorch lightning or torchmetrics auto sets on_epoch to be False for training and True for all else. |
You need to set both |
@DimitrisMantas now just performing |
Not sure about this failing test |
Must be an issue with on of the minimum versions of the package since it's passing for the other tests. |
We can definitely increase the min version of torchmetrics if we need to. |
@robmarkcole I can confirm the recommended approach yields consistent results. |
Sorry it's taken me so long to review. I was originally hung up on the hack required to support Only remaining concern is that the code required to loop over all metrics and averages actually makes the code more complicated and difficult to read than avoiding loops entirely. If we want to add new metrics in the future, it looks non-straightforward. I wonder if we can loop over averages only and still keep things simple. I would also really like to see this done for |
@adamjstewart please see above comment on using |
I think it's better practice to do |
To clarify if we want |
That's a good question, I normally set the train to only record steps and not the full epoch average because depending on how many metrics and the size of your train set this can use up a lot of memory. I think for now we can set both as true. |
From memory I think if using both the loss is named |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM. Any concerns with merging this @adamjstewart?
I don't think these concerns have been addressed: #2130 (comment) |
I agree, I think just making a list of the metrics instead of using loops is probably more readable and easier to add/remove individual metrics based on a users need |
merging main really messed this branch up, will create a new branch/PR |
Coming back to this to say I just reimplemented in a private project, would love to have this upstream! |
Honestly dont know when I will get around to it - the last occasion I had a few hours I ended up debugging dataset issues |
@@ -74,8 +74,8 @@ dependencies = [ | |||
"timm>=0.4.12", | |||
# torch 1.13+ required by torchvision | |||
"torch>=1.13", | |||
# torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics | |||
"torchmetrics>=0.10", | |||
# torchmetrics 1.1.1+ required for average argument to MeanAveragePrecision |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to Lightning-AI/torchmetrics@63c7bbe the argument didn't exist until 1.2
Addresses #2121 for segmentation. Mostly copied from @isaaccorley as here - he is additionally passing
on_epoch=True
which is also adopted here