From a66d7c400eba7e44053fe0e3ed474df165c694d0 Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Fri, 8 Nov 2024 18:37:31 +0000 Subject: [PATCH 01/16] starting from PR #1760 --- torchgeo/trainers/__init__.py | 2 + torchgeo/trainers/change.py | 280 ++++++++++++++++++++++++++++++++++ 2 files changed, 282 insertions(+) create mode 100644 torchgeo/trainers/change.py diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index be4fb4a03db..fe79a37398d 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -5,6 +5,7 @@ from .base import BaseTask from .byol import BYOLTask +from .change import ChangeDetectionTask from .classification import ClassificationTask, MultiLabelClassificationTask from .detection import ObjectDetectionTask from .iobench import IOBenchTask @@ -15,6 +16,7 @@ __all__ = ( # Supervised + 'ChangeDetectionTask', 'ClassificationTask', 'MultiLabelClassificationTask', 'ObjectDetectionTask', diff --git a/torchgeo/trainers/change.py b/torchgeo/trainers/change.py new file mode 100644 index 00000000000..a39455c5caa --- /dev/null +++ b/torchgeo/trainers/change.py @@ -0,0 +1,280 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Trainers for change detection.""" + +import os +import warnings +from typing import Any, List, Optional, Union + +import segmentation_models_pytorch as smp +import torch +import torch.nn as nn +from torch import Tensor +from torchmetrics import MetricCollection +from torchmetrics.classification import ( + MulticlassAccuracy, + MulticlassF1Score, + MulticlassJaccardIndex, +) +from torchmetrics.wrappers import ClasswiseWrapper +from torchvision.models._api import WeightsEnum + +from ..models import FCSiamConc, FCSiamDiff, get_weight +from . import utils +from .base import BaseTask + + +class FocalJaccardLoss(nn.Module): + def __init__(self): + super().__init__() + self.focal_loss = smp.losses.FocalLoss( + mode="multiclass", normalized=True) + self.jaccard_loss = smp.losses.JaccardLoss(mode="multiclass") + + def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + return self.focal_loss(preds, targets) + self.jaccard_loss(preds, targets) + + +class XEntJaccardLoss(nn.Module): + def __init__(self): + super().__init__() + self.ce_loss = nn.CrossEntropyLoss() + self.jaccard_loss = smp.losses.JaccardLoss(mode="multiclass") + + def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + return self.ce_loss(preds, targets) + self.jaccard_loss(preds, targets) + + +class ChangeDetectionTask(BaseTask): + """Change Detection.""" + + def __init__( + self, + model: str = "unet", + backbone: str = "resnet50", + weights: Optional[Union[WeightsEnum, str, bool]] = None, + in_channels: int = 3, + num_classes: int = 2, + class_weights: Optional[Tensor] = None, + labels: Optional[List[str]] = None, + loss: str = "ce-jaccard", + ignore_index: Optional[int] = None, + lr: float = 1e-3, + patience: int = 10, + freeze_backbone: bool = False, + freeze_decoder: bool = False, + ) -> None: + """Inititalize a new ChangeDetectionTask instance. + + Args: + model: Name of the model to use. + backbone: Name of the `timm + `__ or `smp + `__ backbone to use. + weights: Initial model weights. Either a weight enum, the string + representation of a weight enum, True for ImageNet weights, False or + None for random weights, or the path to a saved model state dict. FCN + model does not support pretrained weights. Pretrained ViT weight enums + are not supported yet. + in_channels: Number of input channels to model. + num_classes: Number of prediction classes. + class_weights: Optional rescaling weight given to each + class and used with 'ce' loss. + labels: Optional labels to use for classes in metrics + e.g. ["background", "change"] + loss: Name of the loss function, currently supports + 'ce', 'jaccard', 'focal' or 'focal-jaccard' loss. + ignore_index: Optional integer class index to ignore in the loss and + metrics. + lr: Learning rate for optimizer. + patience: Patience for learning rate scheduler. + freeze_backbone: Freeze the backbone network to fine-tune the + decoder and segmentation head. + freeze_decoder: Freeze the decoder network to linear probe + the segmentation head. + + .. versionadded: 0.6 + """ + if ignore_index is not None and loss == "jaccard": + warnings.warn( + "ignore_index has no effect on training when loss='jaccard'", + UserWarning, + ) + + self.weights = weights + super().__init__(ignore="weights") + + def configure_losses(self) -> None: + """Initialize the loss criterion. + + Raises: + ValueError: If *loss* is invalid. + """ + loss: str = self.hparams["loss"] + ignore_index = self.hparams["ignore_index"] + if loss == "ce": + ignore_value = -1000 if ignore_index is None else ignore_index + self.criterion = nn.CrossEntropyLoss( + ignore_index=ignore_value, weight=self.hparams["class_weights"] + ) + elif loss == "jaccard": + self.criterion = smp.losses.JaccardLoss( + mode="multiclass", classes=self.hparams["num_classes"] + ) + elif loss == "focal": + self.criterion = smp.losses.FocalLoss( + "multiclass", ignore_index=ignore_index, normalized=True + ) + elif loss == "focal-jaccard": + self.criterion = FocalJaccardLoss() + elif loss == "ce-jaccard": + self.criterion = XEntJaccardLoss() + else: + raise ValueError( + f"Loss type '{loss}' is not valid. " + "Currently, supports 'ce', 'jaccard' or 'focal' loss." + ) + + def configure_metrics(self) -> None: + """Initialize the performance metrics.""" + num_classes: int = self.hparams["num_classes"] + ignore_index: Optional[int] = self.hparams["ignore_index"] + labels: Optional[List[str]] = self.hparams["labels"] + metrics = MetricCollection( + { + "accuracy": ClasswiseWrapper( + MulticlassAccuracy( + num_classes=num_classes, ignore_index=ignore_index, average=None + ), + labels, + ), + "jaccard": ClasswiseWrapper( + MulticlassJaccardIndex( + num_classes=num_classes, ignore_index=ignore_index, average=None + ), + labels, + ), + "f1": ClasswiseWrapper( + MulticlassF1Score( + num_classes=num_classes, ignore_index=ignore_index, average=None + ), + labels, + ), + } + ) + self.train_metrics = metrics.clone(prefix="train_") + self.val_metrics = metrics.clone(prefix="val_") + self.test_metrics = metrics.clone(prefix="test_") + + def configure_models(self) -> None: + """Initialize the model. + + Raises: + ValueError: If *model* is invalid. + """ + model: str = self.hparams["model"] + backbone: str = self.hparams["backbone"] + weights = self.weights + in_channels: int = self.hparams["in_channels"] + num_classes: int = self.hparams["num_classes"] + + if model == "unet": + self.model = smp.Unet( + encoder_name=backbone, + encoder_weights="imagenet" if weights is True else None, + in_channels=in_channels * 2, # images are concatenated + classes=num_classes, + ) + elif model == "fcsiamdiff": + self.model = FCSiamDiff( + in_channels=in_channels, + classes=num_classes, + encoder_weights="imagenet" if weights is True else None, + ) + elif model == "fcsiamconc": + self.model = FCSiamConc( + in_channels=in_channels, + classes=num_classes, + encoder_weights="imagenet" if weights is True else None, + ) + else: + raise ValueError( + f"Model type '{model}' is not valid. " + "Currently, only supports 'unet', 'fcsiamdiff, and 'fcsiamconc'." + ) + + if weights and weights is not True: + if isinstance(weights, WeightsEnum): + state_dict = weights.get_state_dict(progress=True) + elif os.path.exists(weights): + _, state_dict = utils.extract_backbone(weights) + else: + state_dict = get_weight(weights).get_state_dict(progress=True) + self.model.encoder.load_state_dict(state_dict) + + # Freeze backbone + if self.hparams["freeze_backbone"] and model in ["unet"]: + for param in self.model.encoder.parameters(): + param.requires_grad = False + + # Freeze decoder + if self.hparams["freeze_decoder"] and model in ["unet"]: + for param in self.model.decoder.parameters(): + param.requires_grad = False + + def _shared_step(self, batch: Any, batch_idx: int, stage: str) -> Tensor: + model: str = self.hparams["model"] + image1, image2, y = batch["image1"], batch["image2"], batch["mask"].float( + ) + if model == "unet": + x = torch.cat([image1, image2], dim=1) + elif model in ["fcsiamdiff", "fcsiamconc"]: + x = torch.stack((image1, image2), dim=1) + y_hat = self(x) + y = y.long() + + loss: Tensor = self.criterion(y_hat, y) + self.log(f"{stage}_loss", loss) + + # Retrieve the correct metrics based on the stage + metrics = getattr(self, f"{stage}_metrics", None) + if metrics: + metrics(y_hat, y) + self.log_dict({f"{k}": v for k, v in metrics.compute().items()}) + + return loss + + def training_step(self, batch: Any, batch_idx: int) -> Tensor: + loss = self._shared_step(batch, batch_idx, "train") + return loss + + def validation_step(self, batch: Any, batch_idx: int) -> None: + self._shared_step(batch, batch_idx, "val") + + def test_step(self, batch: Any, batch_idx: int) -> None: + self._shared_step(batch, batch_idx, "test") + + def predict_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: + """Compute the predicted class. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. + + Returns: + Output predicted class. + """ + model: str = self.hparams["model"] + image1 = batch["image1"] + image2 = batch["image2"] + if model == "unet": + x = torch.cat([image1, image2], dim=1) + elif model in ["fcsiamdiff", "fcsiamconc"]: + x = torch.stack((image1, image2), dim=1) + y_hat: Tensor = self(x) + y_hat_hard = y_hat.argmax(dim=1) + return y_hat_hard From a3926768e6f6ccef7c80055a42891cd809294907 Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Tue, 12 Nov 2024 23:20:42 +0000 Subject: [PATCH 02/16] changed from image1, image2 to stacked images. --- torchgeo/trainers/change.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/torchgeo/trainers/change.py b/torchgeo/trainers/change.py index a39455c5caa..eae2168bd1f 100644 --- a/torchgeo/trainers/change.py +++ b/torchgeo/trainers/change.py @@ -225,14 +225,11 @@ def configure_models(self) -> None: def _shared_step(self, batch: Any, batch_idx: int, stage: str) -> Tensor: model: str = self.hparams["model"] - image1, image2, y = batch["image1"], batch["image2"], batch["mask"].float( - ) + x = batch["image"] + y = batch["mask"] if model == "unet": - x = torch.cat([image1, image2], dim=1) - elif model in ["fcsiamdiff", "fcsiamconc"]: - x = torch.stack((image1, image2), dim=1) + x = x.flatten(start_dim=1, end_dim=2) y_hat = self(x) - y = y.long() loss: Tensor = self.criterion(y_hat, y) self.log(f"{stage}_loss", loss) @@ -269,12 +266,9 @@ def predict_step( Output predicted class. """ model: str = self.hparams["model"] - image1 = batch["image1"] - image2 = batch["image2"] + x = batch["image"] if model == "unet": - x = torch.cat([image1, image2], dim=1) - elif model in ["fcsiamdiff", "fcsiamconc"]: - x = torch.stack((image1, image2), dim=1) + x = x.flatten(start_dim=1, end_dim=2) y_hat: Tensor = self(x) y_hat_hard = y_hat.argmax(dim=1) return y_hat_hard From a8e1f0ad7547a16a2bc4b73223f9e02f9b0d4dd1 Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Wed, 13 Nov 2024 00:21:34 +0000 Subject: [PATCH 03/16] fixed mypy and ruff issues --- torchgeo/trainers/change.py | 47 ++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/torchgeo/trainers/change.py b/torchgeo/trainers/change.py index eae2168bd1f..47440cda425 100644 --- a/torchgeo/trainers/change.py +++ b/torchgeo/trainers/change.py @@ -5,7 +5,7 @@ import os import warnings -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union, cast import segmentation_models_pytorch as smp import torch @@ -26,24 +26,32 @@ class FocalJaccardLoss(nn.Module): - def __init__(self): + """FocalJaccardLoss.""" + + def __init__(self) -> None: + """Initialize a FocalJaccardLoss instance.""" super().__init__() self.focal_loss = smp.losses.FocalLoss( mode="multiclass", normalized=True) self.jaccard_loss = smp.losses.JaccardLoss(mode="multiclass") def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: - return self.focal_loss(preds, targets) + self.jaccard_loss(preds, targets) + """Compute the loss.""" + return cast(torch.Tensor, self.focal_loss(preds, targets) + self.jaccard_loss(preds, targets)) class XEntJaccardLoss(nn.Module): - def __init__(self): + """XEntJaccardLoss.""" + + def __init__(self) -> None: + """Initialize a XEntJaccardLoss instance.""" super().__init__() self.ce_loss = nn.CrossEntropyLoss() self.jaccard_loss = smp.losses.JaccardLoss(mode="multiclass") def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: - return self.ce_loss(preds, targets) + self.jaccard_loss(preds, targets) + """Compute the loss.""" + return cast(torch.Tensor, self.ce_loss(preds, targets) + self.jaccard_loss(preds, targets)) class ChangeDetectionTask(BaseTask): @@ -57,7 +65,7 @@ def __init__( in_channels: int = 3, num_classes: int = 2, class_weights: Optional[Tensor] = None, - labels: Optional[List[str]] = None, + labels: Optional[list[str]] = None, loss: str = "ce-jaccard", ignore_index: Optional[int] = None, lr: float = 1e-3, @@ -103,7 +111,7 @@ class and used with 'ce' loss. ) self.weights = weights - super().__init__(ignore="weights") + super().__init__() def configure_losses(self) -> None: """Initialize the loss criterion. @@ -133,14 +141,14 @@ def configure_losses(self) -> None: else: raise ValueError( f"Loss type '{loss}' is not valid. " - "Currently, supports 'ce', 'jaccard' or 'focal' loss." + "Currently, supports 'ce', 'jaccard', 'focal', 'focal-jaccard, or 'ce-jaccard loss." ) def configure_metrics(self) -> None: """Initialize the performance metrics.""" num_classes: int = self.hparams["num_classes"] ignore_index: Optional[int] = self.hparams["ignore_index"] - labels: Optional[List[str]] = self.hparams["labels"] + labels: Optional[list[str]] = self.hparams["labels"] metrics = MetricCollection( { "accuracy": ClasswiseWrapper( @@ -243,13 +251,34 @@ def _shared_step(self, batch: Any, batch_idx: int, stage: str) -> Tensor: return loss def training_step(self, batch: Any, batch_idx: int) -> Tensor: + """Compute the training loss and additional metrics. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + + Returns: + The loss tensor. + """ loss = self._shared_step(batch, batch_idx, "train") return loss def validation_step(self, batch: Any, batch_idx: int) -> None: + """Compute the validation loss and additional metrics. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + """ self._shared_step(batch, batch_idx, "val") def test_step(self, batch: Any, batch_idx: int) -> None: + """Compute the test loss and additional metrics. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + """ self._shared_step(batch, batch_idx, "test") def predict_step( From 4513e843e78e81dc1e087f526faaff53bbb7e98b Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Thu, 21 Nov 2024 23:03:38 +0000 Subject: [PATCH 04/16] adding tests. some still need work. --- tests/conf/oscd.yaml | 17 +++ tests/trainers/test_change.py | 206 ++++++++++++++++++++++++++++++++++ 2 files changed, 223 insertions(+) create mode 100644 tests/conf/oscd.yaml create mode 100644 tests/trainers/test_change.py diff --git a/tests/conf/oscd.yaml b/tests/conf/oscd.yaml new file mode 100644 index 00000000000..264f17334d0 --- /dev/null +++ b/tests/conf/oscd.yaml @@ -0,0 +1,17 @@ +model: + class_path: ChangeDetectionTask + init_args: + loss: 'ce' + model: 'unet' + backbone: 'resnet18' + in_channels: 13 + num_classes: 2 + ignore_index: 0 +data: + class_path: OSCDDataModule + init_args: + batch_size: 2 + patch_size: 16 + val_split_pct: 0.5 + dict_kwargs: + root: 'tests/data/oscd' \ No newline at end of file diff --git a/tests/trainers/test_change.py b/tests/trainers/test_change.py new file mode 100644 index 00000000000..2fbe5d32dc5 --- /dev/null +++ b/tests/trainers/test_change.py @@ -0,0 +1,206 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from pathlib import Path +from typing import Any, cast + +import pytest +import segmentation_models_pytorch as smp +import timm +import torch +import torch.nn as nn +from lightning.pytorch import Trainer +from pytest import MonkeyPatch +from torch.nn.modules import Module +from torchvision.models._api import WeightsEnum + +from torchgeo.datamodules import MisconfigurationException, SEN12MSDataModule +from torchgeo.datasets import RGBBandsMissingError +from torchgeo.main import main +from torchgeo.models import ResNet18_Weights +from torchgeo.trainers import ChangeDetectionTask + + +class ChangeDetectionTestModel(Module): + def __init__(self, in_channels: int = 3, classes: int = 3, **kwargs: Any) -> None: + super().__init__() + self.conv1 = nn.Conv2d( + in_channels=in_channels, out_channels=classes, kernel_size=1, padding=0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return cast(torch.Tensor, self.conv1(x)) + + +def create_model(**kwargs: Any) -> Module: + return ChangeDetectionTestModel(**kwargs) + + +def plot(*args: Any, **kwargs: Any) -> None: + return None + + +def plot_missing_bands(*args: Any, **kwargs: Any) -> None: + raise RGBBandsMissingError() + + +class TestChangeDetectionTask: + @pytest.mark.parametrize('name', ['oscd']) + def test_trainer( + self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool + ) -> None: + config = os.path.join('tests', 'conf', name + '.yaml') + + monkeypatch.setattr(smp, 'Unet', create_model) + monkeypatch.setattr(smp, 'DeepLabV3Plus', create_model) + + args = [ + '--config', + config, + '--trainer.accelerator', + 'cpu', + '--trainer.fast_dev_run', + str(fast_dev_run), + '--trainer.max_epochs', + '1', + '--trainer.log_every_n_steps', + '1', + ] + + main(['fit', *args]) + try: + main(['test', *args]) + except MisconfigurationException: + pass + try: + main(['predict', *args]) + except MisconfigurationException: + pass + + @pytest.fixture + def weights(self) -> WeightsEnum: + return ResNet18_Weights.SENTINEL2_ALL_MOCO + + @pytest.fixture + def mocked_weights( + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, + ) -> WeightsEnum: + path = tmp_path / f'{weights}.pth' + model = timm.create_model( + weights.meta['model'], in_chans=weights.meta['in_chans'] + ) + torch.save(model.state_dict(), path) + try: + monkeypatch.setattr(weights.value, 'url', str(path)) + except AttributeError: + monkeypatch.setattr(weights, 'url', str(path)) + return weights + + def test_weight_file(self, checkpoint: str) -> None: + ChangeDetectionTask(backbone='resnet18', weights=checkpoint, num_classes=6) + + def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: + ChangeDetectionTask( + backbone=mocked_weights.meta['model'], + weights=mocked_weights, + in_channels=mocked_weights.meta['in_chans'], + ) + + def test_weight_str(self, mocked_weights: WeightsEnum) -> None: + ChangeDetectionTask( + backbone=mocked_weights.meta['model'], + weights=str(mocked_weights), + in_channels=mocked_weights.meta['in_chans'], + ) + + @pytest.mark.slow + def test_weight_enum_download(self, weights: WeightsEnum) -> None: + ChangeDetectionTask( + backbone=weights.meta['model'], + weights=weights, + in_channels=weights.meta['in_chans'], + ) + + @pytest.mark.slow + def test_weight_str_download(self, weights: WeightsEnum) -> None: + ChangeDetectionTask( + backbone=weights.meta['model'], + weights=str(weights), + in_channels=weights.meta['in_chans'], + ) + + def test_invalid_model(self) -> None: + match = "Model type 'invalid_model' is not valid." + with pytest.raises(ValueError, match=match): + ChangeDetectionTask(model='invalid_model') + + def test_invalid_loss(self) -> None: + match = "Loss type 'invalid_loss' is not valid." + with pytest.raises(ValueError, match=match): + ChangeDetectionTask(loss='invalid_loss') + + def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: + monkeypatch.setattr(SEN12MSDataModule, 'plot', plot) + datamodule = SEN12MSDataModule( + root='tests/data/sen12ms', batch_size=1, num_workers=0 + ) + model = ChangeDetectionTask(backbone='resnet18', in_channels=15, num_classes=6) + trainer = Trainer( + accelerator='cpu', + fast_dev_run=fast_dev_run, + log_every_n_steps=1, + max_epochs=1, + ) + trainer.validate(model=model, datamodule=datamodule) + + def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: + monkeypatch.setattr(SEN12MSDataModule, 'plot', plot_missing_bands) + datamodule = SEN12MSDataModule( + root='tests/data/sen12ms', batch_size=1, num_workers=0 + ) + model = ChangeDetectionTask(backbone='resnet18', in_channels=15, num_classes=6) + trainer = Trainer( + accelerator='cpu', + fast_dev_run=fast_dev_run, + log_every_n_steps=1, + max_epochs=1, + ) + trainer.validate(model=model, datamodule=datamodule) + + @pytest.mark.parametrize('model_name', ['unet']) + @pytest.mark.parametrize( + 'backbone', ['resnet18', 'mobilenet_v2', 'efficientnet-b0'] + ) + def test_freeze_backbone(self, model_name: str, backbone: str) -> None: + model = ChangeDetectionTask( + model=model_name, backbone=backbone, freeze_backbone=True + ) + assert all( + [param.requires_grad is False for param in model.model.encoder.parameters()] + ) + assert all([param.requires_grad for param in model.model.decoder.parameters()]) + assert all( + [ + param.requires_grad + for param in model.model.segmentation_head.parameters() + ] + ) + + @pytest.mark.parametrize('model_name', ['unet']) + def test_freeze_decoder(self, model_name: str) -> None: + model = ChangeDetectionTask(model=model_name, freeze_decoder=True) + assert all( + [param.requires_grad is False for param in model.model.decoder.parameters()] + ) + assert all([param.requires_grad for param in model.model.encoder.parameters()]) + assert all( + [ + param.requires_grad + for param in model.model.segmentation_head.parameters() + ] + ) From 7e5ba826838c9937da3ba36269e8a206902eba10 Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Thu, 21 Nov 2024 23:10:25 +0000 Subject: [PATCH 05/16] making Kornia transforms work with added temporal dimension. --- torchgeo/datamodules/oscd.py | 8 +++++--- torchgeo/datasets/oscd.py | 6 ++++-- torchgeo/transforms/transforms.py | 8 +++++++- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 87ae20cdf40..fea9602adeb 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -86,9 +86,11 @@ def __init__( self.std = torch.tensor([STD[b] for b in self.bands]) self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - _RandomNCrop(self.patch_size, batch_size), - data_keys=['image1', 'image2', 'mask'], + K.VideoSequential( + K.Normalize(mean=self.mean, std=self.std), + _RandomNCrop(self.patch_size, batch_size), + ), + data_keys=['image', 'mask'], ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index 28f7714a7c6..8a6616534eb 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -150,7 +150,8 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image1 = self._load_image(files['images1']) image2 = self._load_image(files['images2']) mask = self._load_target(str(files['mask'])) - sample = {'image1': image1, 'image2': image2, 'mask': mask} + image = torch.stack(tensors=[image1, image2], dim=0) + sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -169,7 +170,8 @@ def _load_files(self) -> list[dict[str, str | Sequence[str]]]: regions = [] labels_root = os.path.join( self.root, - f'Onera Satellite Change Detection dataset - {self.split.capitalize()} ' + f'Onera Satellite Change Detection dataset - { + self.split.capitalize()} ' + 'Labels', ) images_root = os.path.join( diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index d8f80bdcaac..e04389d8b18 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -102,7 +102,13 @@ def forward(self, batch: dict[str, Any]) -> dict[str, Any]: batch['boxes'] = Boxes(batch['boxes']).to_tensor(mode='xyxy') # Torchmetrics does not support masks with a channel dimension - if 'mask' in batch and batch['mask'].shape[1] == 1: + # Kornia adds a temporal dimension to mask when passed through VideoSequential. + if 'mask' in batch and batch['mask'].ndim == 5: + if batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () c h w -> b c h w') + if batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') + elif 'mask' in batch and batch['mask'].shape[1] == 1: batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') if 'masks' in batch and batch['masks'].ndim == 4: batch['masks'] = rearrange(batch['masks'], '() c h w -> c h w') From 035b3967225769b804464b8da73cba2ba2345a3b Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Tue, 3 Dec 2024 20:22:39 +0000 Subject: [PATCH 06/16] Support only binary change with two timesteps. Moved loss functions to torchgeo/losses. --- tests/conf/oscd.yaml | 4 +- tests/trainers/test_change.py | 1 - torchgeo/datasets/oscd.py | 2 +- torchgeo/losses/__init__.py | 4 +- torchgeo/losses/focaljaccard.py | 27 +++++ torchgeo/losses/xentjaccard.py | 27 +++++ torchgeo/trainers/change.py | 188 ++++++++++---------------------- 7 files changed, 118 insertions(+), 135 deletions(-) create mode 100644 torchgeo/losses/focaljaccard.py create mode 100644 torchgeo/losses/xentjaccard.py diff --git a/tests/conf/oscd.yaml b/tests/conf/oscd.yaml index 264f17334d0..8f0d3ce2a39 100644 --- a/tests/conf/oscd.yaml +++ b/tests/conf/oscd.yaml @@ -1,12 +1,10 @@ model: class_path: ChangeDetectionTask init_args: - loss: 'ce' + loss: 'bce' model: 'unet' backbone: 'resnet18' in_channels: 13 - num_classes: 2 - ignore_index: 0 data: class_path: OSCDDataModule init_args: diff --git a/tests/trainers/test_change.py b/tests/trainers/test_change.py index 2fbe5d32dc5..0cf1915d47b 100644 --- a/tests/trainers/test_change.py +++ b/tests/trainers/test_change.py @@ -53,7 +53,6 @@ def test_trainer( config = os.path.join('tests', 'conf', name + '.yaml') monkeypatch.setattr(smp, 'Unet', create_model) - monkeypatch.setattr(smp, 'DeepLabV3Plus', create_model) args = [ '--config', diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index 8a6616534eb..7a13d03aee9 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -242,7 +242,7 @@ def _load_target(self, path: Path) -> Tensor: array: np.typing.NDArray[np.int_] = np.array(img.convert('L')) tensor = torch.from_numpy(array) tensor = torch.clamp(tensor, min=0, max=1) - tensor = tensor.to(torch.long) + tensor = tensor.to(torch.float) return tensor def _verify(self) -> None: diff --git a/torchgeo/losses/__init__.py b/torchgeo/losses/__init__.py index d30807a4bd6..cfeb973a276 100644 --- a/torchgeo/losses/__init__.py +++ b/torchgeo/losses/__init__.py @@ -3,6 +3,8 @@ """TorchGeo losses.""" +from .focaljaccard import BinaryFocalJaccardLoss from .qr import QRLoss, RQLoss +from .xentjaccard import BinaryXEntJaccardLoss -__all__ = ('QRLoss', 'RQLoss') +__all__ = ('QRLoss', 'RQLoss', 'BinaryFocalJaccardLoss', 'BinaryXEntJaccardLoss') diff --git a/torchgeo/losses/focaljaccard.py b/torchgeo/losses/focaljaccard.py new file mode 100644 index 00000000000..cceb9ceacf2 --- /dev/null +++ b/torchgeo/losses/focaljaccard.py @@ -0,0 +1,27 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Focal Jaccard loss functions.""" + +from typing import cast + +import segmentation_models_pytorch as smp +import torch +import torch.nn as nn + + +class BinaryFocalJaccardLoss(nn.Module): + """Binary Focal Jaccard Loss.""" + + def __init__(self) -> None: + """Initialize a BinaryFocalJaccardLoss instance.""" + super().__init__() + self.focal_loss = smp.losses.FocalLoss(mode='binary', normalized=True) + self.jaccard_loss = smp.losses.JaccardLoss(mode='binary') + + def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """Compute the loss.""" + return cast( + torch.Tensor, + self.focal_loss(preds, targets) + self.jaccard_loss(preds, targets), + ) diff --git a/torchgeo/losses/xentjaccard.py b/torchgeo/losses/xentjaccard.py new file mode 100644 index 00000000000..9675e0272b8 --- /dev/null +++ b/torchgeo/losses/xentjaccard.py @@ -0,0 +1,27 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Cross-Entropy Jaccard loss functions.""" + +from typing import cast + +import segmentation_models_pytorch as smp +import torch +import torch.nn as nn + + +class BinaryXEntJaccardLoss(nn.Module): + """Binary Cross-Entropy Jaccard Loss.""" + + def __init__(self) -> None: + """Initialize a BinaryXEntJaccardLoss instance.""" + super().__init__() + self.bce_loss = nn.BCEWithLogitsLoss() + self.jaccard_loss = smp.losses.JaccardLoss(mode='binary') + + def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """Compute the loss.""" + return cast( + torch.Tensor, + self.bce_loss(preds, targets) + self.jaccard_loss(preds, targets), + ) diff --git a/torchgeo/trainers/change.py b/torchgeo/trainers/change.py index 47440cda425..394cace516f 100644 --- a/torchgeo/trainers/change.py +++ b/torchgeo/trainers/change.py @@ -4,70 +4,36 @@ """Trainers for change detection.""" import os -import warnings -from typing import Any, Optional, Union, cast +from typing import Any import segmentation_models_pytorch as smp -import torch import torch.nn as nn from torch import Tensor from torchmetrics import MetricCollection from torchmetrics.classification import ( - MulticlassAccuracy, - MulticlassF1Score, - MulticlassJaccardIndex, + BinaryAccuracy, + BinaryF1Score, + BinaryJaccardIndex, ) -from torchmetrics.wrappers import ClasswiseWrapper from torchvision.models._api import WeightsEnum +from ..losses import BinaryFocalJaccardLoss, BinaryXEntJaccardLoss from ..models import FCSiamConc, FCSiamDiff, get_weight from . import utils from .base import BaseTask -class FocalJaccardLoss(nn.Module): - """FocalJaccardLoss.""" - - def __init__(self) -> None: - """Initialize a FocalJaccardLoss instance.""" - super().__init__() - self.focal_loss = smp.losses.FocalLoss( - mode="multiclass", normalized=True) - self.jaccard_loss = smp.losses.JaccardLoss(mode="multiclass") - - def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: - """Compute the loss.""" - return cast(torch.Tensor, self.focal_loss(preds, targets) + self.jaccard_loss(preds, targets)) - - -class XEntJaccardLoss(nn.Module): - """XEntJaccardLoss.""" - - def __init__(self) -> None: - """Initialize a XEntJaccardLoss instance.""" - super().__init__() - self.ce_loss = nn.CrossEntropyLoss() - self.jaccard_loss = smp.losses.JaccardLoss(mode="multiclass") - - def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: - """Compute the loss.""" - return cast(torch.Tensor, self.ce_loss(preds, targets) + self.jaccard_loss(preds, targets)) - - class ChangeDetectionTask(BaseTask): - """Change Detection.""" + """Change Detection. Currently supports binary change between two timesteps.""" def __init__( self, - model: str = "unet", - backbone: str = "resnet50", - weights: Optional[Union[WeightsEnum, str, bool]] = None, + model: str = 'unet', + backbone: str = 'resnet50', + weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, - num_classes: int = 2, - class_weights: Optional[Tensor] = None, - labels: Optional[list[str]] = None, - loss: str = "ce-jaccard", - ignore_index: Optional[int] = None, + pos_weight: Tensor | None = None, + loss: str = 'bce-jaccard', lr: float = 1e-3, patience: int = 10, freeze_backbone: bool = False, @@ -86,15 +52,9 @@ def __init__( model does not support pretrained weights. Pretrained ViT weight enums are not supported yet. in_channels: Number of input channels to model. - num_classes: Number of prediction classes. - class_weights: Optional rescaling weight given to each - class and used with 'ce' loss. - labels: Optional labels to use for classes in metrics - e.g. ["background", "change"] + pos_weight: A weight of positive examples and used with 'bce' loss. loss: Name of the loss function, currently supports - 'ce', 'jaccard', 'focal' or 'focal-jaccard' loss. - ignore_index: Optional integer class index to ignore in the loss and - metrics. + 'bce', 'jaccard', 'focal', 'focal-jaccard' or 'bce-jaccard' loss. lr: Learning rate for optimizer. patience: Patience for learning rate scheduler. freeze_backbone: Freeze the backbone network to fine-tune the @@ -104,12 +64,6 @@ class and used with 'ce' loss. .. versionadded: 0.6 """ - if ignore_index is not None and loss == "jaccard": - warnings.warn( - "ignore_index has no effect on training when loss='jaccard'", - UserWarning, - ) - self.weights = weights super().__init__() @@ -119,61 +73,35 @@ def configure_losses(self) -> None: Raises: ValueError: If *loss* is invalid. """ - loss: str = self.hparams["loss"] - ignore_index = self.hparams["ignore_index"] - if loss == "ce": - ignore_value = -1000 if ignore_index is None else ignore_index - self.criterion = nn.CrossEntropyLoss( - ignore_index=ignore_value, weight=self.hparams["class_weights"] - ) - elif loss == "jaccard": - self.criterion = smp.losses.JaccardLoss( - mode="multiclass", classes=self.hparams["num_classes"] - ) - elif loss == "focal": - self.criterion = smp.losses.FocalLoss( - "multiclass", ignore_index=ignore_index, normalized=True - ) - elif loss == "focal-jaccard": - self.criterion = FocalJaccardLoss() - elif loss == "ce-jaccard": - self.criterion = XEntJaccardLoss() + loss: str = self.hparams['loss'] + if loss == 'bce': + self.criterion = nn.BCEWithLogitsLoss(pos_weight=self.hparams['pos_weight']) + elif loss == 'jaccard': + self.criterion = smp.losses.JaccardLoss(mode='binary') + elif loss == 'focal': + self.criterion = smp.losses.FocalLoss(mode='binary', normalized=True) + elif loss == 'focal-jaccard': + self.criterion = BinaryFocalJaccardLoss() + elif loss == 'bce-jaccard': + self.criterion = BinaryXEntJaccardLoss() else: raise ValueError( f"Loss type '{loss}' is not valid. " - "Currently, supports 'ce', 'jaccard', 'focal', 'focal-jaccard, or 'ce-jaccard loss." + "Currently, supports 'bce', 'jaccard', 'focal', 'focal-jaccard, or 'bce-jaccard loss." ) def configure_metrics(self) -> None: """Initialize the performance metrics.""" - num_classes: int = self.hparams["num_classes"] - ignore_index: Optional[int] = self.hparams["ignore_index"] - labels: Optional[list[str]] = self.hparams["labels"] metrics = MetricCollection( { - "accuracy": ClasswiseWrapper( - MulticlassAccuracy( - num_classes=num_classes, ignore_index=ignore_index, average=None - ), - labels, - ), - "jaccard": ClasswiseWrapper( - MulticlassJaccardIndex( - num_classes=num_classes, ignore_index=ignore_index, average=None - ), - labels, - ), - "f1": ClasswiseWrapper( - MulticlassF1Score( - num_classes=num_classes, ignore_index=ignore_index, average=None - ), - labels, - ), + 'accuracy': BinaryAccuracy(), + 'jaccard': BinaryJaccardIndex(), + 'f1': BinaryF1Score(), } ) - self.train_metrics = metrics.clone(prefix="train_") - self.val_metrics = metrics.clone(prefix="val_") - self.test_metrics = metrics.clone(prefix="test_") + self.train_metrics = metrics.clone(prefix='train_') + self.val_metrics = metrics.clone(prefix='val_') + self.test_metrics = metrics.clone(prefix='test_') def configure_models(self) -> None: """Initialize the model. @@ -181,30 +109,30 @@ def configure_models(self) -> None: Raises: ValueError: If *model* is invalid. """ - model: str = self.hparams["model"] - backbone: str = self.hparams["backbone"] + model: str = self.hparams['model'] + backbone: str = self.hparams['backbone'] weights = self.weights - in_channels: int = self.hparams["in_channels"] - num_classes: int = self.hparams["num_classes"] + in_channels: int = self.hparams['in_channels'] + num_classes = 1 - if model == "unet": + if model == 'unet': self.model = smp.Unet( encoder_name=backbone, - encoder_weights="imagenet" if weights is True else None, + encoder_weights='imagenet' if weights is True else None, in_channels=in_channels * 2, # images are concatenated classes=num_classes, ) - elif model == "fcsiamdiff": + elif model == 'fcsiamdiff': self.model = FCSiamDiff( in_channels=in_channels, classes=num_classes, - encoder_weights="imagenet" if weights is True else None, + encoder_weights='imagenet' if weights is True else None, ) - elif model == "fcsiamconc": + elif model == 'fcsiamconc': self.model = FCSiamConc( in_channels=in_channels, classes=num_classes, - encoder_weights="imagenet" if weights is True else None, + encoder_weights='imagenet' if weights is True else None, ) else: raise ValueError( @@ -222,31 +150,32 @@ def configure_models(self) -> None: self.model.encoder.load_state_dict(state_dict) # Freeze backbone - if self.hparams["freeze_backbone"] and model in ["unet"]: + if self.hparams['freeze_backbone'] and model in ['unet']: for param in self.model.encoder.parameters(): param.requires_grad = False # Freeze decoder - if self.hparams["freeze_decoder"] and model in ["unet"]: + if self.hparams['freeze_decoder'] and model in ['unet']: for param in self.model.decoder.parameters(): param.requires_grad = False def _shared_step(self, batch: Any, batch_idx: int, stage: str) -> Tensor: - model: str = self.hparams["model"] - x = batch["image"] - y = batch["mask"] - if model == "unet": + model: str = self.hparams['model'] + x = batch['image'] + y = batch['mask'] + y = y.unsqueeze(dim=1) # channel dim for binary loss functions/metrics + if model == 'unet': x = x.flatten(start_dim=1, end_dim=2) y_hat = self(x) loss: Tensor = self.criterion(y_hat, y) - self.log(f"{stage}_loss", loss) + self.log(f'{stage}_loss', loss) # Retrieve the correct metrics based on the stage - metrics = getattr(self, f"{stage}_metrics", None) + metrics = getattr(self, f'{stage}_metrics', None) if metrics: metrics(y_hat, y) - self.log_dict({f"{k}": v for k, v in metrics.compute().items()}) + self.log_dict({f'{k}': v for k, v in metrics.compute().items()}) return loss @@ -260,7 +189,7 @@ def training_step(self, batch: Any, batch_idx: int) -> Tensor: Returns: The loss tensor. """ - loss = self._shared_step(batch, batch_idx, "train") + loss = self._shared_step(batch, batch_idx, 'train') return loss def validation_step(self, batch: Any, batch_idx: int) -> None: @@ -270,7 +199,7 @@ def validation_step(self, batch: Any, batch_idx: int) -> None: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. """ - self._shared_step(batch, batch_idx, "val") + self._shared_step(batch, batch_idx, 'val') def test_step(self, batch: Any, batch_idx: int) -> None: """Compute the test loss and additional metrics. @@ -279,7 +208,7 @@ def test_step(self, batch: Any, batch_idx: int) -> None: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. """ - self._shared_step(batch, batch_idx, "test") + self._shared_step(batch, batch_idx, 'test') def predict_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 @@ -294,10 +223,11 @@ def predict_step( Returns: Output predicted class. """ - model: str = self.hparams["model"] - x = batch["image"] - if model == "unet": + model: str = self.hparams['model'] + threshold = 0.5 + x = batch['image'] + if model == 'unet': x = x.flatten(start_dim=1, end_dim=2) y_hat: Tensor = self(x) - y_hat_hard = y_hat.argmax(dim=1) + y_hat_hard = (nn.functional.sigmoid(y_hat) > threshold).int() return y_hat_hard From 5546cba1fe766bb54b9215fb47acbd8066e445b4 Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Thu, 5 Dec 2024 23:15:55 +0000 Subject: [PATCH 07/16] fixed issues with tests. --- tests/trainers/conftest.py | 7 +++--- tests/trainers/test_change.py | 46 ++++------------------------------- 2 files changed, 9 insertions(+), 44 deletions(-) diff --git a/tests/trainers/conftest.py b/tests/trainers/conftest.py index a3ce098ae7d..920e8ee6abc 100644 --- a/tests/trainers/conftest.py +++ b/tests/trainers/conftest.py @@ -6,8 +6,8 @@ from pathlib import Path import pytest +import timm import torch -import torchvision from _pytest.fixtures import SubRequest from torch import Tensor from torch.nn.modules import Module @@ -22,8 +22,9 @@ def fast_dev_run(request: SubRequest) -> bool: @pytest.fixture(scope='package') -def model() -> Module: - model: Module = torchvision.models.resnet18(weights=None) +def model(request: SubRequest) -> Module: + in_channels = getattr(request, 'param', 3) + model: Module = timm.create_model('resnet18', in_chans=in_channels) return model diff --git a/tests/trainers/test_change.py b/tests/trainers/test_change.py index 0cf1915d47b..e760d8b6ef7 100644 --- a/tests/trainers/test_change.py +++ b/tests/trainers/test_change.py @@ -10,13 +10,11 @@ import timm import torch import torch.nn as nn -from lightning.pytorch import Trainer from pytest import MonkeyPatch from torch.nn.modules import Module from torchvision.models._api import WeightsEnum -from torchgeo.datamodules import MisconfigurationException, SEN12MSDataModule -from torchgeo.datasets import RGBBandsMissingError +from torchgeo.datamodules import MisconfigurationException from torchgeo.main import main from torchgeo.models import ResNet18_Weights from torchgeo.trainers import ChangeDetectionTask @@ -37,14 +35,6 @@ def create_model(**kwargs: Any) -> Module: return ChangeDetectionTestModel(**kwargs) -def plot(*args: Any, **kwargs: Any) -> None: - return None - - -def plot_missing_bands(*args: Any, **kwargs: Any) -> None: - raise RGBBandsMissingError() - - class TestChangeDetectionTask: @pytest.mark.parametrize('name', ['oscd']) def test_trainer( @@ -90,8 +80,9 @@ def mocked_weights( load_state_dict_from_url: None, ) -> WeightsEnum: path = tmp_path / f'{weights}.pth' + # multiply in_chans by 2 since images are concatenated model = timm.create_model( - weights.meta['model'], in_chans=weights.meta['in_chans'] + weights.meta['model'], in_chans=weights.meta['in_chans'] * 2 ) torch.save(model.state_dict(), path) try: @@ -100,8 +91,9 @@ def mocked_weights( monkeypatch.setattr(weights, 'url', str(path)) return weights + @pytest.mark.parametrize('model', [6], indirect=True) def test_weight_file(self, checkpoint: str) -> None: - ChangeDetectionTask(backbone='resnet18', weights=checkpoint, num_classes=6) + ChangeDetectionTask(backbone='resnet18', weights=checkpoint) def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: ChangeDetectionTask( @@ -143,34 +135,6 @@ def test_invalid_loss(self) -> None: with pytest.raises(ValueError, match=match): ChangeDetectionTask(loss='invalid_loss') - def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(SEN12MSDataModule, 'plot', plot) - datamodule = SEN12MSDataModule( - root='tests/data/sen12ms', batch_size=1, num_workers=0 - ) - model = ChangeDetectionTask(backbone='resnet18', in_channels=15, num_classes=6) - trainer = Trainer( - accelerator='cpu', - fast_dev_run=fast_dev_run, - log_every_n_steps=1, - max_epochs=1, - ) - trainer.validate(model=model, datamodule=datamodule) - - def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(SEN12MSDataModule, 'plot', plot_missing_bands) - datamodule = SEN12MSDataModule( - root='tests/data/sen12ms', batch_size=1, num_workers=0 - ) - model = ChangeDetectionTask(backbone='resnet18', in_channels=15, num_classes=6) - trainer = Trainer( - accelerator='cpu', - fast_dev_run=fast_dev_run, - log_every_n_steps=1, - max_epochs=1, - ) - trainer.validate(model=model, datamodule=datamodule) - @pytest.mark.parametrize('model_name', ['unet']) @pytest.mark.parametrize( 'backbone', ['resnet18', 'mobilenet_v2', 'efficientnet-b0'] From b2a30caf41772025b4bcdffcb6c528e3abdd4236 Mon Sep 17 00:00:00 2001 From: Keenan Eves <31701650+keves1@users.noreply.github.com> Date: Tue, 17 Dec 2024 15:21:57 -0700 Subject: [PATCH 08/16] Update versionadded Co-authored-by: Adam J. Stewart --- torchgeo/trainers/change.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/trainers/change.py b/torchgeo/trainers/change.py index 394cace516f..4cbec1f18ed 100644 --- a/torchgeo/trainers/change.py +++ b/torchgeo/trainers/change.py @@ -62,7 +62,7 @@ def __init__( freeze_decoder: Freeze the decoder network to linear probe the segmentation head. - .. versionadded: 0.6 + .. versionadded: 0.7 """ self.weights = weights super().__init__() From 77759056c76ca9e67914327ad18bfcd7221a78cb Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Thu, 19 Dec 2024 18:50:31 +0000 Subject: [PATCH 09/16] removed custom loss functions. --- torchgeo/losses/__init__.py | 4 +--- torchgeo/losses/focaljaccard.py | 27 --------------------------- torchgeo/losses/xentjaccard.py | 27 --------------------------- torchgeo/trainers/change.py | 11 +++-------- 4 files changed, 4 insertions(+), 65 deletions(-) delete mode 100644 torchgeo/losses/focaljaccard.py delete mode 100644 torchgeo/losses/xentjaccard.py diff --git a/torchgeo/losses/__init__.py b/torchgeo/losses/__init__.py index cfeb973a276..d30807a4bd6 100644 --- a/torchgeo/losses/__init__.py +++ b/torchgeo/losses/__init__.py @@ -3,8 +3,6 @@ """TorchGeo losses.""" -from .focaljaccard import BinaryFocalJaccardLoss from .qr import QRLoss, RQLoss -from .xentjaccard import BinaryXEntJaccardLoss -__all__ = ('QRLoss', 'RQLoss', 'BinaryFocalJaccardLoss', 'BinaryXEntJaccardLoss') +__all__ = ('QRLoss', 'RQLoss') diff --git a/torchgeo/losses/focaljaccard.py b/torchgeo/losses/focaljaccard.py deleted file mode 100644 index cceb9ceacf2..00000000000 --- a/torchgeo/losses/focaljaccard.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -"""Focal Jaccard loss functions.""" - -from typing import cast - -import segmentation_models_pytorch as smp -import torch -import torch.nn as nn - - -class BinaryFocalJaccardLoss(nn.Module): - """Binary Focal Jaccard Loss.""" - - def __init__(self) -> None: - """Initialize a BinaryFocalJaccardLoss instance.""" - super().__init__() - self.focal_loss = smp.losses.FocalLoss(mode='binary', normalized=True) - self.jaccard_loss = smp.losses.JaccardLoss(mode='binary') - - def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: - """Compute the loss.""" - return cast( - torch.Tensor, - self.focal_loss(preds, targets) + self.jaccard_loss(preds, targets), - ) diff --git a/torchgeo/losses/xentjaccard.py b/torchgeo/losses/xentjaccard.py deleted file mode 100644 index 9675e0272b8..00000000000 --- a/torchgeo/losses/xentjaccard.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -"""Cross-Entropy Jaccard loss functions.""" - -from typing import cast - -import segmentation_models_pytorch as smp -import torch -import torch.nn as nn - - -class BinaryXEntJaccardLoss(nn.Module): - """Binary Cross-Entropy Jaccard Loss.""" - - def __init__(self) -> None: - """Initialize a BinaryXEntJaccardLoss instance.""" - super().__init__() - self.bce_loss = nn.BCEWithLogitsLoss() - self.jaccard_loss = smp.losses.JaccardLoss(mode='binary') - - def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: - """Compute the loss.""" - return cast( - torch.Tensor, - self.bce_loss(preds, targets) + self.jaccard_loss(preds, targets), - ) diff --git a/torchgeo/trainers/change.py b/torchgeo/trainers/change.py index 4cbec1f18ed..9ccfabaad41 100644 --- a/torchgeo/trainers/change.py +++ b/torchgeo/trainers/change.py @@ -17,7 +17,6 @@ ) from torchvision.models._api import WeightsEnum -from ..losses import BinaryFocalJaccardLoss, BinaryXEntJaccardLoss from ..models import FCSiamConc, FCSiamDiff, get_weight from . import utils from .base import BaseTask @@ -33,7 +32,7 @@ def __init__( weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, pos_weight: Tensor | None = None, - loss: str = 'bce-jaccard', + loss: str = 'bce', lr: float = 1e-3, patience: int = 10, freeze_backbone: bool = False, @@ -54,7 +53,7 @@ def __init__( in_channels: Number of input channels to model. pos_weight: A weight of positive examples and used with 'bce' loss. loss: Name of the loss function, currently supports - 'bce', 'jaccard', 'focal', 'focal-jaccard' or 'bce-jaccard' loss. + 'bce', 'jaccard', or 'focal' loss. lr: Learning rate for optimizer. patience: Patience for learning rate scheduler. freeze_backbone: Freeze the backbone network to fine-tune the @@ -80,14 +79,10 @@ def configure_losses(self) -> None: self.criterion = smp.losses.JaccardLoss(mode='binary') elif loss == 'focal': self.criterion = smp.losses.FocalLoss(mode='binary', normalized=True) - elif loss == 'focal-jaccard': - self.criterion = BinaryFocalJaccardLoss() - elif loss == 'bce-jaccard': - self.criterion = BinaryXEntJaccardLoss() else: raise ValueError( f"Loss type '{loss}' is not valid. " - "Currently, supports 'bce', 'jaccard', 'focal', 'focal-jaccard, or 'bce-jaccard loss." + "Currently, supports 'bce', 'jaccard', or 'focal' loss." ) def configure_metrics(self) -> None: From 7425575b98a0052c592cb7268fc900f48d3c0825 Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Thu, 19 Dec 2024 18:55:21 +0000 Subject: [PATCH 10/16] added docstring. --- torchgeo/trainers/change.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torchgeo/trainers/change.py b/torchgeo/trainers/change.py index 9ccfabaad41..2b48520a7fa 100644 --- a/torchgeo/trainers/change.py +++ b/torchgeo/trainers/change.py @@ -155,6 +155,16 @@ def configure_models(self) -> None: param.requires_grad = False def _shared_step(self, batch: Any, batch_idx: int, stage: str) -> Tensor: + """Compute the loss and additional metrics for the given stage. + + Args: + batch: The output of your DataLoader._ + batch_idx: Integer displaying index of this batch._ + stage: The current stage. + + Returns: + The loss tensor. + """ model: str = self.hparams['model'] x = batch['image'] y = batch['mask'] From 1839b8ec2b8d4b060ff9598e0b4604f9015e8c80 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 8 Jan 2025 09:50:51 +0100 Subject: [PATCH 11/16] Fix syntax error in Python 3.10 --- torchgeo/datasets/oscd.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index 7a13d03aee9..54ecdc0d4b0 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -170,8 +170,7 @@ def _load_files(self) -> list[dict[str, str | Sequence[str]]]: regions = [] labels_root = os.path.join( self.root, - f'Onera Satellite Change Detection dataset - { - self.split.capitalize()} ' + f'Onera Satellite Change Detection dataset - {self.split.capitalize()} ' + 'Labels', ) images_root = os.path.join( From 97f17ca00a231430523d467b24d1dcf57578c325 Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Wed, 8 Jan 2025 20:19:47 +0000 Subject: [PATCH 12/16] revert target dtype to long in dataset and change to float in trainer instead. --- torchgeo/datamodules/oscd.py | 4 ++-- torchgeo/datasets/oscd.py | 11 +++++++---- torchgeo/trainers/change.py | 3 ++- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 483c7b50c4d..5cd898e0699 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -86,8 +86,8 @@ def __init__( self.aug = K.AugmentationSequential( K.VideoSequential( - K.Normalize(mean=self.mean, std=self.std), - _RandomNCrop(self.patch_size, batch_size), + K.Normalize(mean=self.mean, std=self.std), + _RandomNCrop(self.patch_size, batch_size), ), data_keys=None, keepdim=True, diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index 54ecdc0d4b0..8978f4e4cd3 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -103,7 +103,8 @@ def __init__( root: Path = 'data', split: str = 'train', bands: Sequence[str] = all_bands, - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[dict[str, Tensor]], + dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -184,7 +185,8 @@ def _load_files(self) -> list[dict[str, str | Sequence[str]]]: def get_image_paths(ind: int) -> list[str]: return sorted( glob.glob( - os.path.join(images_root, region, f'imgs_{ind}_rect', '*.tif') + os.path.join(images_root, region, + f'imgs_{ind}_rect', '*.tif') ), key=sort_sentinel2_bands, ) @@ -223,7 +225,8 @@ def _load_image(self, paths: Sequence[Path]) -> Tensor: for path in paths: with Image.open(path) as img: images.append(np.array(img)) - array: np.typing.NDArray[np.int_] = np.stack(images, axis=0).astype(np.int_) + array: np.typing.NDArray[np.int_] = np.stack( + images, axis=0).astype(np.int_) tensor = torch.from_numpy(array).float() return tensor @@ -241,7 +244,7 @@ def _load_target(self, path: Path) -> Tensor: array: np.typing.NDArray[np.int_] = np.array(img.convert('L')) tensor = torch.from_numpy(array) tensor = torch.clamp(tensor, min=0, max=1) - tensor = tensor.to(torch.float) + tensor = tensor.to(torch.long) return tensor def _verify(self) -> None: diff --git a/torchgeo/trainers/change.py b/torchgeo/trainers/change.py index 2b48520a7fa..4788dda91c5 100644 --- a/torchgeo/trainers/change.py +++ b/torchgeo/trainers/change.py @@ -7,6 +7,7 @@ from typing import Any import segmentation_models_pytorch as smp +import torch import torch.nn as nn from torch import Tensor from torchmetrics import MetricCollection @@ -173,7 +174,7 @@ def _shared_step(self, batch: Any, batch_idx: int, stage: str) -> Tensor: x = x.flatten(start_dim=1, end_dim=2) y_hat = self(x) - loss: Tensor = self.criterion(y_hat, y) + loss: Tensor = self.criterion(y_hat, y.to(torch.float)) self.log(f'{stage}_loss', loss) # Retrieve the correct metrics based on the stage From 59ddd790fc089cd755cca8d29df7128e9d4e8c18 Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Wed, 8 Jan 2025 20:28:09 +0000 Subject: [PATCH 13/16] ruff format --- torchgeo/datasets/oscd.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index 8978f4e4cd3..27359dc4612 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -103,8 +103,7 @@ def __init__( root: Path = 'data', split: str = 'train', bands: Sequence[str] = all_bands, - transforms: Callable[[dict[str, Tensor]], - dict[str, Tensor]] | None = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -185,8 +184,7 @@ def _load_files(self) -> list[dict[str, str | Sequence[str]]]: def get_image_paths(ind: int) -> list[str]: return sorted( glob.glob( - os.path.join(images_root, region, - f'imgs_{ind}_rect', '*.tif') + os.path.join(images_root, region, f'imgs_{ind}_rect', '*.tif') ), key=sort_sentinel2_bands, ) @@ -225,8 +223,7 @@ def _load_image(self, paths: Sequence[Path]) -> Tensor: for path in paths: with Image.open(path) as img: images.append(np.array(img)) - array: np.typing.NDArray[np.int_] = np.stack( - images, axis=0).astype(np.int_) + array: np.typing.NDArray[np.int_] = np.stack(images, axis=0).astype(np.int_) tensor = torch.from_numpy(array).float() return tensor From 10033e97f3ba7d98c5e9b7f71a6ce353207fbd1a Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Fri, 10 Jan 2025 19:22:31 +0000 Subject: [PATCH 14/16] updated OSCD dataset tests --- tests/datasets/test_oscd.py | 12 ++++-------- torchgeo/datasets/oscd.py | 4 ++-- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/datasets/test_oscd.py b/tests/datasets/test_oscd.py index 711392f7fc4..5c10508e3bb 100644 --- a/tests/datasets/test_oscd.py +++ b/tests/datasets/test_oscd.py @@ -66,19 +66,15 @@ def dataset( def test_getitem(self, dataset: OSCD) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x['image1'], torch.Tensor) - assert x['image1'].ndim == 3 - assert isinstance(x['image2'], torch.Tensor) - assert x['image2'].ndim == 3 + assert isinstance(x['image'], torch.Tensor) + assert x['image'].ndim == 4 assert isinstance(x['mask'], torch.Tensor) assert x['mask'].ndim == 2 if dataset.bands == OSCD.rgb_bands: - assert x['image1'].shape[0] == 3 - assert x['image2'].shape[0] == 3 + assert x['image'].shape[1] == 3 else: - assert x['image1'].shape[0] == 13 - assert x['image2'].shape[0] == 13 + assert x['image'].shape[1] == 13 def test_len(self, dataset: OSCD) -> None: if dataset.split == 'train': diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index 27359dc4612..d261cbbf291 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -325,8 +325,8 @@ def get_masked(img: Tensor) -> 'np.typing.NDArray[np.uint8]': ) return array - image1 = get_masked(sample['image1']) - image2 = get_masked(sample['image2']) + image1 = get_masked(sample['image'][0]) + image2 = get_masked(sample['image'][1]) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) axs[0].imshow(image1) axs[0].axis('off') From 3b798ee7e73cf7a2775ad3af5e7cc3c5b4812580 Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Wed, 15 Jan 2025 23:47:14 +0000 Subject: [PATCH 15/16] removing now that datamodule is tested with ChangeDetectionTask. --- tests/datamodules/test_oscd.py | 82 ---------------------------------- 1 file changed, 82 deletions(-) delete mode 100644 tests/datamodules/test_oscd.py diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py deleted file mode 100644 index e67bd6d5678..00000000000 --- a/tests/datamodules/test_oscd.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os - -import pytest -from _pytest.fixtures import SubRequest -from lightning.pytorch import Trainer - -from torchgeo.datamodules import OSCDDataModule -from torchgeo.datasets import OSCD - - -class TestOSCDDataModule: - @pytest.fixture(params=[OSCD.all_bands, OSCD.rgb_bands]) - def datamodule(self, request: SubRequest) -> OSCDDataModule: - bands = request.param - root = os.path.join('tests', 'data', 'oscd') - dm = OSCDDataModule( - root=root, - download=True, - bands=bands, - batch_size=1, - patch_size=2, - val_split_pct=0.5, - num_workers=0, - ) - dm.prepare_data() - dm.trainer = Trainer(accelerator='cpu', max_epochs=1) - return dm - - def test_train_dataloader(self, datamodule: OSCDDataModule) -> None: - datamodule.setup('fit') - if datamodule.trainer: - datamodule.trainer.training = True - batch = next(iter(datamodule.train_dataloader())) - batch = datamodule.on_after_batch_transfer(batch, 0) - assert batch['image1'].shape[-2:] == batch['mask'].shape[-2:] == (2, 2) - assert batch['image1'].shape[0] == batch['mask'].shape[0] == 1 - assert batch['image2'].shape[-2:] == batch['mask'].shape[-2:] == (2, 2) - assert batch['image2'].shape[0] == batch['mask'].shape[0] == 1 - if datamodule.bands == OSCD.all_bands: - assert batch['image1'].shape[1] == 13 - assert batch['image2'].shape[1] == 13 - else: - assert batch['image1'].shape[1] == 3 - assert batch['image2'].shape[1] == 3 - - def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: - datamodule.setup('validate') - if datamodule.trainer: - datamodule.trainer.validating = True - batch = next(iter(datamodule.val_dataloader())) - batch = datamodule.on_after_batch_transfer(batch, 0) - if datamodule.val_split_pct > 0.0: - assert batch['image1'].shape[-2:] == batch['mask'].shape[-2:] == (2, 2) - assert batch['image1'].shape[0] == batch['mask'].shape[0] == 1 - assert batch['image2'].shape[-2:] == batch['mask'].shape[-2:] == (2, 2) - assert batch['image2'].shape[0] == batch['mask'].shape[0] == 1 - if datamodule.bands == OSCD.all_bands: - assert batch['image1'].shape[1] == 13 - assert batch['image2'].shape[1] == 13 - else: - assert batch['image1'].shape[1] == 3 - assert batch['image2'].shape[1] == 3 - - def test_test_dataloader(self, datamodule: OSCDDataModule) -> None: - datamodule.setup('test') - if datamodule.trainer: - datamodule.trainer.testing = True - batch = next(iter(datamodule.test_dataloader())) - batch = datamodule.on_after_batch_transfer(batch, 0) - assert batch['image1'].shape[-2:] == batch['mask'].shape[-2:] == (2, 2) - assert batch['image1'].shape[0] == batch['mask'].shape[0] == 1 - assert batch['image2'].shape[-2:] == batch['mask'].shape[-2:] == (2, 2) - assert batch['image2'].shape[0] == batch['mask'].shape[0] == 1 - if datamodule.bands == OSCD.all_bands: - assert batch['image1'].shape[1] == 13 - assert batch['image2'].shape[1] == 13 - else: - assert batch['image1'].shape[1] == 3 - assert batch['image2'].shape[1] == 3 From 1fa47c532cab28fa9faca25054e46d0535ab20da Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Fri, 17 Jan 2025 17:56:39 +0000 Subject: [PATCH 16/16] prettier format --- tests/conf/oscd.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conf/oscd.yaml b/tests/conf/oscd.yaml index 8f0d3ce2a39..d44d9c9a459 100644 --- a/tests/conf/oscd.yaml +++ b/tests/conf/oscd.yaml @@ -12,4 +12,4 @@ data: patch_size: 16 val_split_pct: 0.5 dict_kwargs: - root: 'tests/data/oscd' \ No newline at end of file + root: 'tests/data/oscd'