Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SemanticSegmentationTask: add class-wise metrics #2130

Closed
wants to merge 40 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
23fa1fb
Add average metrics
robmarkcole Jun 19, 2024
b7d8305
Add average metrics
robmarkcole Jun 19, 2024
b1526fa
refactor: Rename metrics in SemanticSegmentationTask
robmarkcole Jun 19, 2024
341e272
Ruff format
robmarkcole Jun 19, 2024
024feda
Use ignore_index
robmarkcole Jun 20, 2024
04cac59
pass on_epoch
robmarkcole Jun 20, 2024
56f20fc
on_epoch to train too
robmarkcole Jun 20, 2024
3d2b309
Disable on_step for train metrics
robmarkcole Jun 20, 2024
9af1493
Merge branch 'main' into update-metrics
robmarkcole Jun 20, 2024
192c496
ruff format
robmarkcole Jun 20, 2024
73b710f
Merge branch 'main' into update-metrics
robmarkcole Jun 21, 2024
8ce8c30
Merge branch 'main' into update-metrics
robmarkcole Jun 23, 2024
e4ed9fd
Merge branch 'main' into update-metrics
robmarkcole Jul 2, 2024
d9c2688
Merge branch 'main' into update-metrics
robmarkcole Jul 8, 2024
400fae3
Merge branch 'main' into update-metrics
robmarkcole Jul 21, 2024
f4c793e
Merge branch 'main' into update-metrics
robmarkcole Aug 1, 2024
3b629ea
Merge branch 'main' into update-metrics
robmarkcole Aug 5, 2024
e6abadd
Merge branch 'main' into update-metrics
robmarkcole Aug 6, 2024
da887fe
Bump min torchmetrics
robmarkcole Aug 6, 2024
5138ccb
Merge branch 'update-metrics' of https://github.com/robmarkcole/torch…
robmarkcole Aug 6, 2024
479c7e3
Raise torchmetrics min
robmarkcole Aug 6, 2024
50b7d29
remo on_epoch etc
robmarkcole Aug 7, 2024
1cd436f
Remove on_epoch
robmarkcole Aug 7, 2024
9e985e2
try torchmetrics==1.1.0
robmarkcole Aug 7, 2024
c773322
try torchmetrics==1.1.1
robmarkcole Aug 7, 2024
9d8c8e4
Merge branch 'main' into update-metrics
robmarkcole Aug 7, 2024
e2640f5
Use loop to generate metrics
robmarkcole Aug 8, 2024
19187a9
Update
robmarkcole Aug 8, 2024
a3f7ffe
Fix jaccard
robmarkcole Aug 8, 2024
9a66442
fix dependencies delta
robmarkcole Aug 8, 2024
8381cb7
fix pyproject
robmarkcole Aug 8, 2024
b5050ad
Merge branch 'main' into update-metrics
robmarkcole Sep 5, 2024
07e7c4d
Merge branch 'main' into update-metrics
robmarkcole Dec 31, 2024
ff761f2
Address merge conflicts
robmarkcole Dec 31, 2024
59ba3c8
Specify on_epoch
robmarkcole Jan 2, 2025
1184647
Ruff format
robmarkcole Jan 2, 2025
82eecdc
Merge branch 'main' into update-metrics
robmarkcole Feb 1, 2025
bcada43
merge main
robmarkcole Feb 1, 2025
bc3bb3c
Merge branch 'microsoft:main' into main
robmarkcole Feb 1, 2025
ae7f061
merge main
robmarkcole Feb 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ repos:
- pytest>=6.1.2
- scikit-image>=0.22.0
- torch>=2.6
- torchmetrics>=0.10
- torchmetrics>=1.1.1
- torchvision>=0.18
exclude: (build|data|dist|logo|logs|output)/
- repo: https://github.com/pre-commit/mirrors-prettier
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ dependencies = [
"timm>=0.4.12",
# torch 1.13+ required by torchvision
"torch>=1.13",
# torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics
"torchmetrics>=0.10",
# torchmetrics 1.1.1+ required for average argument to MeanAveragePrecision
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to Lightning-AI/torchmetrics@63c7bbe the argument didn't exist until 1.2

"torchmetrics>=1.1.1",
# torchvision 0.14+ required for torchvision.models.swin_v2_b
"torchvision>=0.14",
]
Expand Down
2 changes: 1 addition & 1 deletion requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ segmentation-models-pytorch==0.2.0
shapely==1.8.0
timm==0.4.12
torch==1.13.0
torchmetrics==0.10.0
torchmetrics==1.1.1
torchvision==0.14.0

# datasets
Expand Down
16 changes: 4 additions & 12 deletions tests/datamodules/test_fair1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import os

import matplotlib.pyplot as plt
import pytest

from torchgeo.datamodules import FAIR1MDataModule
Expand All @@ -26,17 +25,10 @@ def test_val_dataloader(self, datamodule: FAIR1MDataModule) -> None:
datamodule.setup('validate')
next(iter(datamodule.val_dataloader()))

def test_test_dataloader(self, datamodule: FAIR1MDataModule) -> None:
datamodule.setup('test')
next(iter(datamodule.test_dataloader()))

def test_predict_dataloader(self, datamodule: FAIR1MDataModule) -> None:
datamodule.setup('predict')
next(iter(datamodule.predict_dataloader()))

def test_plot(self, datamodule: FAIR1MDataModule) -> None:
datamodule.setup('validate')
batch = next(iter(datamodule.val_dataloader()))
sample = {
'image': batch['image'][0],
'boxes': batch['boxes'][0],
'label': batch['label'][0],
}
datamodule.plot(sample)
plt.close()
22 changes: 2 additions & 20 deletions tests/datamodules/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

from typing import Any

import matplotlib.pyplot as plt
import pytest
import torch
from _pytest.fixtures import SubRequest
from lightning.pytorch import Trainer
from matplotlib.figure import Figure
from rasterio.crs import CRS
from torch import Tensor

Expand All @@ -31,12 +29,9 @@ def __init__(
self.res = 1

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
image = torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)
image = torch.arange(3 * 2 * 2).view(3, 2, 2)
return {'image': image, 'crs': CRS.from_epsg(4326), 'bounds': query}

def plot(self, *args: Any, **kwargs: Any) -> Figure:
return plt.figure()


class CustomGeoDataModule(GeoDataModule):
def __init__(self) -> None:
Expand Down Expand Up @@ -68,14 +63,11 @@ def __init__(
self.length = length

def __getitem__(self, index: int) -> dict[str, Tensor]:
return {'image': torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)}
return {'image': torch.arange(3 * 2 * 2).view(3, 2, 2)}

def __len__(self) -> int:
return self.length

def plot(self, *args: Any, **kwargs: Any) -> Figure:
return plt.figure()


class CustomNonGeoDataModule(NonGeoDataModule):
def __init__(self) -> None:
Expand Down Expand Up @@ -133,11 +125,6 @@ def test_predict(self, datamodule: CustomGeoDataModule) -> None:
batch = datamodule.transfer_batch_to_device(batch, torch.device('cpu'), 1)
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_plot(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup('validate')
datamodule.plot()
plt.close()

def test_no_datasets(self) -> None:
dm = CustomGeoDataModule()
msg = r'CustomGeoDataModule\.setup must define one of '
Expand Down Expand Up @@ -235,11 +222,6 @@ def test_predict(self, datamodule: CustomNonGeoDataModule) -> None:
batch = next(iter(datamodule.predict_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_plot(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup('validate')
datamodule.plot()
plt.close()

def test_no_datasets(self) -> None:
dm = CustomNonGeoDataModule()
msg = r'CustomNonGeoDataModule\.setup must define one of '
Expand Down
9 changes: 0 additions & 9 deletions tests/datamodules/test_usavars.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

import os

import matplotlib.pyplot as plt
import pytest
from _pytest.fixtures import SubRequest

from torchgeo.datamodules import USAVarsDataModule
from torchgeo.datasets import unbind_samples


class TestUSAVarsDataModule:
Expand Down Expand Up @@ -41,10 +39,3 @@ def test_test_dataloader(self, datamodule: USAVarsDataModule) -> None:
assert len(datamodule.test_dataloader()) == 1
batch = next(iter(datamodule.test_dataloader()))
assert batch['image'].shape[0] == datamodule.batch_size

def test_plot(self, datamodule: USAVarsDataModule) -> None:
datamodule.setup('validate')
batch = next(iter(datamodule.val_dataloader()))
sample = unbind_samples(batch)[0]
datamodule.plot(sample)
plt.close()
37 changes: 37 additions & 0 deletions tests/datamodules/test_xview2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import pytest

from torchgeo.datamodules import XView2DataModule


class TestXView2DataModule:
@pytest.fixture
def datamodule(self) -> XView2DataModule:
root = os.path.join('tests', 'data', 'xview2')
batch_size = 1
num_workers = 0
dm = XView2DataModule(
root=root, batch_size=batch_size, num_workers=num_workers, val_split_pct=0.5
)
dm.prepare_data()
return dm

def test_train_dataloader(self, datamodule: XView2DataModule) -> None:
datamodule.setup('fit')
next(iter(datamodule.train_dataloader()))

def test_val_dataloader(self, datamodule: XView2DataModule) -> None:
datamodule.setup('validate')
next(iter(datamodule.val_dataloader()))

def test_test_dataloader(self, datamodule: XView2DataModule) -> None:
datamodule.setup('test')
next(iter(datamodule.test_dataloader()))

def test_predict_dataloader(self, datamodule: XView2DataModule) -> None:
datamodule.setup('predict')
next(iter(datamodule.predict_dataloader()))
93 changes: 14 additions & 79 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import timm
import torch
import torch.nn as nn
import torchvision
from lightning.pytorch import Trainer
from pytest import MonkeyPatch
from torch.nn.modules import Module
Expand All @@ -19,7 +20,7 @@
EuroSATDataModule,
MisconfigurationException,
)
from torchgeo.datasets import BigEarthNet, EuroSAT, RGBBandsMissingError
from torchgeo.datasets import BigEarthNet, EuroSAT
from torchgeo.main import main
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask
Expand Down Expand Up @@ -55,12 +56,9 @@
return ClassificationTestModel(**kwargs)


def plot(*args: Any, **kwargs: Any) -> None:
return None


def plot_missing_bands(*args: Any, **kwargs: Any) -> None:
raise RGBBandsMissingError()
def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestClassificationTask:
Expand Down Expand Up @@ -103,13 +101,13 @@
'1',
]

main(['fit', *args])
main(['fit'] + args)

Check failure on line 104 in tests/trainers/test_classification.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF005)

tests/trainers/test_classification.py:104:14: RUF005 Consider `['fit', *args]` instead of concatenation
try:
main(['test', *args])
main(['test'] + args)

Check failure on line 106 in tests/trainers/test_classification.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF005)

tests/trainers/test_classification.py:106:18: RUF005 Consider `['test', *args]` instead of concatenation
except MisconfigurationException:
pass
try:
main(['predict', *args])
main(['predict'] + args)

Check failure on line 110 in tests/trainers/test_classification.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF005)

tests/trainers/test_classification.py:110:18: RUF005 Consider `['predict', *args]` instead of concatenation
except MisconfigurationException:
pass

Expand All @@ -119,11 +117,7 @@

@pytest.fixture
def mocked_weights(
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model( # type: ignore[attr-defined]
Expand All @@ -134,6 +128,7 @@
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_weight_file(self, checkpoint: str) -> None:
Expand Down Expand Up @@ -183,34 +178,6 @@
with pytest.raises(ValueError, match=match):
ClassificationTask(model='resnet18', loss='invalid_loss')

def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(EuroSATDataModule, 'plot', plot)
datamodule = EuroSATDataModule(
root='tests/data/eurosat', batch_size=1, num_workers=0
)
model = ClassificationTask(model='resnet18', in_channels=13, num_classes=10)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(EuroSATDataModule, 'plot', plot_missing_bands)
datamodule = EuroSATDataModule(
root='tests/data/eurosat', batch_size=1, num_workers=0
)
model = ClassificationTask(model='resnet18', in_channels=13, num_classes=10)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, fast_dev_run: bool) -> None:
datamodule = PredictClassificationDataModule(
root='tests/data/eurosat', batch_size=1, num_workers=0
Expand All @@ -237,7 +204,7 @@

class TestMultiLabelClassificationTask:
@pytest.mark.parametrize(
'name', ['bigearthnet_all', 'bigearthnet_s1', 'bigearthnet_s2', 'treesatai']
'name', ['bigearthnet_all', 'bigearthnet_s1', 'bigearthnet_s2']
)
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
Expand All @@ -259,13 +226,13 @@
'1',
]

main(['fit', *args])
main(['fit'] + args)

Check failure on line 229 in tests/trainers/test_classification.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF005)

tests/trainers/test_classification.py:229:14: RUF005 Consider `['fit', *args]` instead of concatenation
try:
main(['test', *args])
main(['test'] + args)

Check failure on line 231 in tests/trainers/test_classification.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF005)

tests/trainers/test_classification.py:231:18: RUF005 Consider `['test', *args]` instead of concatenation
except MisconfigurationException:
pass
try:
main(['predict', *args])
main(['predict'] + args)

Check failure on line 235 in tests/trainers/test_classification.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF005)

tests/trainers/test_classification.py:235:18: RUF005 Consider `['predict', *args]` instead of concatenation
except MisconfigurationException:
pass

Expand All @@ -274,38 +241,6 @@
with pytest.raises(ValueError, match=match):
MultiLabelClassificationTask(model='resnet18', loss='invalid_loss')

def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(BigEarthNetDataModule, 'plot', plot)
datamodule = BigEarthNetDataModule(
root='tests/data/bigearthnet', batch_size=1, num_workers=0
)
model = MultiLabelClassificationTask(
model='resnet18', in_channels=14, num_classes=19, loss='bce'
)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(BigEarthNetDataModule, 'plot', plot_missing_bands)
datamodule = BigEarthNetDataModule(
root='tests/data/bigearthnet', batch_size=1, num_workers=0
)
model = MultiLabelClassificationTask(
model='resnet18', in_channels=14, num_classes=19, loss='bce'
)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, fast_dev_run: bool) -> None:
datamodule = PredictMultiLabelClassificationDataModule(
root='tests/data/bigearthnet', batch_size=1, num_workers=0
Expand Down
Loading
Loading