-
Notifications
You must be signed in to change notification settings - Fork 351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ChangeDetectionTask #2422
base: main
Are you sure you want to change the base?
Add ChangeDetectionTask #2422
Conversation
from .base import BaseTask | ||
|
||
|
||
class FocalJaccardLoss(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should go in torchgeo/losses/
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, I'll put this and XEntJaccardLoss
there.
loss: 'ce' | ||
model: 'unet' | ||
backbone: 'resnet18' | ||
in_channels: 13 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If stacking the channels should this be 2 * 13?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are no longer stacking channels, we are making all time series datasets (including change detection) into B x T x C x H x W
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK and it is not necessary to config T?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in_channels
is multiplied by 2 in configure_models
when initializing Unet.
labels: Optional labels to use for classes in metrics | ||
e.g. ["background", "change"] | ||
loss: Name of the loss function, currently supports | ||
'ce', 'jaccard', 'focal' or 'focal-jaccard' loss. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and ce-jaccard ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I missed that, thanks.
|
||
|
||
class XEntJaccardLoss(nn.Module): | ||
"""XEntJaccardLoss.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a paper ref for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not that I'm aware of, and I didn't find one from a quick search. This loss function was already in there with what I started from.
""" | ||
loss: str = self.hparams["loss"] | ||
ignore_index = self.hparams["ignore_index"] | ||
if loss == "ce": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noting that for change detection BCE may be more appropriate
I wonder if we should limit the scope to change between two timesteps and binary change - then we can use binary metrics and provide a template for the plot methods. I say this because this is the most common change detection task by a mile. Might also simplify the augmentations approach? Treating as a video sequence seems overkill. |
I'm personally okay with this, although @hfangcat has a recent work using multiple pre-event images that would be nice to support someday (could be a subclass if necessary).
Again, this would probably be fine as a starting point, although I would someday like to make all trainers support binary/multiclass/multilabel, e.g., #2219.
Could also do this in the datasets (at least for benchmark NonGeoDatasets). We're also trying to remove explicit plotting in the trainers: #2184
Agreed.
I actually like the video augmentations, but let me loop in the Kornia folks to get their opinion: @edgarriba @johnnv1
Correct, see #2382 for the big picture (I think I also sent you a recording of my presented plan). |
Can you try
@ashnair1 would this work directly with |
I will go ahead and make changes for this to be for binary change and two timesteps, sounds like a good starting point.
I tried this and it didn't get rid of the other dimension. I also looked into
I was going to add plotting in the trainer, but would you rather not then? What would this look like in the dataset? |
Perhaps there should even be a base class ChangeDetection and subclasses for BinaryChangeDetection etc? |
That's exactly what I'm trying to undo in #2219. |
We can copy-n-paste the
See |
This PR is to add a change detection trainer as mentioned in #2382.
Key points/items to discuss:
AugmentationSequential
doesn’t work, but can be combined withVideoSequential
to support the temporal dimension (see Kornia docs). I overrodeself.aug
in theOSCDDataModule
to do this but not sure if this should be incorporated into theBaseDataModule
instead.VideoSequential
adds a temporal dimension to the mask. Not sure if there is a way to avoid this, or if this is desirable, but I added an if statement to theAugmentationSequential
wrapper to check for and remove this added dimension._RandomNCrop
augmentation, but this does not work for time series data. I'm not sure how to modify_RandomNCrop
to fix this and would appreciate some help/guidance.cc @robmarkcole