From 5d802ac0f2efe1b801b447b91a7b6e34a1a74671 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 23 Aug 2024 18:14:25 +0200 Subject: [PATCH 1/3] Add Sample and Batch TypedDicts --- tests/datamodules/test_geo.py | 7 ++- tests/transforms/test_color.py | 9 ++-- tests/transforms/test_indices.py | 16 +++--- torchgeo/datamodules/chesapeake.py | 7 +-- torchgeo/datamodules/etci2021.py | 7 +-- torchgeo/datamodules/fair1m.py | 7 ++- torchgeo/datamodules/geo.py | 45 ++++++++-------- torchgeo/datamodules/sen12ms.py | 7 +-- torchgeo/datamodules/utils.py | 13 +++-- torchgeo/datasets/__init__.py | 4 ++ torchgeo/datasets/advance.py | 13 ++--- torchgeo/datasets/agb_live_woody_density.py | 10 ++-- torchgeo/datasets/agrifieldnet.py | 15 +++--- torchgeo/datasets/airphen.py | 9 +--- torchgeo/datasets/astergdem.py | 10 ++-- torchgeo/datasets/benin_cashews.py | 10 ++-- torchgeo/datasets/bigearthnet.py | 13 ++--- torchgeo/datasets/biomassters.py | 11 ++-- torchgeo/datasets/cbf.py | 10 ++-- torchgeo/datasets/cdl.py | 13 ++--- torchgeo/datasets/chabud.py | 13 ++--- torchgeo/datasets/chesapeake.py | 23 +++----- torchgeo/datasets/cloud_cover.py | 13 ++--- torchgeo/datasets/cms_mangrove_canopy.py | 10 ++-- torchgeo/datasets/cowc.py | 16 +++--- torchgeo/datasets/cropharvest.py | 10 ++-- torchgeo/datasets/cv4a_kenya_crop_type.py | 10 ++-- torchgeo/datasets/cyclone.py | 14 ++--- torchgeo/datasets/deepglobelandcover.py | 9 ++-- torchgeo/datasets/dfc2022.py | 19 ++++--- torchgeo/datasets/eddmaps.py | 7 ++- torchgeo/datasets/enviroatlas.py | 15 +++--- torchgeo/datasets/esri2020.py | 10 ++-- torchgeo/datasets/etci2021.py | 13 ++--- torchgeo/datasets/eudem.py | 11 ++-- torchgeo/datasets/eurocrops.py | 15 +++--- torchgeo/datasets/eurosat.py | 20 ++++--- torchgeo/datasets/fair1m.py | 21 ++++---- torchgeo/datasets/fire_risk.py | 10 ++-- torchgeo/datasets/forestdamage.py | 22 ++++---- torchgeo/datasets/gbif.py | 7 ++- torchgeo/datasets/geo.py | 47 ++++++++-------- torchgeo/datasets/gid15.py | 12 ++--- torchgeo/datasets/globbiomass.py | 14 +++-- torchgeo/datasets/idtrees.py | 20 ++++--- torchgeo/datasets/inaturalist.py | 7 ++- torchgeo/datasets/inria.py | 20 +++---- torchgeo/datasets/iobench.py | 11 ++-- torchgeo/datasets/l7irish.py | 15 +++--- torchgeo/datasets/l8biome.py | 14 ++--- torchgeo/datasets/landcoverai.py | 26 ++++----- torchgeo/datasets/landsat.py | 10 ++-- torchgeo/datasets/levircd.py | 13 ++--- torchgeo/datasets/loveda.py | 12 ++--- torchgeo/datasets/mapinwild.py | 12 ++--- torchgeo/datasets/millionaid.py | 17 +++--- torchgeo/datasets/naip.py | 8 +-- torchgeo/datasets/nasa_marine_debris.py | 14 ++--- torchgeo/datasets/nccm.py | 13 ++--- torchgeo/datasets/nlcd.py | 13 ++--- torchgeo/datasets/openbuildings.py | 21 ++++---- torchgeo/datasets/oscd.py | 9 ++-- torchgeo/datasets/pastis.py | 20 +++---- torchgeo/datasets/patternnet.py | 10 ++-- torchgeo/datasets/potsdam.py | 9 ++-- torchgeo/datasets/prisma.py | 9 +--- torchgeo/datasets/quakeset.py | 17 +++--- torchgeo/datasets/reforestree.py | 19 ++++--- torchgeo/datasets/resisc45.py | 10 ++-- torchgeo/datasets/rwanda_field_boundary.py | 11 ++-- torchgeo/datasets/seasonet.py | 10 ++-- torchgeo/datasets/seco.py | 13 ++--- torchgeo/datasets/sen12ms.py | 13 ++--- torchgeo/datasets/sentinel.py | 17 ++---- torchgeo/datasets/skippd.py | 11 ++-- torchgeo/datasets/so2sat.py | 14 ++--- torchgeo/datasets/south_africa_crop_type.py | 15 +++--- torchgeo/datasets/south_america_soybean.py | 11 ++-- torchgeo/datasets/spacenet.py | 12 ++--- torchgeo/datasets/ssl4eo.py | 25 ++++----- torchgeo/datasets/ssl4eo_benchmark.py | 13 ++--- torchgeo/datasets/sustainbench_crop_yield.py | 12 ++--- torchgeo/datasets/ucmerced.py | 9 ++-- torchgeo/datasets/usavars.py | 13 ++--- torchgeo/datasets/utils.py | 54 +++++++++++++++---- torchgeo/datasets/vaihingen.py | 9 ++-- torchgeo/datasets/vhr10.py | 15 +++--- .../western_usa_live_fuel_moisture.py | 9 ++-- torchgeo/datasets/xview.py | 9 ++-- torchgeo/datasets/zuericrop.py | 10 ++-- torchgeo/trainers/byol.py | 11 ++-- torchgeo/trainers/classification.py | 19 ++++--- torchgeo/trainers/detection.py | 15 +++--- torchgeo/trainers/iobench.py | 13 ++--- torchgeo/trainers/moco.py | 12 +++-- torchgeo/trainers/regression.py | 11 ++-- torchgeo/trainers/segmentation.py | 11 ++-- torchgeo/trainers/simclr.py | 12 +++-- torchgeo/transforms/transforms.py | 10 ++-- 99 files changed, 615 insertions(+), 750 deletions(-) diff --git a/tests/datamodules/test_geo.py b/tests/datamodules/test_geo.py index 4e5431c684f..b6e97c15164 100644 --- a/tests/datamodules/test_geo.py +++ b/tests/datamodules/test_geo.py @@ -10,14 +10,13 @@ from lightning.pytorch import Trainer from matplotlib.figure import Figure from rasterio.crs import CRS -from torch import Tensor from torchgeo.datamodules import ( GeoDataModule, MisconfigurationException, NonGeoDataModule, ) -from torchgeo.datasets import BoundingBox, GeoDataset, NonGeoDataset +from torchgeo.datasets import BoundingBox, GeoDataset, NonGeoDataset, Sample from torchgeo.samplers import RandomBatchGeoSampler, RandomGeoSampler @@ -30,7 +29,7 @@ def __init__( self.index.insert(i, (0, 1, 2, 3, 4, 5)) self.res = 1 - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: image = torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2) return {'image': image, 'crs': CRS.from_epsg(4326), 'bounds': query} @@ -67,7 +66,7 @@ def __init__( ) -> None: self.length = length - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: return {'image': torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)} def __len__(self) -> int: diff --git a/tests/transforms/test_color.py b/tests/transforms/test_color.py index b235f7195f2..0b7c5b3f9a9 100644 --- a/tests/transforms/test_color.py +++ b/tests/transforms/test_color.py @@ -6,11 +6,12 @@ import torch from torch import Tensor +from torchgeo.datasets import Batch, Sample from torchgeo.transforms import RandomGrayscale @pytest.fixture -def sample() -> dict[str, Tensor]: +def sample() -> Sample: return { 'image': torch.arange(3 * 4 * 4, dtype=torch.float).view(3, 4, 4), 'mask': torch.arange(4 * 4, dtype=torch.long).view(1, 4, 4), @@ -18,7 +19,7 @@ def sample() -> dict[str, Tensor]: @pytest.fixture -def batch() -> dict[str, Tensor]: +def batch() -> Batch: return { 'image': torch.arange(2 * 3 * 4 * 4, dtype=torch.float).view(2, 3, 4, 4), 'mask': torch.arange(2 * 4 * 4, dtype=torch.long).view(2, 1, 4, 4), @@ -33,7 +34,7 @@ def batch() -> dict[str, Tensor]: torch.tensor([1.0, 2.0, 3.0]), ], ) -def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) -> None: +def test_random_grayscale_sample(weights: Tensor, sample: Sample) -> None: aug = K.AugmentationSequential( RandomGrayscale(weights, p=1), keepdim=True, data_keys=None ) @@ -51,7 +52,7 @@ def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) -> torch.tensor([1.0, 2.0, 3.0]), ], ) -def test_random_grayscale_batch(weights: Tensor, batch: dict[str, Tensor]) -> None: +def test_random_grayscale_batch(weights: Tensor, batch: Batch) -> None: aug = K.AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=None) output = aug(batch) assert output['image'].shape == batch['image'].shape diff --git a/tests/transforms/test_indices.py b/tests/transforms/test_indices.py index 9e6f54e48c4..dda3d09c875 100644 --- a/tests/transforms/test_indices.py +++ b/tests/transforms/test_indices.py @@ -4,8 +4,8 @@ import kornia.augmentation as K import pytest import torch -from torch import Tensor +from torchgeo.datasets import Batch, Sample from torchgeo.transforms import ( AppendBNDVI, AppendGBNDVI, @@ -25,7 +25,7 @@ @pytest.fixture -def sample() -> dict[str, Tensor]: +def sample() -> Sample: return { 'image': torch.arange(3 * 4 * 4, dtype=torch.float).view(3, 4, 4), 'mask': torch.arange(4 * 4, dtype=torch.long).view(1, 4, 4), @@ -33,14 +33,14 @@ def sample() -> dict[str, Tensor]: @pytest.fixture -def batch() -> dict[str, Tensor]: +def batch() -> Batch: return { 'image': torch.arange(2 * 3 * 4 * 4, dtype=torch.float).view(2, 3, 4, 4), 'mask': torch.arange(2 * 4 * 4, dtype=torch.long).view(2, 1, 4, 4), } -def test_append_index_sample(sample: dict[str, Tensor]) -> None: +def test_append_index_sample(sample: Sample) -> None: c, h, w = sample['image'].shape aug = K.AugmentationSequential( AppendNormalizedDifferenceIndex(index_a=0, index_b=1), data_keys=None @@ -49,7 +49,7 @@ def test_append_index_sample(sample: dict[str, Tensor]) -> None: assert output['image'].shape == (1, c + 1, h, w) -def test_append_index_batch(batch: dict[str, Tensor]) -> None: +def test_append_index_batch(batch: Batch) -> None: b, c, h, w = batch['image'].shape aug = K.AugmentationSequential( AppendNormalizedDifferenceIndex(index_a=0, index_b=1), data_keys=None @@ -58,7 +58,7 @@ def test_append_index_batch(batch: dict[str, Tensor]) -> None: assert output['image'].shape == (b, c + 1, h, w) -def test_append_triband_index_batch(batch: dict[str, Tensor]) -> None: +def test_append_triband_index_batch(batch: Batch) -> None: b, c, h, w = batch['image'].shape aug = K.AugmentationSequential( AppendTriBandNormalizedDifferenceIndex(index_a=0, index_b=1, index_c=2), @@ -83,7 +83,7 @@ def test_append_triband_index_batch(batch: dict[str, Tensor]) -> None: ], ) def test_append_normalized_difference_indices( - sample: dict[str, Tensor], index: AppendNormalizedDifferenceIndex + sample: Sample, index: AppendNormalizedDifferenceIndex ) -> None: c, h, w = sample['image'].shape aug = K.AugmentationSequential(index(0, 1), data_keys=None) @@ -93,7 +93,7 @@ def test_append_normalized_difference_indices( @pytest.mark.parametrize('index', [AppendGBNDVI, AppendGRNDVI, AppendRBNDVI]) def test_append_tri_band_normalized_difference_indices( - sample: dict[str, Tensor], index: AppendTriBandNormalizedDifferenceIndex + sample: Sample, index: AppendTriBandNormalizedDifferenceIndex ) -> None: c, h, w = sample['image'].shape aug = K.AugmentationSequential(index(0, 1, 2), data_keys=None) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 41e944e1af5..1a970835a58 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -7,9 +7,8 @@ import kornia.augmentation as K import torch.nn.functional as F -from torch import Tensor -from ..datasets import ChesapeakeCVPR +from ..datasets import ChesapeakeCVPR, Sample from ..samplers import GridGeoSampler, RandomBatchGeoSampler from .geo import GeoDataModule @@ -124,9 +123,7 @@ def setup(self, stage: str) -> None: self.test_dataset, self.original_patch_size, self.original_patch_size ) - def on_after_batch_transfer( - self, batch: dict[str, Tensor], dataloader_idx: int - ) -> dict[str, Tensor]: + def on_after_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample: """Apply batch augmentations to the batch after it is transferred to the device. Args: diff --git a/torchgeo/datamodules/etci2021.py b/torchgeo/datamodules/etci2021.py index 233fa43261b..420213014ad 100644 --- a/torchgeo/datamodules/etci2021.py +++ b/torchgeo/datamodules/etci2021.py @@ -6,9 +6,8 @@ from typing import Any import torch -from torch import Tensor -from ..datasets import ETCI2021 +from ..datasets import ETCI2021, Sample from .geo import NonGeoDataModule @@ -62,9 +61,7 @@ def setup(self, stage: str) -> None: # Test set masks are not public, use for prediction instead self.predict_dataset = ETCI2021(split='test', **self.kwargs) - def on_after_batch_transfer( - self, batch: dict[str, Tensor], dataloader_idx: int - ) -> dict[str, Tensor]: + def on_after_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample: """Apply batch augmentations to the batch after it is transferred to the device. Args: diff --git a/torchgeo/datamodules/fair1m.py b/torchgeo/datamodules/fair1m.py index 291dd617e04..d6e1c8bdc34 100644 --- a/torchgeo/datamodules/fair1m.py +++ b/torchgeo/datamodules/fair1m.py @@ -6,13 +6,12 @@ from typing import Any import torch -from torch import Tensor -from ..datasets import FAIR1M +from ..datasets import FAIR1M, Sample from .geo import NonGeoDataModule -def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]: +def collate_fn(batch: list[Sample]) -> Sample: """Custom object detection collate fn to handle variable boxes. Args: @@ -23,7 +22,7 @@ def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]: .. versionadded:: 0.5 """ - output: dict[str, Any] = {} + output: Sample = {} output['image'] = torch.stack([sample['image'] for sample in batch]) if 'boxes' in batch[0]: diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 8721ea6e7f6..4348e456e89 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -10,10 +10,9 @@ import torch from lightning.pytorch import LightningDataModule from matplotlib.figure import Figure -from torch import Tensor from torch.utils.data import DataLoader, Dataset, Subset, default_collate -from ..datasets import GeoDataset, NonGeoDataset, stack_samples +from ..datasets import Batch, GeoDataset, NonGeoDataset, Sample, stack_samples from ..samplers import ( BatchGeoSampler, GeoSampler, @@ -34,7 +33,7 @@ class BaseDataModule(LightningDataModule): def __init__( self, - dataset_class: type[Dataset[dict[str, Tensor]]], + dataset_class: type[Dataset[Sample]], batch_size: int = 1, num_workers: int = 0, **kwargs: Any, @@ -55,11 +54,11 @@ def __init__( self.kwargs = kwargs # Datasets - self.dataset: Dataset[dict[str, Tensor]] | None = None - self.train_dataset: Dataset[dict[str, Tensor]] | None = None - self.val_dataset: Dataset[dict[str, Tensor]] | None = None - self.test_dataset: Dataset[dict[str, Tensor]] | None = None - self.predict_dataset: Dataset[dict[str, Tensor]] | None = None + self.dataset: Dataset[Sample] | None = None + self.train_dataset: Dataset[Sample] | None = None + self.val_dataset: Dataset[Sample] | None = None + self.test_dataset: Dataset[Sample] | None = None + self.predict_dataset: Dataset[Sample] | None = None # Data loaders self.train_batch_size: int | None = None @@ -68,7 +67,7 @@ def __init__( self.predict_batch_size: int | None = None # Data augmentation - Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]] + Transform = Callable[[Batch], Batch] self.aug: Transform = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) @@ -115,9 +114,7 @@ def _valid_attribute(self, *args: str) -> Any: msg = f'{self.__class__.__name__}.setup must define one of {args}.' raise MisconfigurationException(msg) - def on_after_batch_transfer( - self, batch: dict[str, Tensor], dataloader_idx: int - ) -> dict[str, Tensor]: + def on_after_batch_transfer(self, batch: Batch, dataloader_idx: int) -> Batch: """Apply batch augmentations to the batch after it is transferred to the device. Args: @@ -253,7 +250,7 @@ def setup(self, stage: str) -> None: self.test_dataset, self.patch_size, self.patch_size ) - def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: + def _dataloader_factory(self, split: str) -> DataLoader[Batch]: """Implement one or more PyTorch DataLoaders. Args: @@ -289,7 +286,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: persistent_workers=self.num_workers > 0, ) - def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: + def train_dataloader(self) -> DataLoader[Batch]: """Implement one or more PyTorch DataLoaders for training. Returns: @@ -301,7 +298,7 @@ def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: """ return self._dataloader_factory('train') - def val_dataloader(self) -> DataLoader[dict[str, Tensor]]: + def val_dataloader(self) -> DataLoader[Batch]: """Implement one or more PyTorch DataLoaders for validation. Returns: @@ -313,7 +310,7 @@ def val_dataloader(self) -> DataLoader[dict[str, Tensor]]: """ return self._dataloader_factory('val') - def test_dataloader(self) -> DataLoader[dict[str, Tensor]]: + def test_dataloader(self) -> DataLoader[Batch]: """Implement one or more PyTorch DataLoaders for testing. Returns: @@ -325,7 +322,7 @@ def test_dataloader(self) -> DataLoader[dict[str, Tensor]]: """ return self._dataloader_factory('test') - def predict_dataloader(self) -> DataLoader[dict[str, Tensor]]: + def predict_dataloader(self) -> DataLoader[Batch]: """Implement one or more PyTorch DataLoaders for prediction. Returns: @@ -338,8 +335,8 @@ def predict_dataloader(self) -> DataLoader[dict[str, Tensor]]: return self._dataloader_factory('predict') def transfer_batch_to_device( - self, batch: dict[str, Tensor], device: torch.device, dataloader_idx: int - ) -> dict[str, Tensor]: + self, batch: Batch, device: torch.device, dataloader_idx: int + ) -> Batch: """Transfer batch to device. Defines how custom data types are moved to the target device. @@ -409,7 +406,7 @@ def setup(self, stage: str) -> None: split='test', **self.kwargs ) - def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: + def _dataloader_factory(self, split: str) -> DataLoader[Batch]: """Implement one or more PyTorch DataLoaders. Args: @@ -433,7 +430,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: persistent_workers=self.num_workers > 0, ) - def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: + def train_dataloader(self) -> DataLoader[Batch]: """Implement one or more PyTorch DataLoaders for training. Returns: @@ -445,7 +442,7 @@ def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: """ return self._dataloader_factory('train') - def val_dataloader(self) -> DataLoader[dict[str, Tensor]]: + def val_dataloader(self) -> DataLoader[Batch]: """Implement one or more PyTorch DataLoaders for validation. Returns: @@ -457,7 +454,7 @@ def val_dataloader(self) -> DataLoader[dict[str, Tensor]]: """ return self._dataloader_factory('val') - def test_dataloader(self) -> DataLoader[dict[str, Tensor]]: + def test_dataloader(self) -> DataLoader[Batch]: """Implement one or more PyTorch DataLoaders for testing. Returns: @@ -469,7 +466,7 @@ def test_dataloader(self) -> DataLoader[dict[str, Tensor]]: """ return self._dataloader_factory('test') - def predict_dataloader(self) -> DataLoader[dict[str, Tensor]]: + def predict_dataloader(self) -> DataLoader[Batch]: """Implement one or more PyTorch DataLoaders for prediction. Returns: diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index 2ca50fb10ae..8982c02f1e6 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -6,10 +6,9 @@ from typing import Any import torch -from torch import Tensor from torch.utils.data import Subset -from ..datasets import SEN12MS +from ..datasets import SEN12MS, Sample from .geo import NonGeoDataModule from .utils import group_shuffle_split @@ -96,9 +95,7 @@ def setup(self, stage: str) -> None: if stage in ['test']: self.test_dataset = SEN12MS(split='test', **self.kwargs) - def on_after_batch_transfer( - self, batch: dict[str, Tensor], dataloader_idx: int - ) -> dict[str, Tensor]: + def on_after_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample: """Apply batch augmentations to the batch after it is transferred to the device. Args: diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index 4c3aab63b61..01300eac2ff 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -10,9 +10,10 @@ import numpy as np import torch from einops import rearrange -from torch import Tensor from torch.nn import Module +from ..datasets import Sample + # Based on lightning_lite.utilities.exceptions class MisconfigurationException(Exception): @@ -25,9 +26,7 @@ class AugPipe(Module): .. versionadded:: 0.6 """ - def __init__( - self, augs: Callable[[dict[str, Any]], dict[str, Any]], batch_size: int - ) -> None: + def __init__(self, augs: Callable[[Sample], Sample], batch_size: int) -> None: """Initialize a new AugPipe instance. Args: @@ -38,7 +37,7 @@ def __init__( self.augs = augs self.batch_size = batch_size - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + def forward(self, batch: Sample) -> Sample: """Apply the augmentation. Args: @@ -73,7 +72,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: return batch -def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]: +def collate_fn_detection(batch: list[Sample]) -> Sample: """Custom collate fn for object detection and instance segmentation. Args: @@ -84,7 +83,7 @@ def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]: .. versionadded:: 0.6 """ - output: dict[str, Any] = {} + output: Sample = {} output['image'] = [sample['image'] for sample in batch] output['boxes'] = [sample['boxes'].float() for sample in batch] if 'labels' in batch[0]: diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index f3b88f6396c..4f1cc727519 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -134,7 +134,9 @@ from .ucmerced import UCMerced from .usavars import USAVars from .utils import ( + Batch, BoundingBox, + Sample, concat_samples, merge_samples, stack_samples, @@ -283,6 +285,8 @@ 'UnionDataset', 'VectorDataset', # Utilities + 'Sample', + 'Batch', 'BoundingBox', 'concat_samples', 'merge_samples', diff --git a/torchgeo/datasets/advance.py b/torchgeo/datasets/advance.py index c9fcea22a01..9952ef35445 100644 --- a/torchgeo/datasets/advance.py +++ b/torchgeo/datasets/advance.py @@ -17,7 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, download_and_extract_archive, lazy_import +from .utils import Path, Sample, download_and_extract_archive, lazy_import class ADVANCE(NonGeoDataset): @@ -89,7 +89,7 @@ class ADVANCE(NonGeoDataset): def __init__( self, root: Path = 'data', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -122,7 +122,7 @@ def __init__( self.classes = tuple(sorted({f['cls'] for f in self.files})) self.class_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)} - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -136,7 +136,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: audio = self._load_target(files['audio']) cls_label = self.class_to_idx[files['cls']] label = torch.tensor(cls_label, dtype=torch.long) - sample = {'image': image, 'audio': audio, 'label': label} + sample: Sample = {'image': image, 'audio': audio, 'label': label} if self.transforms is not None: sample = self.transforms(sample) @@ -224,10 +224,7 @@ def _download(self) -> None: ) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/agb_live_woody_density.py b/torchgeo/datasets/agb_live_woody_density.py index aaef8db9751..f258dca1420 100644 --- a/torchgeo/datasets/agb_live_woody_density.py +++ b/torchgeo/datasets/agb_live_woody_density.py @@ -6,7 +6,6 @@ import json import os from collections.abc import Callable, Iterable -from typing import Any import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -14,7 +13,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import Path, download_url +from .utils import Path, Sample, download_url class AbovegroundLiveWoodyBiomassDensity(RasterDataset): @@ -60,7 +59,7 @@ def __init__( paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, cache: bool = True, ) -> None: @@ -119,10 +118,7 @@ def _download(self) -> None: ) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/agrifieldnet.py b/torchgeo/datasets/agrifieldnet.py index 3624c1e193e..619e02f6168 100644 --- a/torchgeo/datasets/agrifieldnet.py +++ b/torchgeo/datasets/agrifieldnet.py @@ -6,7 +6,7 @@ import os import re from collections.abc import Callable, Iterable, Sequence -from typing import Any, ClassVar, cast +from typing import ClassVar, cast import matplotlib.pyplot as plt import torch @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import RasterDataset -from .utils import BoundingBox, Path, which +from .utils import BoundingBox, Path, Sample, which class AgriFieldNet(RasterDataset): @@ -128,7 +128,7 @@ def __init__( crs: CRS | None = None, classes: list[int] = list(cmap.keys()), bands: Sequence[str] = all_bands, - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, download: bool = False, ) -> None: @@ -171,7 +171,7 @@ def __init__( self.ordinal_map[k] = v self.ordinal_cmap[v] = torch.tensor(self.cmap[k]) - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Return an index within the dataset. Args: @@ -218,7 +218,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: mask = self._merge_files(mask_filepaths, query) mask = self.ordinal_map[mask.squeeze().long()] - sample = { + sample: Sample = { 'crs': self.crs, 'bounds': query, 'image': image.float(), @@ -251,10 +251,7 @@ def _download(self) -> None: azcopy('sync', f'{self.url}', self.paths, '--recursive=true') def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/airphen.py b/torchgeo/datasets/airphen.py index 12b8c38141c..ca8c01f3287 100644 --- a/torchgeo/datasets/airphen.py +++ b/torchgeo/datasets/airphen.py @@ -3,14 +3,12 @@ """Airphen dataset.""" -from typing import Any - import matplotlib.pyplot as plt from matplotlib.figure import Figure from .errors import RGBBandsMissingError from .geo import RasterDataset -from .utils import percentile_normalization +from .utils import Sample, percentile_normalization class Airphen(RasterDataset): @@ -44,10 +42,7 @@ class Airphen(RasterDataset): rgb_bands = ('B4', 'B3', 'B1') def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/astergdem.py b/torchgeo/datasets/astergdem.py index c4ef23061b8..5c6827165cb 100644 --- a/torchgeo/datasets/astergdem.py +++ b/torchgeo/datasets/astergdem.py @@ -4,7 +4,6 @@ """Aster Global Digital Elevation Model dataset.""" from collections.abc import Callable -from typing import Any import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -12,7 +11,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import Path +from .utils import Path, Sample class AsterGDEM(RasterDataset): @@ -51,7 +50,7 @@ def __init__( paths: Path | list[Path] = 'data', crs: CRS | None = None, res: float | None = None, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, ) -> None: """Initialize a new Dataset instance. @@ -88,10 +87,7 @@ def _verify(self) -> None: raise DatasetNotFoundError(self) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/benin_cashews.py b/torchgeo/datasets/benin_cashews.py index 4dd1ae927de..ab629fbdeb7 100644 --- a/torchgeo/datasets/benin_cashews.py +++ b/torchgeo/datasets/benin_cashews.py @@ -19,7 +19,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import Path, which +from .utils import Path, Sample, which class BeninSmallHolderCashews(NonGeoDataset): @@ -167,7 +167,7 @@ def __init__( chip_size: int = 256, stride: int = 128, bands: Sequence[str] = all_bands, - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, ) -> None: """Initialize a new Benin Smallholder Cashew Plantations Dataset instance. @@ -209,7 +209,7 @@ def __init__( ]: self.chips_metadata.append((y, x)) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -226,7 +226,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: img = img[:, :, y : y + self.chip_size, x : x + self.chip_size] labels = labels[y : y + self.chip_size, x : x + self.chip_size] - sample = { + sample: Sample = { 'image': img, 'mask': labels, 'x': torch.tensor(x), @@ -348,7 +348,7 @@ def _download(self) -> None: def plot( self, - sample: dict[str, Tensor], + sample: Sample, show_titles: bool = True, time_step: int = 0, suptitle: str | None = None, diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index 38669cd6ff1..32e7dbcb206 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -19,7 +19,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, download_url, extract_archive, sort_sentinel2_bands +from .utils import Path, Sample, download_url, extract_archive, sort_sentinel2_bands class BigEarthNet(NonGeoDataset): @@ -272,7 +272,7 @@ def __init__( split: str = 'train', bands: str = 'all', num_classes: int = 19, - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -305,7 +305,7 @@ def __init__( self._verify() self.folders = self._load_folders() - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -316,7 +316,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: """ image = self._load_image(index) label = self._load_target(index) - sample: dict[str, Tensor] = {'image': image, 'label': label} + sample: Sample = {'image': image, 'label': label} if self.transforms is not None: sample = self.transforms(sample) @@ -527,10 +527,7 @@ def _onehot_labels_to_names( return labels def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/biomassters.py b/torchgeo/datasets/biomassters.py index 70a53a4220a..2d7db32c97b 100644 --- a/torchgeo/datasets/biomassters.py +++ b/torchgeo/datasets/biomassters.py @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, percentile_normalization +from .utils import Path, Sample, percentile_normalization class BioMassters(NonGeoDataset): @@ -127,7 +127,7 @@ def __init__( self.df['num_index'] = self.df.groupby(['chip_id', 'month']).ngroup() - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -147,7 +147,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: ) filepaths = sample_df['filename'].tolist() - sample: dict[str, Tensor] = {} + sample: Sample = {} for sens in self.sensors: sens_filepaths = [fp for fp in filepaths if sens in fp] sample[f'image_{sens}'] = self._load_input(sens_filepaths) @@ -216,10 +216,7 @@ def _verify(self) -> None: raise DatasetNotFoundError(self) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/cbf.py b/torchgeo/datasets/cbf.py index 3c986eb44c1..58f6ca1f539 100644 --- a/torchgeo/datasets/cbf.py +++ b/torchgeo/datasets/cbf.py @@ -5,7 +5,6 @@ import os from collections.abc import Callable, Iterable -from typing import Any import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -13,7 +12,7 @@ from .errors import DatasetNotFoundError from .geo import VectorDataset -from .utils import Path, check_integrity, download_and_extract_archive +from .utils import Path, Sample, check_integrity, download_and_extract_archive class CanadianBuildingFootprints(VectorDataset): @@ -65,7 +64,7 @@ def __init__( paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float = 0.00001, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -125,10 +124,7 @@ def _download(self) -> None: ) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/cdl.py b/torchgeo/datasets/cdl.py index 0b0f6ac5b3d..b74696faef0 100644 --- a/torchgeo/datasets/cdl.py +++ b/torchgeo/datasets/cdl.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable, Iterable -from typing import Any, ClassVar +from typing import ClassVar import matplotlib.pyplot as plt import torch @@ -14,7 +14,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import BoundingBox, Path, download_url, extract_archive +from .utils import BoundingBox, Path, Sample, download_url, extract_archive class CDL(RasterDataset): @@ -212,7 +212,7 @@ def __init__( res: float | None = None, years: list[int] = [2023], classes: list[int] = list(cmap.keys()), - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -270,7 +270,7 @@ def __init__( self.ordinal_map[k] = v self.ordinal_cmap[v] = torch.tensor(self.cmap[k]) - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Retrieve mask and metadata indexed by query. Args: @@ -334,10 +334,7 @@ def _extract(self) -> None: extract_archive(pathname, self.paths) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/chabud.py b/torchgeo/datasets/chabud.py index ba773607a54..6dd3d0e504c 100644 --- a/torchgeo/datasets/chabud.py +++ b/torchgeo/datasets/chabud.py @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, download_url, lazy_import, percentile_normalization +from .utils import Path, Sample, download_url, lazy_import, percentile_normalization class ChaBuD(NonGeoDataset): @@ -79,7 +79,7 @@ def __init__( root: Path = 'data', split: str = 'train', bands: Sequence[str] = all_bands, - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -117,7 +117,7 @@ def __init__( self.uuids = self._load_uuids() - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -129,7 +129,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image = self._load_image(index) mask = self._load_target(index) - sample = {'image': image, 'mask': mask} + sample: Sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -227,10 +227,7 @@ def _download(self) -> None: ) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index af9b01f54dc..c4504e201ff 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -8,7 +8,7 @@ import sys from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Sequence -from typing import Any, ClassVar, cast +from typing import ClassVar, cast import fiona import matplotlib.pyplot as plt @@ -22,12 +22,11 @@ from matplotlib.colors import ListedColormap from matplotlib.figure import Figure from rasterio.crs import CRS -from torch import Tensor from .errors import DatasetNotFoundError from .geo import GeoDataset, RasterDataset from .nlcd import NLCD -from .utils import BoundingBox, Path, download_url, extract_archive +from .utils import BoundingBox, Path, Sample, download_url, extract_archive class Chesapeake(RasterDataset, ABC): @@ -129,7 +128,7 @@ def __init__( paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -199,10 +198,7 @@ def _extract(self) -> None: extract_archive(file) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. @@ -436,7 +432,7 @@ def __init__( root: Path = 'data', splits: Sequence[str] = ['de-train'], layers: Sequence[str] = ['naip-new', 'lc'], - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -514,7 +510,7 @@ def __init__( }, ) - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Retrieve image/mask and metadata indexed by query. Args: @@ -529,7 +525,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: hits = self.index.intersection(tuple(query), objects=True) filepaths = cast(list[dict[str, str]], [hit.object for hit in hits]) - sample = {'image': [], 'mask': [], 'crs': self.crs, 'bounds': query} + sample: Sample = {'image': [], 'mask': [], 'crs': self.crs, 'bounds': query} if len(filepaths) == 0: raise IndexError( @@ -632,10 +628,7 @@ def _extract(self) -> None: extract_archive(os.path.join(self.root, self.filenames[subdataset])) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/cloud_cover.py b/torchgeo/datasets/cloud_cover.py index e0ca0045e33..ee654776f39 100644 --- a/torchgeo/datasets/cloud_cover.py +++ b/torchgeo/datasets/cloud_cover.py @@ -17,7 +17,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import Path, which +from .utils import Path, Sample, which class CloudCoverDetection(NonGeoDataset): @@ -65,7 +65,7 @@ def __init__( root: Path = 'data', split: str = 'train', bands: Sequence[str] = all_bands, - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, ) -> None: """Initiatlize a CloudCoverDetection instance. @@ -104,7 +104,7 @@ def __len__(self) -> int: """ return len(self.metadata) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Returns a sample from dataset. Args: @@ -116,7 +116,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: chip_id = self.metadata.iat[index, 0] image = self._load_image(chip_id) label = self._load_target(chip_id) - sample = {'image': image, 'mask': label} + sample: Sample = {'image': image, 'mask': label} if self.transforms is not None: sample = self.transforms(sample) @@ -174,10 +174,7 @@ def _download(self) -> None: azcopy('sync', url, directory, '--recursive=true') def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/cms_mangrove_canopy.py b/torchgeo/datasets/cms_mangrove_canopy.py index f9db256238d..cd9553d6aee 100644 --- a/torchgeo/datasets/cms_mangrove_canopy.py +++ b/torchgeo/datasets/cms_mangrove_canopy.py @@ -5,7 +5,6 @@ import os from collections.abc import Callable -from typing import Any import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -13,7 +12,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import Path, check_integrity, extract_archive +from .utils import Path, Sample, check_integrity, extract_archive class CMSGlobalMangroveCanopy(RasterDataset): @@ -174,7 +173,7 @@ def __init__( res: float | None = None, measurement: str = 'agb', country: str = all_countries[0], - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, checksum: bool = False, ) -> None: @@ -245,10 +244,7 @@ def _extract(self) -> None: extract_archive(pathname) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/cowc.py b/torchgeo/datasets/cowc.py index fa97fa87037..de8e24e5c4c 100644 --- a/torchgeo/datasets/cowc.py +++ b/torchgeo/datasets/cowc.py @@ -18,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, check_integrity, download_and_extract_archive +from .utils import Path, Sample, check_integrity, download_and_extract_archive class COWC(NonGeoDataset, abc.ABC): @@ -67,7 +67,7 @@ def __init__( self, root: Path = 'data', split: str = 'train', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -110,7 +110,7 @@ def __init__( self.images.append(row[0]) self.targets.append(row[1]) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -119,7 +119,10 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Returns: data and label at that index """ - sample = {'image': self._load_image(index), 'label': self._load_target(index)} + sample: Sample = { + 'image': self._load_image(index), + 'label': self._load_target(index), + } if self.transforms is not None: sample = self.transforms(sample) @@ -191,10 +194,7 @@ def _download(self) -> None: ) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/cropharvest.py b/torchgeo/datasets/cropharvest.py index bb3e4b3f3c5..8926b3a56da 100644 --- a/torchgeo/datasets/cropharvest.py +++ b/torchgeo/datasets/cropharvest.py @@ -18,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, download_url, extract_archive, lazy_import +from .utils import Path, Sample, download_url, extract_archive, lazy_import class CropHarvest(NonGeoDataset): @@ -98,7 +98,7 @@ class CropHarvest(NonGeoDataset): def __init__( self, root: Path = 'data', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -130,7 +130,7 @@ def __init__( self.classes = self.classes[self.classes != np.array(None)] self.classes = np.insert(self.classes, 0, ['None', 'Other']) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -143,7 +143,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data = self._load_array(files['chip']) label = self._load_label(files['index'], files['dataset']) - sample = {'array': data, 'label': label} + sample: Sample = {'array': data, 'label': label} if self.transforms is not None: sample = self.transforms(sample) @@ -289,7 +289,7 @@ def _extract(self) -> None: features_path = os.path.join(self.root, self.file_dict['features']['filename']) extract_archive(features_path) - def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure: + def plot(self, sample: Sample, suptitle: str | None = None) -> Figure: """Plot a sample from the dataset using bands for Agriculture RGB composite. Args: diff --git a/torchgeo/datasets/cv4a_kenya_crop_type.py b/torchgeo/datasets/cv4a_kenya_crop_type.py index 2248dab4292..2bd2b8305a5 100644 --- a/torchgeo/datasets/cv4a_kenya_crop_type.py +++ b/torchgeo/datasets/cv4a_kenya_crop_type.py @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import Path, which +from .utils import Path, Sample, which class CV4AKenyaCropType(NonGeoDataset): @@ -108,7 +108,7 @@ def __init__( chip_size: int = 256, stride: int = 128, bands: Sequence[str] = all_bands, - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, ) -> None: """Initialize a new CV4A Kenya Crop Type Dataset instance. @@ -151,7 +151,7 @@ def __init__( ]: self.chips_metadata.append((tile_index, y, x)) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -170,7 +170,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: labels = labels[y : y + self.chip_size, x : x + self.chip_size] field_ids = field_ids[y : y + self.chip_size, x : x + self.chip_size] - sample = { + sample: Sample = { 'image': img, 'mask': labels, 'field_ids': field_ids, @@ -284,7 +284,7 @@ def _download(self) -> None: def plot( self, - sample: dict[str, Tensor], + sample: Sample, show_titles: bool = True, time_step: int = 0, suptitle: str | None = None, diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index 2a21832703a..06da7a8cad4 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -6,7 +6,6 @@ import os from collections.abc import Callable from functools import lru_cache -from typing import Any import matplotlib.pyplot as plt import numpy as np @@ -18,7 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, which +from .utils import Path, Sample, which class TropicalCyclone(NonGeoDataset): @@ -55,7 +54,7 @@ def __init__( self, root: Path = 'data', split: str = 'train', - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, ) -> None: """Initialize a new TropicalCyclone instance. @@ -87,7 +86,7 @@ def __init__( self.features = pd.read_csv(os.path.join(root, f'{self.filename}_features.csv')) self.labels = pd.read_csv(os.path.join(root, f'{self.filename}_labels.csv')) - def __getitem__(self, index: int) -> dict[str, Any]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -96,7 +95,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: Returns: data, labels, field ids, and metadata at that index """ - sample = { + sample: Sample = { 'relative_time': torch.tensor(self.features.iat[index, 2]), 'ocean': torch.tensor(self.features.iat[index, 3]), 'label': torch.tensor(self.labels.iat[index, 1]), @@ -168,10 +167,7 @@ def _download(self) -> None: azcopy('copy', f'{self.url}/{file}', self.root) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/deepglobelandcover.py b/torchgeo/datasets/deepglobelandcover.py index 47ef6f15ce3..ec7eadfa475 100644 --- a/torchgeo/datasets/deepglobelandcover.py +++ b/torchgeo/datasets/deepglobelandcover.py @@ -17,6 +17,7 @@ from .geo import NonGeoDataset from .utils import ( Path, + Sample, check_integrity, draw_semantic_segmentation_masks, extract_archive, @@ -103,7 +104,7 @@ def __init__( self, root: Path = 'data', split: str = 'train', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, checksum: bool = False, ) -> None: """Initialize a new DeepGlobeLandCover dataset instance. @@ -147,7 +148,7 @@ def __init__( self.image_fns.append(image_path) self.mask_fns.append(mask_path) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -158,7 +159,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: """ image = self._load_image(index) mask = self._load_target(index) - sample = {'image': image, 'mask': mask} + sample: Sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -228,7 +229,7 @@ def _verify(self) -> None: def plot( self, - sample: dict[str, Tensor], + sample: Sample, show_titles: bool = True, suptitle: str | None = None, alpha: float = 0.5, diff --git a/torchgeo/datasets/dfc2022.py b/torchgeo/datasets/dfc2022.py index 6886ea6b1d3..86b31b79ae9 100644 --- a/torchgeo/datasets/dfc2022.py +++ b/torchgeo/datasets/dfc2022.py @@ -19,7 +19,13 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, check_integrity, extract_archive, percentile_normalization +from .utils import ( + Path, + Sample, + check_integrity, + extract_archive, + percentile_normalization, +) class DFC2022(NonGeoDataset): @@ -140,7 +146,7 @@ def __init__( self, root: Path = 'data', split: str = 'train', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, checksum: bool = False, ) -> None: """Initialize a new DFC2022 dataset instance. @@ -167,7 +173,7 @@ def __init__( self.class2idx = {c: i for i, c in enumerate(self.classes)} self.files = self._load_files() - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -181,7 +187,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: dem = self._load_image(files['dem'], shape=image.shape[1:]) image = torch.cat(tensors=[image, dem], dim=0) - sample = {'image': image} + sample: Sample = {'image': image} if self.split == 'train': mask = self._load_target(files['target']) @@ -289,10 +295,7 @@ def _verify(self) -> None: raise DatasetNotFoundError(self) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/eddmaps.py b/torchgeo/datasets/eddmaps.py index d3a046993a1..29bffebdbf9 100644 --- a/torchgeo/datasets/eddmaps.py +++ b/torchgeo/datasets/eddmaps.py @@ -5,7 +5,6 @@ import os import sys -from typing import Any import numpy as np import pandas as pd @@ -13,7 +12,7 @@ from .errors import DatasetNotFoundError from .geo import GeoDataset -from .utils import BoundingBox, Path, disambiguate_timestamp +from .utils import BoundingBox, Path, Sample, disambiguate_timestamp class EDDMapS(GeoDataset): @@ -80,7 +79,7 @@ def __init__(self, root: Path = 'data') -> None: self.index.insert(i, coords) i += 1 - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Retrieve metadata indexed by query. Args: @@ -100,6 +99,6 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: f'query: {query} not found in index with bounds: {self.bounds}' ) - sample = {'crs': self.crs, 'bounds': bboxes} + sample: Sample = {'crs': self.crs, 'bounds': bboxes} return sample diff --git a/torchgeo/datasets/enviroatlas.py b/torchgeo/datasets/enviroatlas.py index b8af4aef70e..348928b1996 100644 --- a/torchgeo/datasets/enviroatlas.py +++ b/torchgeo/datasets/enviroatlas.py @@ -6,7 +6,7 @@ import os import sys from collections.abc import Callable, Sequence -from typing import Any, ClassVar, cast +from typing import ClassVar, cast import fiona import matplotlib.pyplot as plt @@ -23,7 +23,7 @@ from .errors import DatasetNotFoundError from .geo import GeoDataset -from .utils import BoundingBox, Path, download_url, extract_archive +from .utils import BoundingBox, Path, Sample, download_url, extract_archive class EnviroAtlas(GeoDataset): @@ -257,7 +257,7 @@ def __init__( root: Path = 'data', splits: Sequence[str] = ['pittsburgh_pa-2010_1m-train'], layers: Sequence[str] = ['naip', 'prior'], - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, prior_as_input: bool = False, cache: bool = True, download: bool = False, @@ -333,7 +333,7 @@ def __init__( }, ) - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Retrieve image/mask and metadata indexed by query. Args: @@ -348,7 +348,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: hits = self.index.intersection(tuple(query), objects=True) filepaths = cast(list[dict[str, str]], [hit.object for hit in hits]) - sample = {'image': [], 'mask': [], 'crs': self.crs, 'bounds': query} + sample: Sample = {'image': [], 'mask': [], 'crs': self.crs, 'bounds': query} if len(filepaths) == 0: raise IndexError( @@ -444,10 +444,7 @@ def _extract(self) -> None: extract_archive(os.path.join(self.root, self.filename)) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/esri2020.py b/torchgeo/datasets/esri2020.py index 197272d6b48..48753682102 100644 --- a/torchgeo/datasets/esri2020.py +++ b/torchgeo/datasets/esri2020.py @@ -6,7 +6,6 @@ import glob import os from collections.abc import Callable, Iterable -from typing import Any import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -14,7 +13,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import Path, download_url, extract_archive +from .utils import Path, Sample, download_url, extract_archive class Esri2020(RasterDataset): @@ -72,7 +71,7 @@ def __init__( paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -136,10 +135,7 @@ def _extract(self) -> None: extract_archive(os.path.join(self.paths, self.zipfile)) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/etci2021.py b/torchgeo/datasets/etci2021.py index 7855c8bb3cf..a54e9bfec9c 100644 --- a/torchgeo/datasets/etci2021.py +++ b/torchgeo/datasets/etci2021.py @@ -17,7 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, download_and_extract_archive +from .utils import Path, Sample, download_and_extract_archive class ETCI2021(NonGeoDataset): @@ -84,7 +84,7 @@ def __init__( self, root: Path = 'data', split: str = 'train', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -117,7 +117,7 @@ def __init__( self.files = self._load_files(self.root, self.split) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -138,7 +138,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: mask = water_mask.unsqueeze(0) image = torch.cat(tensors=[vv, vh], dim=0) - sample = {'image': image, 'mask': mask} + sample: Sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -254,10 +254,7 @@ def _download(self) -> None: ) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/eudem.py b/torchgeo/datasets/eudem.py index 5a9af7f6fa3..42350dadcdc 100644 --- a/torchgeo/datasets/eudem.py +++ b/torchgeo/datasets/eudem.py @@ -6,7 +6,7 @@ import glob import os from collections.abc import Callable, Iterable -from typing import Any, ClassVar +from typing import ClassVar import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -14,7 +14,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import Path, check_integrity, extract_archive +from .utils import Path, Sample, check_integrity, extract_archive class EUDEM(RasterDataset): @@ -78,7 +78,7 @@ def __init__( paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, checksum: bool = False, ) -> None: @@ -129,10 +129,7 @@ def _verify(self) -> None: raise DatasetNotFoundError(self) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/eurocrops.py b/torchgeo/datasets/eurocrops.py index 5f438143c87..f3caa517a82 100644 --- a/torchgeo/datasets/eurocrops.py +++ b/torchgeo/datasets/eurocrops.py @@ -16,7 +16,13 @@ from .errors import DatasetNotFoundError from .geo import VectorDataset -from .utils import Path, check_integrity, download_and_extract_archive, download_url +from .utils import ( + Path, + Sample, + check_integrity, + download_and_extract_archive, + download_url, +) class EuroCrops(VectorDataset): @@ -89,7 +95,7 @@ def __init__( crs: CRS = CRS.from_epsg(4326), res: float = 0.00001, classes: list[str] | None = None, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -219,10 +225,7 @@ def get_label(self, feature: 'fiona.model.Feature') -> int: return 0 def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index 26c1c860cf8..d7f0ec5397d 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -15,7 +15,14 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoClassificationDataset -from .utils import Path, check_integrity, download_url, extract_archive, rasterio_loader +from .utils import ( + Path, + Sample, + check_integrity, + download_url, + extract_archive, + rasterio_loader, +) class EuroSAT(NonGeoClassificationDataset): @@ -103,7 +110,7 @@ def __init__( root: Path = 'data', split: str = 'train', bands: Sequence[str] = BAND_SETS['all'], - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -155,7 +162,7 @@ def is_in_split(x: Path) -> bool: is_valid_file=is_in_split, ) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -166,7 +173,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image, label = self._load_image(index) image = torch.index_select(image, dim=0, index=self.band_indices).float() - sample = {'image': image, 'label': label} + sample: Sample = {'image': image, 'label': label} if self.transforms is not None: sample = self.transforms(sample) @@ -243,10 +250,7 @@ def _validate_bands(self, bands: Sequence[str]) -> None: raise ValueError(f"'{band}' is an invalid band name.") def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/fair1m.py b/torchgeo/datasets/fair1m.py index d58968eaa19..0cb8de4430e 100644 --- a/torchgeo/datasets/fair1m.py +++ b/torchgeo/datasets/fair1m.py @@ -6,7 +6,7 @@ import glob import os from collections.abc import Callable -from typing import Any, ClassVar, cast +from typing import ClassVar, cast from xml.etree.ElementTree import Element, parse import matplotlib.patches as patches @@ -19,10 +19,10 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, check_integrity, download_url, extract_archive +from .utils import Path, Sample, check_integrity, download_url, extract_archive -def parse_pascal_voc(path: Path) -> dict[str, Any]: +def parse_pascal_voc(path: Path) -> Sample: """Read a PASCAL VOC annotation file. Args: @@ -119,7 +119,7 @@ class FAIR1M(NonGeoDataset): .. versionadded:: 0.2 """ - classes: ClassVar[dict[str, dict[str, Any]]] = { + classes: ClassVar[dict[str, Sample]] = { 'Passenger Ship': {'id': 0, 'category': 'Ship'}, 'Motorboat': {'id': 1, 'category': 'Ship'}, 'Fishing Boat': {'id': 2, 'category': 'Ship'}, @@ -232,7 +232,7 @@ def __init__( self, root: Path = 'data', split: str = 'train', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -264,7 +264,7 @@ def __init__( glob.glob(os.path.join(self.root, self.filename_glob[split])) ) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -276,14 +276,14 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: path = self.files[index] image = self._load_image(path) - sample = {'image': image} + sample: Sample = {'image': image} if self.split != 'test': label_path = str(path).replace(self.image_root, self.label_root) label_path = label_path.replace('.tif', '.xml') voc = parse_pascal_voc(label_path) boxes, labels = self._load_target(voc['points'], voc['labels']) - sample = {'image': image, 'boxes': boxes, 'label': labels} + sample: Sample = {'image': image, 'boxes': boxes, 'label': labels} if self.transforms is not None: sample = self.transforms(sample) @@ -383,10 +383,7 @@ def _download(self) -> None: extract_archive(filepath) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/fire_risk.py b/torchgeo/datasets/fire_risk.py index 9370488f503..0075cffdc1d 100644 --- a/torchgeo/datasets/fire_risk.py +++ b/torchgeo/datasets/fire_risk.py @@ -9,11 +9,10 @@ import matplotlib.pyplot as plt from matplotlib.figure import Figure -from torch import Tensor from .errors import DatasetNotFoundError from .geo import NonGeoClassificationDataset -from .utils import Path, download_url, extract_archive +from .utils import Path, Sample, download_url, extract_archive class FireRisk(NonGeoClassificationDataset): @@ -70,7 +69,7 @@ def __init__( self, root: Path = 'data', split: str = 'train', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -135,10 +134,7 @@ def _extract(self) -> None: extract_archive(filepath) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py index 885ac21a471..dd3e7aa86ee 100644 --- a/torchgeo/datasets/forestdamage.py +++ b/torchgeo/datasets/forestdamage.py @@ -6,7 +6,6 @@ import glob import os from collections.abc import Callable -from typing import Any from xml.etree import ElementTree import matplotlib.patches as patches @@ -19,10 +18,16 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, check_integrity, download_and_extract_archive, extract_archive +from .utils import ( + Path, + Sample, + check_integrity, + download_and_extract_archive, + extract_archive, +) -def parse_pascal_voc(path: Path) -> dict[str, Any]: +def parse_pascal_voc(path: Path) -> Sample: """Read a PASCAL VOC annotation file. Args: @@ -104,7 +109,7 @@ class ForestDamage(NonGeoDataset): def __init__( self, root: Path = 'data', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -131,7 +136,7 @@ def __init__( self.class_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)} - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -146,7 +151,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: boxes, labels = self._load_target(parsed['bboxes'], parsed['labels']) - sample = {'image': image, 'boxes': boxes, 'label': labels} + sample: Sample = {'image': image, 'boxes': boxes, 'label': labels} if self.transforms is not None: sample = self.transforms(sample) @@ -250,10 +255,7 @@ def _download(self) -> None: ) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/gbif.py b/torchgeo/datasets/gbif.py index 3e8cfb6c883..4f965ddd642 100644 --- a/torchgeo/datasets/gbif.py +++ b/torchgeo/datasets/gbif.py @@ -7,7 +7,6 @@ import os import sys from datetime import datetime, timedelta -from typing import Any import numpy as np import pandas as pd @@ -15,7 +14,7 @@ from .errors import DatasetNotFoundError from .geo import GeoDataset -from .utils import BoundingBox, Path +from .utils import BoundingBox, Path, Sample def _disambiguate_timestamps( @@ -117,7 +116,7 @@ def __init__(self, root: Path = 'data') -> None: self.index.insert(i, coords) i += 1 - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Retrieve metadata indexed by query. Args: @@ -137,6 +136,6 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: f'query: {query} not found in index with bounds: {self.bounds}' ) - sample = {'crs': self.crs, 'bounds': bboxes} + sample: Sample = {'crs': self.crs, 'bounds': bboxes} return sample diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 26a035d427d..07d31efb530 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -36,6 +36,7 @@ from .utils import ( BoundingBox, Path, + Sample, array_to_tensor, concat_samples, disambiguate_timestamp, @@ -44,7 +45,7 @@ ) -class GeoDataset(Dataset[dict[str, Any]], abc.ABC): +class GeoDataset(Dataset[Sample], abc.ABC): """Abstract base class for datasets containing geospatial information. Geospatial information includes things like: @@ -110,9 +111,7 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC): #: Users should instead use the intersection or union operator. __add__ = None # type: ignore[assignment] - def __init__( - self, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None - ) -> None: + def __init__(self, transforms: Callable[[Sample], Sample] | None = None) -> None: """Initialize a new GeoDataset instance. Args: @@ -125,7 +124,7 @@ def __init__( self.index = Index(interleaved=False, properties=Property(dimension=3)) @abc.abstractmethod - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Retrieve image/mask and metadata indexed by query. Args: @@ -418,7 +417,7 @@ def __init__( crs: CRS | None = None, res: float | None = None, bands: Sequence[str] | None = None, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, ) -> None: """Initialize a new RasterDataset instance. @@ -507,7 +506,7 @@ def __init__( self._crs = cast(CRS, crs) self._res = cast(float, res) - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Retrieve image/mask and metadata indexed by query. Args: @@ -548,7 +547,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: else: data = self._merge_files(filepaths, query, self.band_indexes) - sample = {'crs': self.crs, 'bounds': query} + sample: Sample = {'crs': self.crs, 'bounds': query} data = data.to(self.dtype) if self.is_image: @@ -656,7 +655,7 @@ def __init__( paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float = 0.0001, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, label_name: str | None = None, ) -> None: """Initialize a new VectorDataset instance. @@ -719,7 +718,7 @@ def __init__( self._crs = crs self._res = res - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Retrieve image/mask and metadata indexed by query. Args: @@ -778,7 +777,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: masks = array_to_tensor(masks) masks = masks.to(self.dtype) - sample = {'mask': masks, 'crs': self.crs, 'bounds': query} + sample: Sample = {'mask': masks, 'crs': self.crs, 'bounds': query} if self.transforms is not None: sample = self.transforms(sample) @@ -801,14 +800,14 @@ def get_label(self, feature: 'fiona.model.Feature') -> int: return 1 -class NonGeoDataset(Dataset[dict[str, Any]], abc.ABC): +class NonGeoDataset(Dataset[Sample], abc.ABC): """Abstract base class for datasets lacking geospatial information. This base class is designed for datasets with pre-defined image chips. """ @abc.abstractmethod - def __getitem__(self, index: int) -> dict[str, Any]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -851,7 +850,7 @@ class NonGeoClassificationDataset(NonGeoDataset, ImageFolder): # type: ignore[m def __init__( self, root: Path = 'data', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, loader: Callable[[Path], Any] | None = pil_loader, is_valid_file: Callable[[Path], bool] | None = None, ) -> None: @@ -879,7 +878,7 @@ def __init__( # Must be set after calling super().__init__() self.transforms = transforms - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -889,7 +888,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ image, label = self._load_image(index) - sample = {'image': image, 'label': label} + sample: Sample = {'image': image, 'label': label} if self.transforms is not None: sample = self.transforms(sample) @@ -948,10 +947,8 @@ def __init__( self, dataset1: GeoDataset, dataset2: GeoDataset, - collate_fn: Callable[ - [Sequence[dict[str, Any]]], dict[str, Any] - ] = concat_samples, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + collate_fn: Callable[[Sequence[Sample]], Sample] = concat_samples, + transforms: Callable[[Sample], Sample] | None = None, ) -> None: """Initialize a new IntersectionDataset instance. @@ -1005,7 +1002,7 @@ def _merge_dataset_indices(self) -> None: if i == 0: raise RuntimeError('Datasets have no spatiotemporal intersection') - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Retrieve image and metadata indexed by query. Args: @@ -1109,10 +1106,8 @@ def __init__( self, dataset1: GeoDataset, dataset2: GeoDataset, - collate_fn: Callable[ - [Sequence[dict[str, Any]]], dict[str, Any] - ] = merge_samples, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + collate_fn: Callable[[Sequence[Sample]], Sample] = merge_samples, + transforms: Callable[[Sample], Sample] | None = None, ) -> None: """Initialize a new UnionDataset instance. @@ -1157,7 +1152,7 @@ def _merge_dataset_indices(self) -> None: self.index.insert(i, hit.bounds) i += 1 - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Retrieve image and metadata indexed by query. Args: diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py index b42e6e58df6..589f3c1115a 100644 --- a/torchgeo/datasets/gid15.py +++ b/torchgeo/datasets/gid15.py @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, download_and_extract_archive +from .utils import Path, Sample, download_and_extract_archive class GID15(NonGeoDataset): @@ -90,7 +90,7 @@ def __init__( self, root: Path = 'data', split: str = 'train', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -123,7 +123,7 @@ def __init__( self.files = self._load_files(self.root, self.split) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -137,9 +137,9 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: if self.split != 'test': mask = self._load_target(files['mask']) - sample = {'image': image, 'mask': mask} + sample: Sample = {'image': image, 'mask': mask} else: - sample = {'image': image} + sample: Sample = {'image': image} if self.transforms is not None: sample = self.transforms(sample) @@ -235,7 +235,7 @@ def _download(self) -> None: md5=self.md5 if self.checksum else None, ) - def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure: + def plot(self, sample: Sample, suptitle: str | None = None) -> Figure: """Plot a sample from the dataset. Args: diff --git a/torchgeo/datasets/globbiomass.py b/torchgeo/datasets/globbiomass.py index c214fbba205..b9644b3b491 100644 --- a/torchgeo/datasets/globbiomass.py +++ b/torchgeo/datasets/globbiomass.py @@ -6,7 +6,7 @@ import glob import os from collections.abc import Callable, Iterable -from typing import Any, ClassVar, cast +from typing import ClassVar, cast import matplotlib.pyplot as plt import torch @@ -18,6 +18,7 @@ from .utils import ( BoundingBox, Path, + Sample, check_integrity, disambiguate_timestamp, extract_archive, @@ -141,7 +142,7 @@ def __init__( crs: CRS | None = None, res: float | None = None, measurement: str = 'agb', - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, checksum: bool = False, ) -> None: @@ -178,7 +179,7 @@ def __init__( super().__init__(paths, crs, res, transforms=transforms, cache=cache) - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Retrieve image/mask and metadata indexed by query. Args: @@ -206,7 +207,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: mask = torch.cat((mask, std_err_mask), dim=0) - sample = {'mask': mask, 'crs': self.crs, 'bounds': query} + sample: Sample = {'mask': mask, 'crs': self.crs, 'bounds': query} if self.transforms is not None: sample = self.transforms(sample) @@ -233,10 +234,7 @@ def _verify(self) -> None: raise DatasetNotFoundError(self) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py index 28e890dc69f..8de50c1ab7b 100644 --- a/torchgeo/datasets/idtrees.py +++ b/torchgeo/datasets/idtrees.py @@ -22,7 +22,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, download_url, extract_archive, lazy_import +from .utils import Path, Sample, download_url, extract_archive, lazy_import class IDTReeS(NonGeoDataset): @@ -158,7 +158,7 @@ def __init__( root: Path = 'data', split: str = 'train', task: str = 'task1', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -195,7 +195,7 @@ def __init__( self._verify() self.images, self.geometries, self.labels = self._load(root) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -209,7 +209,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: hsi = self._load_image(path.replace('RGB', 'HSI')) chm = self._load_image(path.replace('RGB', 'CHM')) las = self._load_las(path.replace('RGB', 'LAS').replace('.tif', '.las')) - sample = {'image': image, 'hsi': hsi, 'chm': chm, 'las': las} + sample: Sample = {'image': image, 'hsi': hsi, 'chm': chm, 'las': las} if self.split == 'test': if self.task == 'task2': @@ -282,7 +282,7 @@ def _load_boxes(self, path: Path) -> Tensor: the bounding boxes """ base_path = os.path.basename(path) - geometries = cast(dict[int, dict[str, Any]], self.geometries) + geometries = cast(dict[int, Sample], self.geometries) # Find object ids and geometries # The train set geometry->image mapping is contained @@ -335,9 +335,7 @@ def _load_target(self, path: Path) -> Tensor: tensor = torch.tensor(labels) return tensor - def _load( - self, root: Path - ) -> tuple[list[str], dict[int, dict[str, Any]] | None, Any]: + def _load(self, root: Path) -> tuple[list[str], dict[int, Sample] | None, Any]: """Load files, geometries, and labels. Args: @@ -383,7 +381,7 @@ def _load_labels(self, directory: Path) -> Any: df.reset_index() return df - def _load_geometries(self, directory: Path) -> dict[int, dict[str, Any]]: + def _load_geometries(self, directory: Path) -> dict[int, Sample]: """Load the shape files containing the geometries. Args: @@ -395,7 +393,7 @@ def _load_geometries(self, directory: Path) -> dict[int, dict[str, Any]]: filepaths = glob.glob(os.path.join(directory, 'ITC', '*.shp')) i = 0 - features: dict[int, dict[str, Any]] = {} + features: dict[int, Sample] = {} for path in filepaths: with fiona.open(path) as src: for feature in src: @@ -479,7 +477,7 @@ def _verify(self) -> None: def plot( self, - sample: dict[str, Tensor], + sample: Sample, show_titles: bool = True, suptitle: str | None = None, hsi_indices: tuple[int, int, int] = (0, 1, 2), diff --git a/torchgeo/datasets/inaturalist.py b/torchgeo/datasets/inaturalist.py index bb5cfe3c8df..80f378e1bc5 100644 --- a/torchgeo/datasets/inaturalist.py +++ b/torchgeo/datasets/inaturalist.py @@ -6,14 +6,13 @@ import glob import os import sys -from typing import Any import pandas as pd from rasterio.crs import CRS from .errors import DatasetNotFoundError from .geo import GeoDataset -from .utils import BoundingBox, Path, disambiguate_timestamp +from .utils import BoundingBox, Path, Sample, disambiguate_timestamp class INaturalist(GeoDataset): @@ -87,7 +86,7 @@ def __init__(self, root: Path = 'data') -> None: self.index.insert(i, coords) i += 1 - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Retrieve metadata indexed by query. Args: @@ -107,6 +106,6 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: f'query: {query} not found in index with bounds: {self.bounds}' ) - sample = {'crs': self.crs, 'bounds': bboxes} + sample: Sample = {'crs': self.crs, 'bounds': bboxes} return sample diff --git a/torchgeo/datasets/inria.py b/torchgeo/datasets/inria.py index 3b2a4348a96..69c4fc1c4c6 100644 --- a/torchgeo/datasets/inria.py +++ b/torchgeo/datasets/inria.py @@ -7,7 +7,6 @@ import os import re from collections.abc import Callable -from typing import Any import matplotlib.pyplot as plt import numpy as np @@ -18,7 +17,13 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, check_integrity, extract_archive, percentile_normalization +from .utils import ( + Path, + Sample, + check_integrity, + extract_archive, + percentile_normalization, +) class InriaAerialImageLabeling(NonGeoDataset): @@ -61,7 +66,7 @@ def __init__( self, root: Path = 'data', split: str = 'train', - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, checksum: bool = False, ) -> None: """Initialize a new InriaAerialImageLabeling Dataset instance. @@ -158,7 +163,7 @@ def __len__(self) -> int: """ return len(self.files) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -169,7 +174,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: """ files = self.files[index] img = self._load_image(files['image']) - sample = {'image': img} + sample: Sample = {'image': img} if files.get('label'): mask = self._load_target(files['label']) sample['mask'] = mask @@ -194,10 +199,7 @@ def _verify(self) -> None: extract_archive(archive_path) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/iobench.py b/torchgeo/datasets/iobench.py index 608a9ccc17a..b97aa691c81 100644 --- a/torchgeo/datasets/iobench.py +++ b/torchgeo/datasets/iobench.py @@ -6,7 +6,7 @@ import glob import os from collections.abc import Callable, Sequence -from typing import Any, ClassVar +from typing import ClassVar import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import IntersectionDataset from .landsat import Landsat9 -from .utils import Path, download_url, extract_archive +from .utils import Path, Sample, download_url, extract_archive class IOBench(IntersectionDataset): @@ -56,7 +56,7 @@ def __init__( res: float | None = None, bands: Sequence[str] | None = [*Landsat9.default_bands, 'SR_QA_AEROSOL'], classes: list[int] = [0], - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -134,10 +134,7 @@ def _extract(self) -> None: extract_archive(os.path.join(self.root, f'{self.split}.tar.gz'), self.root) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/l7irish.py b/torchgeo/datasets/l7irish.py index d39f225ed75..ff15bc1dfe5 100644 --- a/torchgeo/datasets/l7irish.py +++ b/torchgeo/datasets/l7irish.py @@ -7,20 +7,20 @@ import os import re from collections.abc import Callable, Iterable, Sequence -from typing import Any, ClassVar, cast +from typing import ClassVar, cast import matplotlib.pyplot as plt import torch from matplotlib.figure import Figure from rasterio.crs import CRS from rtree.index import Index, Property -from torch import Tensor from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import IntersectionDataset, RasterDataset from .utils import ( BoundingBox, Path, + Sample, disambiguate_timestamp, download_url, extract_archive, @@ -71,7 +71,7 @@ def __init__( crs: CRS | None = None, res: float | None = None, bands: Sequence[str] | None = None, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, ) -> None: """Initialize a new L7IrishMask instance. @@ -102,7 +102,7 @@ def __init__( index.insert(hit.id, (minx, maxx, miny, maxy, mint, maxt), hit.object) self.index = index - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Retrieve image/mask and metadata indexed by query. Args: @@ -179,7 +179,7 @@ def __init__( crs: CRS | None = CRS.from_epsg(3857), res: float | None = None, bands: Sequence[str] = L7IrishImage.all_bands, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -267,10 +267,7 @@ def _extract(self) -> None: extract_archive(tarfile) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/l8biome.py b/torchgeo/datasets/l8biome.py index e53c403b713..64f941b4c29 100644 --- a/torchgeo/datasets/l8biome.py +++ b/torchgeo/datasets/l8biome.py @@ -6,17 +6,16 @@ import glob import os from collections.abc import Callable, Iterable, Sequence -from typing import Any, ClassVar +from typing import ClassVar import matplotlib.pyplot as plt import torch from matplotlib.figure import Figure from rasterio.crs import CRS -from torch import Tensor from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import IntersectionDataset, RasterDataset -from .utils import BoundingBox, Path, download_url, extract_archive +from .utils import BoundingBox, Path, Sample, download_url, extract_archive class L8BiomeImage(RasterDataset): @@ -63,7 +62,7 @@ class L8BiomeMask(RasterDataset): ordinal_map[192] = 3 ordinal_map[255] = 4 - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Retrieve image/mask and metadata indexed by query. Args: @@ -136,7 +135,7 @@ def __init__( crs: CRS | None = CRS.from_epsg(3857), res: float | None = None, bands: Sequence[str] = L8BiomeImage.all_bands, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -212,10 +211,7 @@ def _extract(self) -> None: extract_archive(tarfile) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index e0ff04755bf..9182bb6737e 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -23,10 +23,10 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset, RasterDataset -from .utils import BoundingBox, Path, download_url, extract_archive, working_dir +from .utils import BoundingBox, Path, Sample, download_url, extract_archive, working_dir -class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC): +class LandCoverAIBase(Dataset[Sample], abc.ABC): r"""Abstract base class for LandCover.ai Geo and NonGeo datasets. The `LandCover.ai `__ (Land Cover from @@ -120,7 +120,7 @@ def _verify(self) -> None: self._extract() @abc.abstractmethod - def __getitem__(self, query: Any) -> dict[str, Any]: + def __getitem__(self, query: Any) -> Sample: """Retrieve image, mask and metadata indexed by index. Args: @@ -146,10 +146,7 @@ def _extract(self) -> None: extract_archive(os.path.join(self.root, self.filename)) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. @@ -208,7 +205,7 @@ def __init__( root: Path = 'data', crs: CRS | None = None, res: float | None = None, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -241,7 +238,7 @@ def _verify_data(self) -> bool: masks = glob.glob(mask_query) return len(images) > 0 and len(images) == len(masks) - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Retrieve image/mask and metadata indexed by query. Args: @@ -266,7 +263,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: img = self._merge_files(img_filepaths, query, self.band_indexes) mask = self._merge_files(mask_filepaths, query, self.band_indexes) - sample = { + sample: Sample = { 'crs': self.crs, 'bounds': query, 'image': img.float(), @@ -298,7 +295,7 @@ def __init__( self, root: Path = 'data', split: str = 'train', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -325,7 +322,7 @@ def __init__( with open(os.path.join(self.root, split + '.txt')) as f: self.ids = f.readlines() - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -335,7 +332,10 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ id_ = self.ids[index].rstrip() - sample = {'image': self._load_image(id_), 'mask': self._load_target(id_)} + sample: Sample = { + 'image': self._load_image(id_), + 'mask': self._load_target(id_), + } if self.transforms is not None: sample = self.transforms(sample) diff --git a/torchgeo/datasets/landsat.py b/torchgeo/datasets/landsat.py index 8fb33b7c9cc..22585bb849b 100644 --- a/torchgeo/datasets/landsat.py +++ b/torchgeo/datasets/landsat.py @@ -5,7 +5,6 @@ import abc from collections.abc import Callable, Iterable, Sequence -from typing import Any import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -13,7 +12,7 @@ from .errors import RGBBandsMissingError from .geo import RasterDataset -from .utils import Path +from .utils import Path, Sample class Landsat(RasterDataset, abc.ABC): @@ -64,7 +63,7 @@ def __init__( crs: CRS | None = None, res: float | None = None, bands: Sequence[str] | None = None, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, ) -> None: """Initialize a new Dataset instance. @@ -92,10 +91,7 @@ def __init__( super().__init__(paths, crs, res, bands, transforms, cache) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index fdff569dc19..1dce218666a 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -18,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, download_and_extract_archive, percentile_normalization +from .utils import Path, Sample, download_and_extract_archive, percentile_normalization class LEVIRCDBase(NonGeoDataset, abc.ABC): @@ -34,7 +34,7 @@ def __init__( self, root: Path = 'data', split: str = 'train', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -67,7 +67,7 @@ def __init__( self.files = self._load_files(self.root, self.split) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -80,7 +80,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image1 = self._load_image(files['image1']) image2 = self._load_image(files['image2']) mask = self._load_target(files['mask']) - sample = {'image1': image1, 'image2': image2, 'mask': mask} + sample: Sample = {'image1': image1, 'image2': image2, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -130,10 +130,7 @@ def _load_target(self, path: Path) -> Tensor: return tensor def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/loveda.py b/torchgeo/datasets/loveda.py index 93b6b18e455..c2dc0718e96 100644 --- a/torchgeo/datasets/loveda.py +++ b/torchgeo/datasets/loveda.py @@ -17,7 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, download_and_extract_archive +from .utils import Path, Sample, download_and_extract_archive class LoveDA(NonGeoDataset): @@ -95,7 +95,7 @@ def __init__( root: Path = 'data', split: str = 'train', scene: Sequence[str] = ['urban', 'rural'], - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -143,7 +143,7 @@ def __init__( self.files = self._load_files(self.scene_paths, self.split) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -158,9 +158,9 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: if self.split != 'test': mask = self._load_target(files['mask']) - sample = {'image': image, 'mask': mask} + sample: Sample = {'image': image, 'mask': mask} else: - sample = {'image': image} + sample: Sample = {'image': image} if self.transforms is not None: sample = self.transforms(sample) @@ -256,7 +256,7 @@ def _download(self) -> None: md5=self.md5 if self.checksum else None, ) - def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure: + def plot(self, sample: Sample, suptitle: str | None = None) -> Figure: """Plot a sample from the dataset. Args: diff --git a/torchgeo/datasets/mapinwild.py b/torchgeo/datasets/mapinwild.py index cd294014318..70c85b4ecc3 100644 --- a/torchgeo/datasets/mapinwild.py +++ b/torchgeo/datasets/mapinwild.py @@ -21,6 +21,7 @@ from .geo import NonGeoDataset from .utils import ( Path, + Sample, check_integrity, download_url, extract_archive, @@ -116,7 +117,7 @@ def __init__( root: Path = 'data', modality: list[str] = ['mask', 'esa_wc', 'viirs', 's2_summer'], split: str = 'train', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -173,7 +174,7 @@ def __init__( self.ids = split_dataframe[split].dropna().values.tolist() self.ids = list(map(int, self.ids)) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -195,7 +196,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image = torch.cat(list_modalities, dim=0) - sample: dict[str, Tensor] = {'image': image, 'mask': mask} + sample: Sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -332,10 +333,7 @@ def _convert_to_color( return arr_3d def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py index 14e0ba7dc12..73d83ded9bc 100644 --- a/torchgeo/datasets/millionaid.py +++ b/torchgeo/datasets/millionaid.py @@ -6,7 +6,7 @@ import glob import os from collections.abc import Callable -from typing import Any, ClassVar, cast +from typing import ClassVar, cast import matplotlib.pyplot as plt import numpy as np @@ -17,7 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, check_integrity, extract_archive +from .utils import Path, Sample, check_integrity, extract_archive class MillionAID(NonGeoDataset): @@ -193,7 +193,7 @@ def __init__( root: Path = 'data', task: str = 'multi-class', split: str = 'train', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, checksum: bool = False, ) -> None: """Initialize a new MillionAID dataset instance. @@ -232,7 +232,7 @@ def __len__(self) -> int: """ return len(self.files) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -245,14 +245,14 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image = self._load_image(files['image']) cls_label = [self.class_to_idx[label] for label in files['label']] label = torch.tensor(cls_label, dtype=torch.long) - sample = {'image': image, 'label': label} + sample: Sample = {'image': image, 'label': label} if self.transforms is not None: sample = self.transforms(sample) return sample - def _load_files(self, root: Path) -> list[dict[str, Any]]: + def _load_files(self, root: Path) -> list[Sample]: """Return the paths of the files in the dataset. Args: @@ -331,10 +331,7 @@ def _verify(self) -> None: raise DatasetNotFoundError(self) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/naip.py b/torchgeo/datasets/naip.py index 326dccd6d72..48fd07cb629 100644 --- a/torchgeo/datasets/naip.py +++ b/torchgeo/datasets/naip.py @@ -3,12 +3,11 @@ """National Agriculture Imagery Program (NAIP) dataset.""" -from typing import Any - import matplotlib.pyplot as plt from matplotlib.figure import Figure from .geo import RasterDataset +from .utils import Sample class NAIP(RasterDataset): @@ -49,10 +48,7 @@ class NAIP(RasterDataset): rgb_bands = ('R', 'G', 'B') def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index 9c57e290406..4c78ed73e01 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -12,12 +12,11 @@ import rasterio import torch from matplotlib.figure import Figure -from torch import Tensor from torchvision.utils import draw_bounding_boxes from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, which +from .utils import Path, Sample, which class NASAMarineDebris(NonGeoDataset): @@ -59,7 +58,7 @@ class NASAMarineDebris(NonGeoDataset): def __init__( self, root: Path = 'data', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, ) -> None: """Initialize a new NASA Marine Debris Dataset instance. @@ -82,7 +81,7 @@ def __init__( self.source = sorted(glob.glob(os.path.join(self.root, 'source', '*.tif'))) self.labels = sorted(glob.glob(os.path.join(self.root, 'labels', '*.npy'))) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -105,7 +104,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: indices = w_check & h_check boxes = boxes[indices] - sample = {'image': image, 'boxes': boxes} + sample: Sample = {'image': image, 'boxes': boxes} if self.transforms is not None: sample = self.transforms(sample) @@ -142,10 +141,7 @@ def _download(self) -> None: azcopy('sync', self.url, self.root, '--recursive=true') def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/nccm.py b/torchgeo/datasets/nccm.py index 96633b2e35b..2d38e7d3e8d 100644 --- a/torchgeo/datasets/nccm.py +++ b/torchgeo/datasets/nccm.py @@ -4,7 +4,7 @@ """Northeastern China Crop Map Dataset.""" from collections.abc import Callable, Iterable -from typing import Any, ClassVar +from typing import ClassVar import matplotlib.pyplot as plt import torch @@ -13,7 +13,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import BoundingBox, Path, download_url +from .utils import BoundingBox, Path, Sample, download_url class NCCM(RasterDataset): @@ -87,7 +87,7 @@ def __init__( crs: CRS | None = None, res: float | None = None, years: list[int] = [2019], - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -128,7 +128,7 @@ def __init__( self.ordinal_map[k] = i self.ordinal_cmap[i] = torch.tensor(v) - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Retrieve mask and metadata indexed by query. Args: @@ -168,10 +168,7 @@ def _download(self) -> None: ) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/nlcd.py b/torchgeo/datasets/nlcd.py index ab0e83c96b1..4864531a357 100644 --- a/torchgeo/datasets/nlcd.py +++ b/torchgeo/datasets/nlcd.py @@ -6,7 +6,7 @@ import glob import os from collections.abc import Callable, Iterable -from typing import Any, ClassVar +from typing import ClassVar import matplotlib.pyplot as plt import torch @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import BoundingBox, Path, download_url, extract_archive +from .utils import BoundingBox, Path, Sample, download_url, extract_archive class NLCD(RasterDataset): @@ -113,7 +113,7 @@ def __init__( res: float | None = None, years: list[int] = [2019], classes: list[int] = list(cmap.keys()), - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -165,7 +165,7 @@ def __init__( self.ordinal_map[k] = v self.ordinal_cmap[v] = torch.tensor(self.cmap[k]) - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Retrieve mask and metadata indexed by query. Args: @@ -228,10 +228,7 @@ def _extract(self) -> None: extract_archive(glob.glob(pathname, recursive=True)[0], self.paths) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/openbuildings.py b/torchgeo/datasets/openbuildings.py index 292dc274c32..42e6511b5e4 100644 --- a/torchgeo/datasets/openbuildings.py +++ b/torchgeo/datasets/openbuildings.py @@ -8,7 +8,7 @@ import os import sys from collections.abc import Callable, Iterable -from typing import Any, ClassVar, cast +from typing import ClassVar, cast import fiona import fiona.transform @@ -24,7 +24,7 @@ from .errors import DatasetNotFoundError from .geo import VectorDataset -from .utils import BoundingBox, Path, check_integrity +from .utils import BoundingBox, Path, Sample, check_integrity class OpenBuildings(VectorDataset): @@ -210,7 +210,7 @@ def __init__( paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float = 0.0001, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, checksum: bool = False, ) -> None: """Initialize a new Dataset instance. @@ -290,7 +290,7 @@ def __init__( self._crs = crs self._source_crs = source_crs - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Retrieve image/mask and metadata indexed by query. Args: @@ -327,7 +327,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: else: masks = torch.zeros(size=(1, round(height), round(width))) - sample = {'mask': masks, 'crs': self.crs, 'bounds': query} + sample: Sample = {'mask': masks, 'crs': self.crs, 'bounds': query} if self.transforms is not None: sample = self.transforms(sample) @@ -336,7 +336,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: def _filter_geometries( self, query: BoundingBox, filepaths: list[str] - ) -> list[dict[str, Any]]: + ) -> list[Sample]: """Filters a df read from the polygon csv file based on query and conf thresh. Args: @@ -369,7 +369,7 @@ def _filter_geometries( return shapes - def _wkt_fiona_geom_transform(self, x: str) -> dict[str, Any]: + def _wkt_fiona_geom_transform(self, x: str) -> Sample: """Function to transform a geometry string into new crs. Args: @@ -389,7 +389,7 @@ def _wkt_fiona_geom_transform(self, x: str) -> dict[str, Any]: geom = fiona.model.Geometry(**x) else: geom = x - transformed: dict[str, Any] = fiona.transform.transform_geom( + transformed: Sample = fiona.transform.transform_geom( self._source_crs.to_dict(), self._crs.to_dict(), geom ) return transformed @@ -412,10 +412,7 @@ def _verify(self) -> None: raise DatasetNotFoundError(self) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index 28f7714a7c6..b6e97d00d04 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -19,6 +19,7 @@ from .geo import NonGeoDataset from .utils import ( Path, + Sample, download_url, draw_semantic_segmentation_masks, extract_archive, @@ -103,7 +104,7 @@ def __init__( root: Path = 'data', split: str = 'train', bands: Sequence[str] = all_bands, - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -137,7 +138,7 @@ def __init__( self.files = self._load_files() - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -150,7 +151,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image1 = self._load_image(files['images1']) image2 = self._load_image(files['images2']) mask = self._load_target(str(files['mask'])) - sample = {'image1': image1, 'image2': image2, 'mask': mask} + sample: Sample = {'image1': image1, 'image2': image2, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -283,7 +284,7 @@ def _extract(self) -> None: def plot( self, - sample: dict[str, Tensor], + sample: Sample, show_titles: bool = True, suptitle: str | None = None, alpha: float = 0.5, diff --git a/torchgeo/datasets/pastis.py b/torchgeo/datasets/pastis.py index 06f716a9ffb..0ce455ab079 100644 --- a/torchgeo/datasets/pastis.py +++ b/torchgeo/datasets/pastis.py @@ -17,7 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, check_integrity, download_url, extract_archive +from .utils import Path, Sample, check_integrity, download_url, extract_archive class PASTIS(NonGeoDataset): @@ -133,7 +133,7 @@ def __init__( folds: Sequence[int] = (1, 2, 3, 4, 5), bands: str = 's2', mode: str = 'semantic', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -179,7 +179,7 @@ def __init__( ) self._cmap = ListedColormap(colors) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -191,10 +191,15 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image = self._load_image(index) if self.mode == 'semantic': mask = self._load_semantic_targets(index) - sample = {'image': image, 'mask': mask} + sample: Sample = {'image': image, 'mask': mask} elif self.mode == 'instance': mask, boxes, labels = self._load_instance_targets(index) - sample = {'image': image, 'mask': mask, 'boxes': boxes, 'label': labels} + sample: Sample = { + 'image': image, + 'mask': mask, + 'boxes': boxes, + 'label': labels, + } if self.transforms is not None: sample = self.transforms(sample) @@ -345,10 +350,7 @@ def _download(self) -> None: extract_archive(os.path.join(self.root, self.filename), self.root) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/patternnet.py b/torchgeo/datasets/patternnet.py index a9385049872..4718cdad3ef 100644 --- a/torchgeo/datasets/patternnet.py +++ b/torchgeo/datasets/patternnet.py @@ -9,11 +9,10 @@ import matplotlib.pyplot as plt from matplotlib.figure import Figure -from torch import Tensor from .errors import DatasetNotFoundError from .geo import NonGeoClassificationDataset -from .utils import Path, download_url, extract_archive +from .utils import Path, Sample, download_url, extract_archive class PatternNet(NonGeoClassificationDataset): @@ -86,7 +85,7 @@ class PatternNet(NonGeoClassificationDataset): def __init__( self, root: Path = 'data', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -144,10 +143,7 @@ def _extract(self) -> None: extract_archive(filepath) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/potsdam.py b/torchgeo/datasets/potsdam.py index 51f1ebd0441..a623b98e8a8 100644 --- a/torchgeo/datasets/potsdam.py +++ b/torchgeo/datasets/potsdam.py @@ -19,6 +19,7 @@ from .geo import NonGeoDataset from .utils import ( Path, + Sample, check_integrity, draw_semantic_segmentation_masks, extract_archive, @@ -125,7 +126,7 @@ def __init__( self, root: Path = 'data', split: str = 'train', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, checksum: bool = False, ) -> None: """Initialize a new Potsdam dataset instance. @@ -156,7 +157,7 @@ def __init__( if os.path.exists(image) and os.path.exists(mask): self.files.append(dict(image=image, mask=mask)) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -167,7 +168,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: """ image = self._load_image(index) mask = self._load_target(index) - sample = {'image': image, 'mask': mask} + sample: Sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -240,7 +241,7 @@ def _verify(self) -> None: def plot( self, - sample: dict[str, Tensor], + sample: Sample, show_titles: bool = True, suptitle: str | None = None, alpha: float = 0.5, diff --git a/torchgeo/datasets/prisma.py b/torchgeo/datasets/prisma.py index c2e6e66b598..fde51d3b900 100644 --- a/torchgeo/datasets/prisma.py +++ b/torchgeo/datasets/prisma.py @@ -3,13 +3,11 @@ """PRISMA datasets.""" -from typing import Any - import matplotlib.pyplot as plt from matplotlib.figure import Figure from .geo import RasterDataset -from .utils import percentile_normalization +from .utils import Sample, percentile_normalization class PRISMA(RasterDataset): @@ -78,10 +76,7 @@ class PRISMA(RasterDataset): date_format = '%Y%m%d%H%M%S' def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index 811d79cff08..8008e616617 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable -from typing import Any, ClassVar, cast +from typing import ClassVar, cast import matplotlib.pyplot as plt import numpy as np @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, download_url, lazy_import, percentile_normalization +from .utils import Path, Sample, download_url, lazy_import, percentile_normalization class QuakeSet(NonGeoDataset): @@ -72,7 +72,7 @@ def __init__( self, root: Path = 'data', split: str = 'train', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -104,7 +104,7 @@ def __init__( self._verify() self.data = self._load_data() - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -117,7 +117,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: label = torch.tensor(self.data[index]['label']) magnitude = torch.tensor(self.data[index]['magnitude']) - sample = {'image': image, 'label': label, 'magnitude': magnitude} + sample: Sample = {'image': image, 'label': label, 'magnitude': magnitude} if self.transforms is not None: sample = self.transforms(sample) @@ -132,7 +132,7 @@ def __len__(self) -> int: """ return len(self.data) - def _load_data(self) -> list[dict[str, Any]]: + def _load_data(self) -> list[Sample]: """Return the metadata for a given split. Returns: @@ -224,10 +224,7 @@ def _download(self) -> None: ) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/reforestree.py b/torchgeo/datasets/reforestree.py index 1c46c450191..91f03080c08 100644 --- a/torchgeo/datasets/reforestree.py +++ b/torchgeo/datasets/reforestree.py @@ -18,7 +18,13 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, check_integrity, download_and_extract_archive, extract_archive +from .utils import ( + Path, + Sample, + check_integrity, + download_and_extract_archive, + extract_archive, +) class ReforesTree(NonGeoDataset): @@ -65,7 +71,7 @@ class ReforesTree(NonGeoDataset): def __init__( self, root: Path = 'data', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -94,7 +100,7 @@ def __init__( self.class2idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)} - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -109,7 +115,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: boxes, labels, agb = self._load_target(filepath) - sample = {'image': image, 'boxes': boxes, 'label': labels, 'agb': agb} + sample: Sample = {'image': image, 'boxes': boxes, 'label': labels, 'agb': agb} if self.transforms is not None: sample = self.transforms(sample) @@ -202,10 +208,7 @@ def _download(self) -> None: ) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/resisc45.py b/torchgeo/datasets/resisc45.py index fd33b634fde..5ee6987f23d 100644 --- a/torchgeo/datasets/resisc45.py +++ b/torchgeo/datasets/resisc45.py @@ -10,11 +10,10 @@ import matplotlib.pyplot as plt import numpy as np from matplotlib.figure import Figure -from torch import Tensor from .errors import DatasetNotFoundError from .geo import NonGeoClassificationDataset -from .utils import Path, download_url, extract_archive +from .utils import Path, Sample, download_url, extract_archive class RESISC45(NonGeoClassificationDataset): @@ -114,7 +113,7 @@ def __init__( self, root: Path = 'data', split: str = 'train', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -194,10 +193,7 @@ def _extract(self) -> None: extract_archive(filepath) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/rwanda_field_boundary.py b/torchgeo/datasets/rwanda_field_boundary.py index 48968f007f5..a8fa69b0853 100644 --- a/torchgeo/datasets/rwanda_field_boundary.py +++ b/torchgeo/datasets/rwanda_field_boundary.py @@ -15,11 +15,10 @@ import torch from einops import rearrange from matplotlib.figure import Figure -from torch import Tensor from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import Path, which +from .utils import Path, Sample, which class RwandaFieldBoundary(NonGeoDataset): @@ -68,7 +67,7 @@ def __init__( root: Path = 'data', split: str = 'train', bands: Sequence[str] = all_bands, - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, ) -> None: """Initialize a new RwandaFieldBoundary instance. @@ -104,7 +103,7 @@ def __len__(self) -> int: """ return self.splits[self.split] - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -121,7 +120,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: with rasterio.open(os.path.join(path, f'{index:02}_{band}.tif')) as src: patches.append(src.read(1).astype(np.float32)) images.append(patches) - sample = {'image': torch.from_numpy(np.array(images))} + sample: Sample = {'image': torch.from_numpy(np.array(images))} if self.split == 'train': path = os.path.join(self.root, 'labels', self.split) @@ -156,7 +155,7 @@ def _download(self) -> None: def plot( self, - sample: dict[str, Tensor], + sample: Sample, show_titles: bool = True, time_step: int = 0, suptitle: str | None = None, diff --git a/torchgeo/datasets/seasonet.py b/torchgeo/datasets/seasonet.py index 3e47a8ec491..e4784d95fdc 100644 --- a/torchgeo/datasets/seasonet.py +++ b/torchgeo/datasets/seasonet.py @@ -21,7 +21,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import Path, download_url, extract_archive, percentile_normalization +from .utils import Path, Sample, download_url, extract_archive, percentile_normalization class SeasoNet(NonGeoDataset): @@ -219,7 +219,7 @@ def __init__( bands: Iterable[str] = all_bands, grids: Iterable[int] = [1, 2], concat_seasons: int = 1, - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -292,7 +292,7 @@ def __init__( else: self.files = csv['Path'] - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -304,7 +304,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: """ image = self._load_image(index) mask = self._load_target(index) - sample = {'image': image, 'mask': mask} + sample: Sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -402,7 +402,7 @@ def _verify(self) -> None: def plot( self, - sample: dict[str, Tensor], + sample: Sample, show_titles: bool = True, show_legend: bool = True, suptitle: str | None = None, diff --git a/torchgeo/datasets/seco.py b/torchgeo/datasets/seco.py index c67fecb9c8e..7736102bc8d 100644 --- a/torchgeo/datasets/seco.py +++ b/torchgeo/datasets/seco.py @@ -18,7 +18,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import Path, download_url, extract_archive, percentile_normalization +from .utils import Path, Sample, download_url, extract_archive, percentile_normalization class SeasonalContrastS2(NonGeoDataset): @@ -75,7 +75,7 @@ def __init__( version: str = '100k', seasons: int = 1, bands: Sequence[str] = rgb_bands, - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -113,7 +113,7 @@ def __init__( self._verify() - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -133,7 +133,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: images = [self._load_patch(root, subdir) for subdir in subdirs] - sample = {'image': torch.cat(images)} + sample: Sample = {'image': torch.cat(images)} if self.transforms is not None: sample = self.transforms(sample) @@ -222,10 +222,7 @@ def _extract(self) -> None: ) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/sen12ms.py b/torchgeo/datasets/sen12ms.py index 20183f06421..2322bbe99ac 100644 --- a/torchgeo/datasets/sen12ms.py +++ b/torchgeo/datasets/sen12ms.py @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import Path, check_integrity, percentile_normalization +from .utils import Path, Sample, check_integrity, percentile_normalization class SEN12MS(NonGeoDataset): @@ -169,7 +169,7 @@ def __init__( root: Path = 'data', split: str = 'train', bands: Sequence[str] = BAND_SETS['all'], - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, checksum: bool = False, ) -> None: """Initialize a new SEN12MS dataset instance. @@ -213,7 +213,7 @@ def __init__( with open(os.path.join(self.root, split + '_list.txt')) as f: self.ids = [line.rstrip() for line in f.readlines()] - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -231,7 +231,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image = torch.cat(tensors=[s1, s2], dim=0) image = torch.index_select(image, dim=0, index=self.band_indices) - sample: dict[str, Tensor] = {'image': image, 'mask': lc[0]} + sample: Sample = {'image': image, 'mask': lc[0]} if self.transforms is not None: sample = self.transforms(sample) @@ -313,10 +313,7 @@ def _check_integrity(self) -> bool: return True def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/sentinel.py b/torchgeo/datasets/sentinel.py index 79637931adb..61c79811b54 100644 --- a/torchgeo/datasets/sentinel.py +++ b/torchgeo/datasets/sentinel.py @@ -4,7 +4,6 @@ """Sentinel datasets.""" from collections.abc import Callable, Iterable, Sequence -from typing import Any import matplotlib.pyplot as plt import torch @@ -13,7 +12,7 @@ from .errors import RGBBandsMissingError from .geo import RasterDataset -from .utils import Path +from .utils import Path, Sample class Sentinel(RasterDataset): @@ -146,7 +145,7 @@ def __init__( crs: CRS | None = None, res: float = 10, bands: Sequence[str] = ['VV', 'VH'], - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, ) -> None: """Initialize a new Dataset instance. @@ -192,10 +191,7 @@ def __init__( super().__init__(paths, crs, res, bands, transforms, cache) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. @@ -302,7 +298,7 @@ def __init__( crs: CRS | None = None, res: float = 10, bands: Sequence[str] | None = None, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, ) -> None: """Initialize a new Dataset instance. @@ -331,10 +327,7 @@ def __init__( super().__init__(paths, crs, res, bands, transforms, cache) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py index 3accf32d2af..4516fb937ad 100644 --- a/torchgeo/datasets/skippd.py +++ b/torchgeo/datasets/skippd.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable -from typing import Any, ClassVar +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, download_url, extract_archive, lazy_import +from .utils import Path, Sample, download_url, extract_archive, lazy_import class SKIPPD(NonGeoDataset): @@ -82,7 +82,7 @@ def __init__( root: Path = 'data', split: str = 'trainval', task: str = 'nowcast', - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -237,10 +237,7 @@ def _extract(self) -> None: extract_archive(zipfile_path, self.root) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index 4840a48e468..f6bcf9e990a 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -11,11 +11,10 @@ import numpy as np import torch from matplotlib.figure import Figure -from torch import Tensor from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import Path, check_integrity, lazy_import, percentile_normalization +from .utils import Path, Sample, check_integrity, lazy_import, percentile_normalization class So2Sat(NonGeoDataset): @@ -198,7 +197,7 @@ def __init__( version: str = '2', split: str = 'train', bands: Sequence[str] = BAND_SETS['all'], - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, checksum: bool = False, ) -> None: """Initialize a new So2Sat dataset instance. @@ -266,7 +265,7 @@ def __init__( with h5py.File(self.fn, 'r') as f: self.size: int = f['label'].shape[0] - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -291,7 +290,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: s1 = torch.from_numpy(s1) s2 = torch.from_numpy(s2) - sample = {'image': torch.cat([s1, s2]).float(), 'label': label} + sample: Sample = {'image': torch.cat([s1, s2]).float(), 'label': label} if self.transforms is not None: sample = self.transforms(sample) @@ -335,10 +334,7 @@ def _validate_bands(self, bands: Sequence[str]) -> None: raise ValueError(f"'{band}' is an invalid band name.") def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/south_africa_crop_type.py b/torchgeo/datasets/south_africa_crop_type.py index a8643873c5b..34ffebf6a84 100644 --- a/torchgeo/datasets/south_africa_crop_type.py +++ b/torchgeo/datasets/south_africa_crop_type.py @@ -6,7 +6,7 @@ import os import re from collections.abc import Callable, Iterable, Sequence -from typing import Any, ClassVar, cast +from typing import ClassVar, cast import matplotlib.pyplot as plt import torch @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import RasterDataset -from .utils import BoundingBox, Path, which +from .utils import BoundingBox, Path, Sample, which class SouthAfricaCropType(RasterDataset): @@ -114,7 +114,7 @@ def __init__( crs: CRS | None = None, classes: Sequence[int] = list(cmap.keys()), bands: Sequence[str] = s2_bands, - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, ) -> None: """Initialize a new South Africa Crop Type dataset instance. @@ -151,7 +151,7 @@ def __init__( self.ordinal_map[k] = v self.ordinal_cmap[v] = torch.tensor(self.cmap[k]) - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: """Return an index within the dataset. Args: @@ -225,7 +225,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: mask = self._merge_files(mask_filepaths, query).squeeze(0) - sample = { + sample: Sample = { 'crs': self.crs, 'bounds': query, 'image': image.float(), @@ -258,10 +258,7 @@ def _download(self) -> None: azcopy('sync', f'{self.url}', self.paths, '--recursive=true') def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index adbde74d6cb..300a5b1633d 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable, Iterable -from typing import Any, ClassVar +from typing import ClassVar import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -13,7 +13,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import Path, download_url +from .utils import Path, Sample, download_url class SouthAmericaSoybean(RasterDataset): @@ -77,7 +77,7 @@ def __init__( crs: CRS | None = None, res: float | None = None, years: list[int] = [2021], - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -132,10 +132,7 @@ def _download(self) -> None: ) def plot( - self, - sample: dict[str, Any], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index 21a31657f58..2048e59dd6f 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -28,6 +28,7 @@ from .geo import NonGeoDataset from .utils import ( Path, + Sample, check_integrity, extract_archive, percentile_normalization, @@ -108,7 +109,7 @@ def __init__( aois: list[int] = [], image: str | None = None, mask: str | None = None, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -217,7 +218,7 @@ def _load_mask( return torch.from_numpy(mask) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -229,7 +230,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image_path = self.images[index] img, tfm, raster_crs = self._load_image(image_path) h, w = img.shape[1:] - sample = {'image': img} + sample: Sample = {'image': img} if self.split == 'train': mask_path = self.masks[index] @@ -339,10 +340,7 @@ def _verify(self) -> None: self.masks.extend(masks) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/ssl4eo.py b/torchgeo/datasets/ssl4eo.py index b2840afb865..dcc6a8a3705 100644 --- a/torchgeo/datasets/ssl4eo.py +++ b/torchgeo/datasets/ssl4eo.py @@ -14,11 +14,10 @@ import rasterio import torch from matplotlib.figure import Figure -from torch import Tensor from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, check_integrity, download_url, extract_archive +from .utils import Path, Sample, check_integrity, download_url, extract_archive class SSL4EO(NonGeoDataset): @@ -165,7 +164,7 @@ def __init__( root: Path = 'data', split: str = 'oli_sr', seasons: int = 1, - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -199,7 +198,7 @@ def __init__( self.scenes = sorted(os.listdir(self.subdir)) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -220,7 +219,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image = f.read() images.append(torch.from_numpy(image.astype(np.float32))) - sample = {'image': torch.cat(images)} + sample: Sample = {'image': torch.cat(images)} if self.transforms is not None: sample = self.transforms(sample) @@ -284,10 +283,7 @@ def _extract(self) -> None: extract_archive(path) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. @@ -407,7 +403,7 @@ def __init__( root: Path = 'data', split: str = 's2c', seasons: int = 1, - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, checksum: bool = False, ) -> None: """Initialize a new SSL4EOS12 instance. @@ -439,7 +435,7 @@ def __init__( self._verify() - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -461,7 +457,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image = f.read(out_shape=(1, self.size, self.size)) images.append(torch.from_numpy(image.astype(np.float32))) - sample = {'image': torch.cat(images)} + sample: Sample = {'image': torch.cat(images)} if self.transforms is not None: sample = self.transforms(sample) @@ -499,10 +495,7 @@ def _extract(self) -> None: extract_archive(os.path.join(self.root, filename)) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/ssl4eo_benchmark.py b/torchgeo/datasets/ssl4eo_benchmark.py index 13c5a8474c4..f105ea71d31 100644 --- a/torchgeo/datasets/ssl4eo_benchmark.py +++ b/torchgeo/datasets/ssl4eo_benchmark.py @@ -19,7 +19,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset from .nlcd import NLCD -from .utils import Path, download_url, extract_archive +from .utils import Path, Sample, download_url, extract_archive class SSL4EOLBenchmark(NonGeoDataset): @@ -116,7 +116,7 @@ def __init__( product: str = 'cdl', split: str = 'train', classes: list[int] | None = None, - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -257,7 +257,7 @@ def _extract(self) -> None: mask_pathname = os.path.join(self.root, f'{self.mask_dir_name}.tar.gz') extract_archive(mask_pathname) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -268,7 +268,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: """ img_path, mask_path = self.sample_collection[index] - sample = { + sample: Sample = { 'image': self._load_image(img_path), 'mask': self._load_mask(mask_path), } @@ -329,10 +329,7 @@ def _load_mask(self, path: Path) -> Tensor: return mask def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py index eec9be57ab3..6dc041540fa 100644 --- a/torchgeo/datasets/sustainbench_crop_yield.py +++ b/torchgeo/datasets/sustainbench_crop_yield.py @@ -5,17 +5,15 @@ import os from collections.abc import Callable -from typing import Any import matplotlib.pyplot as plt import numpy as np import torch from matplotlib.figure import Figure -from torch import Tensor from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, download_url, extract_archive +from .utils import Path, Sample, download_url, extract_archive class SustainBenchCropYield(NonGeoDataset): @@ -62,7 +60,7 @@ def __init__( root: Path = 'data', split: str = 'train', countries: list[str] = ['usa'], - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -139,7 +137,7 @@ def __len__(self) -> int: """ return len(self.images) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -148,7 +146,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Returns: data and label at that index """ - sample: dict[str, Tensor] = {'image': self.images[index]} + sample: Sample = {'image': self.images[index]} sample.update(self.features[index]) if self.transforms is not None: @@ -194,7 +192,7 @@ def _extract(self) -> None: def plot( self, - sample: dict[str, Any], + sample: Sample, band_idx: int = 0, show_titles: bool = True, suptitle: str | None = None, diff --git a/torchgeo/datasets/ucmerced.py b/torchgeo/datasets/ucmerced.py index 5527a7ed133..d6250ce3153 100644 --- a/torchgeo/datasets/ucmerced.py +++ b/torchgeo/datasets/ucmerced.py @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoClassificationDataset -from .utils import Path, check_integrity, download_url, extract_archive +from .utils import Path, Sample, check_integrity, download_url, extract_archive class UCMerced(NonGeoClassificationDataset): @@ -88,7 +88,7 @@ def __init__( self, root: Path = 'data', split: str = 'train', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -192,10 +192,7 @@ def _extract(self) -> None: extract_archive(filepath) def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index db13059bfa7..39a494f474f 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -18,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, download_url, extract_archive +from .utils import Path, Sample, download_url, extract_archive class USAVars(NonGeoDataset): @@ -90,7 +90,7 @@ def __init__( root: Path = 'data', split: str = 'train', labels: Sequence[str] = ALL_LABELS, - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -131,7 +131,7 @@ def __init__( for lab in self.labels } - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -143,7 +143,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: tif_file = self.files[index] id_ = tif_file[5:-4] - sample = { + sample: Sample = { 'labels': Tensor( [self.label_dfs[lab].loc[id_][lab] for lab in self.labels] ), @@ -228,10 +228,7 @@ def _extract(self) -> None: extract_archive(os.path.join(self.root, self.dirname + '.zip')) def plot( - self, - sample: dict[str, Tensor], - show_labels: bool = True, - suptitle: str | None = None, + self, sample: Sample, show_labels: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index e668e749840..ee3f934d8c9 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -13,14 +13,15 @@ import shutil import subprocess import sys -from collections.abc import Iterable, Iterator, Mapping, MutableMapping, Sequence +from collections.abc import Iterable, Iterator, Mapping, Sequence from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, TypeAlias, cast, overload +from typing import Any, TypeAlias, TypedDict, cast, overload import numpy as np import rasterio import torch +from rasterio.crs import CRS from torch import Tensor from torchvision.datasets.utils import ( check_integrity, @@ -44,6 +45,39 @@ Path: TypeAlias = str | os.PathLike[str] +class Sample(TypedDict, total=False): + """A single dataset sample. + + .. versionadded:: 0.6 + """ + + # Designed to match kornia.constants.DataKey + image: Tensor + mask: Tensor + bbox: Tensor + bbox_xyxy: Tensor + bbox_xywh: Tensor + keypoints: Tensor + label: Tensor + + # Additional common keys for TorchGeo datasets + prediction: Tensor + bounds: BoundingBox + crs: CRS + + # TODO: remove + boxes: Tensor + + +class Batch(Sample): + """A batch of samples. + + .. versionadded:: 0.6 + """ + + # For now, identical to Sample until we can type check tensor shapes + + @dataclass(frozen=True) class BoundingBox: """Data class for indexing spatiotemporal data.""" @@ -407,7 +441,7 @@ def _dict_list_to_list_dict( return uncollated -def stack_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: +def stack_samples(samples: Iterable[Sample]) -> Batch: """Stack a list of samples along a new axis. Useful for forming a mini-batch of samples to pass to @@ -428,7 +462,7 @@ def stack_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: return collated -def concat_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: +def concat_samples(samples: Iterable[Sample]) -> Batch: """Concatenate a list of samples along an existing axis. Useful for joining samples in a :class:`torchgeo.datasets.IntersectionDataset`. @@ -450,7 +484,7 @@ def concat_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: return collated -def merge_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: +def merge_samples(samples: Iterable[Sample]) -> Batch: """Merge a list of samples. Useful for joining samples in a :class:`torchgeo.datasets.UnionDataset`. @@ -475,24 +509,24 @@ def merge_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: return collated -def unbind_samples(sample: MutableMapping[Any, Any]) -> list[dict[Any, Any]]: +def unbind_samples(batch: Batch) -> list[Sample]: """Reverse of :func:`stack_samples`. Useful for turning a mini-batch of samples into a list of samples. These individual samples can then be plotted using a dataset's ``plot`` method. Args: - sample: a mini-batch of samples + batch: a mini-batch of samples Returns: list of samples .. versionadded:: 0.2 """ - for key, values in sample.items(): + for key, values in batch.items(): if isinstance(values, Tensor): - sample[key] = torch.unbind(values) - return _dict_list_to_list_dict(sample) + batch[key] = torch.unbind(values) + return _dict_list_to_list_dict(batch) def rasterio_loader(path: Path) -> np.typing.NDArray[np.int_]: diff --git a/torchgeo/datasets/vaihingen.py b/torchgeo/datasets/vaihingen.py index 2c671ca27ac..fe466f50476 100644 --- a/torchgeo/datasets/vaihingen.py +++ b/torchgeo/datasets/vaihingen.py @@ -18,6 +18,7 @@ from .geo import NonGeoDataset from .utils import ( Path, + Sample, check_integrity, draw_semantic_segmentation_masks, extract_archive, @@ -124,7 +125,7 @@ def __init__( self, root: Path = 'data', split: str = 'train', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, checksum: bool = False, ) -> None: """Initialize a new Vaihingen2D dataset instance. @@ -155,7 +156,7 @@ def __init__( if os.path.exists(image) and os.path.exists(mask): self.files.append(dict(image=image, mask=mask)) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -166,7 +167,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: """ image = self._load_image(index) mask = self._load_target(index) - sample = {'image': image, 'mask': mask} + sample: Sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -241,7 +242,7 @@ def _verify(self) -> None: def plot( self, - sample: dict[str, Tensor], + sample: Sample, show_titles: bool = True, suptitle: str | None = None, alpha: float = 0.5, diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index 9adc2f44e9e..e59e97d8545 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable -from typing import Any, ClassVar +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -19,6 +19,7 @@ from .geo import NonGeoDataset from .utils import ( Path, + Sample, check_integrity, download_and_extract_archive, download_url, @@ -62,7 +63,7 @@ class ConvertCocoAnnotations: https://github.com/pytorch/vision/blob/v0.14.0/references/detection/coco_utils.py """ - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + def __call__(self, sample: Sample) -> Sample: """Converts MS COCO fields (boxes, masks & labels) from list of ints to tensors. Args: @@ -187,7 +188,7 @@ def __init__( self, root: Path = 'data', split: str = 'positive', - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -230,7 +231,7 @@ def __init__( self.coco_convert = ConvertCocoAnnotations() self.ids = list(sorted(self.coco.imgs.keys())) - def __getitem__(self, index: int) -> dict[str, Any]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -241,7 +242,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: """ id_ = index % len(self) + 1 - sample: dict[str, Any] = { + sample: Sample = { 'image': self._load_image(id_), 'label': self._load_target(id_), } @@ -292,7 +293,7 @@ def _load_image(self, id_: int) -> Tensor: tensor = tensor.permute((2, 0, 1)) return tensor - def _load_target(self, id_: int) -> dict[str, Any]: + def _load_target(self, id_: int) -> Sample: """Load the annotations for a single image. Args: @@ -359,7 +360,7 @@ def _download(self) -> None: def plot( self, - sample: dict[str, Tensor], + sample: Sample, show_titles: bool = True, suptitle: str | None = None, show_feats: str | None = 'both', diff --git a/torchgeo/datasets/western_usa_live_fuel_moisture.py b/torchgeo/datasets/western_usa_live_fuel_moisture.py index fe51f6ade8f..a35b3ae0335 100644 --- a/torchgeo/datasets/western_usa_live_fuel_moisture.py +++ b/torchgeo/datasets/western_usa_live_fuel_moisture.py @@ -7,14 +7,13 @@ import json import os from collections.abc import Callable, Iterable -from typing import Any import pandas as pd import torch from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import Path, which +from .utils import Path, Sample, which class WesternUSALiveFuelMoisture(NonGeoDataset): @@ -199,7 +198,7 @@ def __init__( self, root: Path = 'data', input_features: Iterable[str] = all_variable_names, - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, ) -> None: """Initialize a new Western USA Live Fuel Moisture Dataset. @@ -234,7 +233,7 @@ def __len__(self) -> int: """ return len(self.dataframe) - def __getitem__(self, index: int) -> dict[str, Any]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -245,7 +244,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: """ data = self.dataframe.iloc[index, :] - sample = { + sample: Sample = { 'input': torch.tensor( data.drop([self.label_name]).values, dtype=torch.float32 ), diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index a7f6a36456a..47ca2f14ee5 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -19,6 +19,7 @@ from .geo import NonGeoDataset from .utils import ( Path, + Sample, check_integrity, draw_semantic_segmentation_masks, extract_archive, @@ -74,7 +75,7 @@ def __init__( self, root: Path = 'data', split: str = 'train', - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, checksum: bool = False, ) -> None: """Initialize a new xView2 dataset instance. @@ -101,7 +102,7 @@ def __init__( self.class2idx = {c: i for i, c in enumerate(self.classes)} self.files = self._load_files(root, split) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -118,7 +119,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image = torch.stack(tensors=[image1, image2], dim=0) mask = torch.stack(tensors=[mask1, mask2], dim=0) - sample = {'image': image, 'mask': mask} + sample: Sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -225,7 +226,7 @@ def _verify(self) -> None: def plot( self, - sample: dict[str, Tensor], + sample: Sample, show_titles: bool = True, suptitle: str | None = None, alpha: float = 0.5, diff --git a/torchgeo/datasets/zuericrop.py b/torchgeo/datasets/zuericrop.py index 2928dc58a70..734bc8f866f 100644 --- a/torchgeo/datasets/zuericrop.py +++ b/torchgeo/datasets/zuericrop.py @@ -13,7 +13,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import Path, download_url, lazy_import, percentile_normalization +from .utils import Path, Sample, download_url, lazy_import, percentile_normalization class ZueriCrop(NonGeoDataset): @@ -66,7 +66,7 @@ def __init__( self, root: Path = 'data', bands: Sequence[str] = band_names, - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[Sample], Sample] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -100,7 +100,7 @@ def __init__( self._verify() - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -112,7 +112,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image = self._load_image(index) mask, boxes, label = self._load_target(index) - sample = {'image': image, 'mask': mask, 'boxes': boxes, 'label': label} + sample: Sample = {'image': image, 'mask': mask, 'boxes': boxes, 'label': label} if self.transforms is not None: sample = self.transforms(sample) @@ -250,7 +250,7 @@ def _validate_bands(self, bands: Sequence[str]) -> None: def plot( self, - sample: dict[str, Tensor], + sample: Sample, time_step: int = 0, show_titles: bool = True, suptitle: str | None = None, diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index 18df10e02f0..573fe3481af 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -14,6 +14,7 @@ from torch import Tensor from torchvision.models._api import WeightsEnum +from ..datasets import Batch from ..models import get_weight from . import utils from .base import BaseTask @@ -349,7 +350,7 @@ def configure_models(self) -> None: self.model = BYOL(backbone, in_channels=in_channels, image_size=(224, 224)) def training_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute the training loss and additional metrics. @@ -395,12 +396,14 @@ def training_step( return loss def validation_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> None: """No-op, does nothing.""" - def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + def test_step(self, batch: Batch, batch_idx: int, dataloader_idx: int = 0) -> None: """No-op, does nothing.""" - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + def predict_step( + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 + ) -> None: """No-op, does nothing.""" diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 2e2766419a5..8b9b4551d6c 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -4,7 +4,6 @@ """Trainers for image classification.""" import os -from typing import Any import matplotlib.pyplot as plt import timm @@ -23,7 +22,7 @@ ) from torchvision.models._api import WeightsEnum -from ..datasets import RGBBandsMissingError, unbind_samples +from ..datasets import Batch, RGBBandsMissingError, unbind_samples from ..models import get_weight from . import utils from .base import BaseTask @@ -164,7 +163,7 @@ def configure_metrics(self) -> None: self.test_metrics = metrics.clone(prefix='test_') def training_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute the training loss and additional metrics. @@ -188,7 +187,7 @@ def training_step( return loss def validation_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> None: """Compute the validation loss and additional metrics. @@ -233,7 +232,7 @@ def validation_step( ) plt.close() - def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + def test_step(self, batch: Batch, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the test loss and additional metrics. Args: @@ -251,7 +250,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None self.log_dict(self.test_metrics, batch_size=batch_size) def predict_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute the predicted class probabilities. @@ -306,7 +305,7 @@ def configure_metrics(self) -> None: self.test_metrics = metrics.clone(prefix='test_') def training_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute the training loss and additional metrics. @@ -331,7 +330,7 @@ def training_step( return loss def validation_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> None: """Compute the validation loss and additional metrics. @@ -376,7 +375,7 @@ def validation_step( f'image/{batch_idx}', fig, global_step=self.global_step ) - def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + def test_step(self, batch: Batch, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the test loss and additional metrics. Args: @@ -395,7 +394,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None self.log_dict(self.test_metrics, batch_size=batch_size) def predict_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute the predicted class probabilities. diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 3d970abdae0..7cb95fd2448 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -4,7 +4,6 @@ """Trainers for object detection.""" from functools import partial -from typing import Any import matplotlib.pyplot as plt import torch @@ -19,7 +18,7 @@ from torchvision.models.detection.rpn import AnchorGenerator from torchvision.ops import MultiScaleRoIAlign, feature_pyramid_network, misc -from ..datasets import RGBBandsMissingError, unbind_samples +from ..datasets import Batch, RGBBandsMissingError, unbind_samples from .base import BaseTask BACKBONE_LAT_DIM_MAP = { @@ -224,7 +223,7 @@ def configure_metrics(self) -> None: self.test_metrics = metrics.clone(prefix='test_') def training_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute the training loss. @@ -248,7 +247,7 @@ def training_step( return train_loss def validation_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> None: """Compute the validation metrics. @@ -303,7 +302,7 @@ def validation_step( ) plt.close() - def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + def test_step(self, batch: Batch, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the test metrics. Args: @@ -326,8 +325,8 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None self.log_dict(metrics, batch_size=batch_size) def predict_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 - ) -> list[dict[str, Tensor]]: + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 + ) -> list[Batch]: """Compute the predicted bounding boxes. Args: @@ -339,5 +338,5 @@ def predict_step( Output predicted probabilities. """ x = batch['image'] - y_hat: list[dict[str, Tensor]] = self(x) + y_hat: list[Batch] = self(x) return y_hat diff --git a/torchgeo/trainers/iobench.py b/torchgeo/trainers/iobench.py index c8826a1dce5..54a8a3141b2 100644 --- a/torchgeo/trainers/iobench.py +++ b/torchgeo/trainers/iobench.py @@ -3,13 +3,12 @@ """Trainers for I/O benchmarking.""" -from typing import Any - import lightning import torch from torch import Tensor from torch.optim import SGD +from ..datasets import Batch from .base import BaseTask @@ -34,7 +33,7 @@ def configure_optimizers( return {'optimizer': optimizer} def training_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """No-op. @@ -49,7 +48,7 @@ def training_step( return torch.tensor(0.0, requires_grad=True) def validation_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> None: """No-op. @@ -59,7 +58,7 @@ def validation_step( dataloader_idx: Index of the current dataloader. """ - def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + def test_step(self, batch: Batch, batch_idx: int, dataloader_idx: int = 0) -> None: """No-op. Args: @@ -68,7 +67,9 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None dataloader_idx: Index of the current dataloader. """ - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + def predict_step( + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 + ) -> None: """No-op. Args: diff --git a/torchgeo/trainers/moco.py b/torchgeo/trainers/moco.py index b079543adba..4945ad9c1cd 100644 --- a/torchgeo/trainers/moco.py +++ b/torchgeo/trainers/moco.py @@ -6,7 +6,6 @@ import os import warnings from collections.abc import Sequence -from typing import Any import kornia.augmentation as K import lightning @@ -30,6 +29,7 @@ import torchgeo.transforms as T +from ..datasets import Batch from ..models import get_weight from . import utils from .base import BaseTask @@ -368,7 +368,7 @@ def forward_momentum(self, x: Tensor) -> Tensor: return k def training_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute the training loss and additional metrics. @@ -436,12 +436,14 @@ def training_step( return loss def validation_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> None: """No-op, does nothing.""" - def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + def test_step(self, batch: Batch, batch_idx: int, dataloader_idx: int = 0) -> None: """No-op, does nothing.""" - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + def predict_step( + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 + ) -> None: """No-op, does nothing.""" diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 0381316050b..8309837d1d4 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -4,7 +4,6 @@ """Trainers for regression.""" import os -from typing import Any import matplotlib.pyplot as plt import segmentation_models_pytorch as smp @@ -16,7 +15,7 @@ from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection from torchvision.models._api import WeightsEnum -from ..datasets import RGBBandsMissingError, unbind_samples +from ..datasets import Batch, RGBBandsMissingError, unbind_samples from ..models import FCN, get_weight from . import utils from .base import BaseTask @@ -146,7 +145,7 @@ def configure_metrics(self) -> None: self.test_metrics = metrics.clone(prefix='test_') def training_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute the training loss and additional metrics. @@ -173,7 +172,7 @@ def training_step( return loss def validation_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> None: """Compute the validation loss and additional metrics. @@ -224,7 +223,7 @@ def validation_step( ) plt.close() - def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + def test_step(self, batch: Batch, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the test loss and additional metrics. Args: @@ -245,7 +244,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None self.log_dict(self.test_metrics, batch_size=batch_size) def predict_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute the predicted regression values. diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index f8e519fa493..3c9dd2ab506 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -4,7 +4,6 @@ """Trainers for semantic segmentation.""" import os -from typing import Any import matplotlib.pyplot as plt import segmentation_models_pytorch as smp @@ -15,7 +14,7 @@ from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex from torchvision.models._api import WeightsEnum -from ..datasets import RGBBandsMissingError, unbind_samples +from ..datasets import Batch, RGBBandsMissingError, unbind_samples from ..models import FCN, get_weight from . import utils from .base import BaseTask @@ -213,7 +212,7 @@ def configure_metrics(self) -> None: self.test_metrics = metrics.clone(prefix='test_') def training_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute the training loss and additional metrics. @@ -236,7 +235,7 @@ def training_step( return loss def validation_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> None: """Compute the validation loss and additional metrics. @@ -281,7 +280,7 @@ def validation_step( ) plt.close() - def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + def test_step(self, batch: Batch, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the test loss and additional metrics. Args: @@ -299,7 +298,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None self.log_dict(self.test_metrics, batch_size=batch_size) def predict_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute the predicted class probabilities. diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index 1cb05315f60..775f939ca90 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -5,7 +5,6 @@ import os import warnings -from typing import Any import kornia.augmentation as K import lightning @@ -22,6 +21,7 @@ import torchgeo.transforms as T +from ..datasets import Batch from ..models import get_weight from . import utils from .base import BaseTask @@ -221,7 +221,7 @@ def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: return z, h def training_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute the training loss and additional metrics. @@ -272,16 +272,18 @@ def training_step( return loss def validation_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 ) -> None: """No-op, does nothing.""" - def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + def test_step(self, batch: Batch, batch_idx: int, dataloader_idx: int = 0) -> None: """No-op, does nothing.""" # TODO # v2: add distillation step - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + def predict_step( + self, batch: Batch, batch_idx: int, dataloader_idx: int = 0 + ) -> None: """No-op, does nothing.""" def configure_optimizers( diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index d8f80bdcaac..d5a78260b82 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -14,6 +14,8 @@ from torch import Tensor from torch.nn.modules import Module +from ..datasets import Batch + # TODO: contribute these to Kornia and delete this file class AugmentationSequential(Module): @@ -55,7 +57,7 @@ def __init__( self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) - def forward(self, batch: dict[str, Any]) -> dict[str, Any]: + def forward(self, batch: Batch) -> Batch: """Perform augmentations and update data dict. Args: @@ -88,9 +90,7 @@ def forward(self, batch: dict[str, Any]) -> dict[str, Any]: outputs_list = ( outputs_list if isinstance(outputs_list, list) else [outputs_list] ) - outputs: dict[str, Tensor] = { - k: v for k, v in zip(self.data_keys, outputs_list) - } + outputs: Batch = {k: v for k, v in zip(self.data_keys, outputs_list)} batch.update(outputs) # Convert all inputs back to their previous dtype @@ -179,7 +179,7 @@ def __init__(self, size: tuple[int, int] | Tensor, num: int) -> None: def forward( self, batch_shape: tuple[int, ...], same_on_batch: bool = False - ) -> dict[str, Tensor]: + ) -> Batch: """Generate the crops. Args: From bf5b0a57725fb410ef99ee73ba144b5bc66ec327 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 24 Aug 2024 13:01:08 +0200 Subject: [PATCH 2/3] Fix docs --- docs/api/datasets.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 96ca225344a..ec2dfaffa18 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -555,6 +555,8 @@ Utilities --------- .. autoclass:: BoundingBox +.. autoclass:: Sample +.. autoclass:: Batch Collation Functions ^^^^^^^^^^^^^^^^^^^ From 9eb51f55ba3205e0fb08d6a7aa924887afb11c82 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 24 Aug 2024 13:26:51 +0200 Subject: [PATCH 3/3] More fixes --- tests/datasets/test_geo.py | 5 +- tests/datasets/test_splits.py | 4 +- tests/datasets/test_utils.py | 18 +++---- tests/samplers/test_batch.py | 4 +- tests/samplers/test_single.py | 4 +- torchgeo/datasets/enviroatlas.py | 26 ++++++---- torchgeo/datasets/eurocrops.py | 6 +-- torchgeo/datasets/fair1m.py | 2 +- torchgeo/datasets/gid15.py | 2 +- torchgeo/datasets/skippd.py | 8 +-- torchgeo/datasets/sustainbench_crop_yield.py | 4 +- torchgeo/datasets/utils.py | 54 ++++++++++++++++++-- 12 files changed, 93 insertions(+), 44 deletions(-) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 71e07f6928b..15ebc192603 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -26,6 +26,7 @@ NonGeoClassificationDataset, NonGeoDataset, RasterDataset, + Sample, Sentinel2, UnionDataset, VectorDataset, @@ -46,7 +47,7 @@ def __init__( self.res = res self.paths = paths or [] - def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: + def __getitem__(self, query: BoundingBox) -> Sample: hits = self.index.intersection(tuple(query), objects=True) hit = next(iter(hits)) bounds = BoundingBox(*hit.bounds) @@ -77,7 +78,7 @@ class CustomSentinelDataset(Sentinel2): class CustomNonGeoDataset(NonGeoDataset): - def __getitem__(self, index: int) -> dict[str, int]: + def __getitem__(self, index: int) -> Sample: return {'index': index} def __len__(self) -> int: diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index 2977586ddb4..1ed16487d6b 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -3,7 +3,6 @@ from collections.abc import Sequence from math import floor, isclose -from typing import Any import pytest from rasterio.crs import CRS @@ -11,6 +10,7 @@ from torchgeo.datasets import ( BoundingBox, GeoDataset, + Sample, random_bbox_assignment, random_bbox_splitting, random_grid_cell_assignment, @@ -49,7 +49,7 @@ def __init__( self._crs = crs self.res = res - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: hits = self.index.intersection(tuple(query), objects=True) hit = next(iter(hits)) return {'content': hit.object} diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index 141709ceba8..74a3a7e839d 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -15,7 +15,7 @@ import torch from rasterio.crs import CRS -from torchgeo.datasets import BoundingBox, DependencyNotFoundError +from torchgeo.datasets import BoundingBox, DependencyNotFoundError, Sample from torchgeo.datasets.utils import ( Executable, array_to_tensor, @@ -393,13 +393,13 @@ def test_disambiguate_timestamp( class TestCollateFunctionsMatchingKeys: @pytest.fixture(scope='class') - def samples(self) -> list[dict[str, Any]]: + def samples(self) -> list[Sample]: return [ {'image': torch.tensor([1, 2, 0]), 'crs': CRS.from_epsg(2000)}, {'image': torch.tensor([0, 0, 3]), 'crs': CRS.from_epsg(2001)}, ] - def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None: + def test_stack_unbind_samples(self, samples: list[Sample]) -> None: sample = stack_samples(samples) assert sample['image'].size() == torch.Size([2, 3]) assert torch.allclose(sample['image'], torch.tensor([[1, 2, 0], [0, 0, 3]])) @@ -410,13 +410,13 @@ def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None: assert torch.allclose(samples[i]['image'], new_samples[i]['image']) assert samples[i]['crs'] == new_samples[i]['crs'] - def test_concat_samples(self, samples: list[dict[str, Any]]) -> None: + def test_concat_samples(self, samples: list[Sample]) -> None: sample = concat_samples(samples) assert sample['image'].size() == torch.Size([6]) assert torch.allclose(sample['image'], torch.tensor([1, 2, 0, 0, 0, 3])) assert sample['crs'] == CRS.from_epsg(2000) - def test_merge_samples(self, samples: list[dict[str, Any]]) -> None: + def test_merge_samples(self, samples: list[Sample]) -> None: sample = merge_samples(samples) assert sample['image'].size() == torch.Size([3]) assert torch.allclose(sample['image'], torch.tensor([1, 2, 3])) @@ -425,13 +425,13 @@ def test_merge_samples(self, samples: list[dict[str, Any]]) -> None: class TestCollateFunctionsDifferingKeys: @pytest.fixture(scope='class') - def samples(self) -> list[dict[str, Any]]: + def samples(self) -> list[Sample]: return [ {'image': torch.tensor([1, 2, 0]), 'crs1': CRS.from_epsg(2000)}, {'mask': torch.tensor([0, 0, 3]), 'crs2': CRS.from_epsg(2001)}, ] - def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None: + def test_stack_unbind_samples(self, samples: list[Sample]) -> None: sample = stack_samples(samples) assert sample['image'].size() == torch.Size([1, 3]) assert sample['mask'].size() == torch.Size([1, 3]) @@ -446,7 +446,7 @@ def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None: assert torch.allclose(samples[1]['mask'], new_samples[0]['mask']) assert samples[1]['crs2'] == new_samples[0]['crs2'] - def test_concat_samples(self, samples: list[dict[str, Any]]) -> None: + def test_concat_samples(self, samples: list[Sample]) -> None: sample = concat_samples(samples) assert sample['image'].size() == torch.Size([3]) assert sample['mask'].size() == torch.Size([3]) @@ -455,7 +455,7 @@ def test_concat_samples(self, samples: list[dict[str, Any]]) -> None: assert sample['crs1'] == CRS.from_epsg(2000) assert sample['crs2'] == CRS.from_epsg(2001) - def test_merge_samples(self, samples: list[dict[str, Any]]) -> None: + def test_merge_samples(self, samples: list[Sample]) -> None: sample = merge_samples(samples) assert sample['image'].size() == torch.Size([3]) assert sample['mask'].size() == torch.Size([3]) diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 199239a0e79..c4d660284ae 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -11,7 +11,7 @@ from rasterio.crs import CRS from torch.utils.data import DataLoader -from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples +from torchgeo.datasets import BoundingBox, GeoDataset, Sample, stack_samples from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler, Units @@ -33,7 +33,7 @@ def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None: self._crs = crs self.res = res - def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: + def __getitem__(self, query: BoundingBox) -> Sample: return {'index': query} diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index e2c829f1b9e..48c14cfc95a 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -11,7 +11,7 @@ from rasterio.crs import CRS from torch.utils.data import DataLoader -from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples +from torchgeo.datasets import BoundingBox, GeoDataset, Sample, stack_samples from torchgeo.samplers import ( GeoSampler, GridGeoSampler, @@ -40,7 +40,7 @@ def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None: self._crs = crs self.res = res - def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: + def __getitem__(self, query: BoundingBox) -> Sample: return {'index': query} diff --git a/torchgeo/datasets/enviroatlas.py b/torchgeo/datasets/enviroatlas.py index 348928b1996..aa6603bd842 100644 --- a/torchgeo/datasets/enviroatlas.py +++ b/torchgeo/datasets/enviroatlas.py @@ -6,7 +6,7 @@ import os import sys from collections.abc import Callable, Sequence -from typing import ClassVar, cast +from typing import Any, ClassVar, cast import fiona import matplotlib.pyplot as plt @@ -347,8 +347,8 @@ def __getitem__(self, query: BoundingBox) -> Sample: """ hits = self.index.intersection(tuple(query), objects=True) filepaths = cast(list[dict[str, str]], [hit.object for hit in hits]) - - sample: Sample = {'image': [], 'mask': [], 'crs': self.crs, 'bounds': query} + images: list[np.typing.NDArray[Any]] = [] + masks: list[np.typing.NDArray[Any]] = [] if len(filepaths) == 0: raise IndexError( @@ -389,23 +389,27 @@ def __getitem__(self, query: BoundingBox) -> Sample: 'waterbodies', 'water', ]: - sample['image'].append(data) + images.append(data) elif layer in ['prior', 'prior_no_osm_no_buildings']: if self.prior_as_input: - sample['image'].append(data) + images.append(data) else: - sample['mask'].append(data) + masks.append(data) elif layer in ['lc']: data = self.raw_enviroatlas_to_idx_map[data] - sample['mask'].append(data) + masks.append(data) else: raise IndexError(f'query: {query} spans multiple tiles which is not valid') - sample['image'] = np.concatenate(sample['image'], axis=0) - sample['mask'] = np.concatenate(sample['mask'], axis=0) + image = torch.from_numpy(np.concatenate(images, axis=0)) + mask = torch.from_numpy(np.concatenate(masks, axis=0)) - sample['image'] = torch.from_numpy(sample['image']) - sample['mask'] = torch.from_numpy(sample['mask']) + sample: Sample = { + 'image': image, + 'mask': mask, + 'crs': self.crs, + 'bounds': query, + } if self.transforms is not None: sample = self.transforms(sample) diff --git a/torchgeo/datasets/eurocrops.py b/torchgeo/datasets/eurocrops.py index f3caa517a82..89563797a4d 100644 --- a/torchgeo/datasets/eurocrops.py +++ b/torchgeo/datasets/eurocrops.py @@ -6,13 +6,13 @@ import csv import os from collections.abc import Callable, Iterable -from typing import Any import fiona import matplotlib.pyplot as plt import numpy as np from matplotlib.figure import Figure from rasterio.crs import CRS +from torch import Tensor from .errors import DatasetNotFoundError from .geo import VectorDataset @@ -247,9 +247,7 @@ def plot( fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4, 4)) - def apply_cmap( - arr: 'np.typing.NDArray[Any]', - ) -> 'np.typing.NDArray[np.float64]': + def apply_cmap(arr: Tensor) -> 'np.typing.NDArray[np.float64]': # Color 0 as black, while applying default color map for the class indices. cmap = plt.get_cmap('viridis') im: np.typing.NDArray[np.float64] = cmap(arr / len(self.class_map)) diff --git a/torchgeo/datasets/fair1m.py b/torchgeo/datasets/fair1m.py index 0cb8de4430e..5ceadde0918 100644 --- a/torchgeo/datasets/fair1m.py +++ b/torchgeo/datasets/fair1m.py @@ -283,7 +283,7 @@ def __getitem__(self, index: int) -> Sample: label_path = label_path.replace('.tif', '.xml') voc = parse_pascal_voc(label_path) boxes, labels = self._load_target(voc['points'], voc['labels']) - sample: Sample = {'image': image, 'boxes': boxes, 'label': labels} + sample = {'image': image, 'boxes': boxes, 'label': labels} if self.transforms is not None: sample = self.transforms(sample) diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py index 589f3c1115a..89a6dcfa4f7 100644 --- a/torchgeo/datasets/gid15.py +++ b/torchgeo/datasets/gid15.py @@ -139,7 +139,7 @@ def __getitem__(self, index: int) -> Sample: mask = self._load_target(files['mask']) sample: Sample = {'image': image, 'mask': mask} else: - sample: Sample = {'image': image} + sample = {'image': image} if self.transforms is not None: sample = self.transforms(sample) diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py index 4516fb937ad..a878960dc70 100644 --- a/torchgeo/datasets/skippd.py +++ b/torchgeo/datasets/skippd.py @@ -134,7 +134,7 @@ def __len__(self) -> int: return num_datapoints - def __getitem__(self, index: int) -> dict[str, str | Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -143,7 +143,7 @@ def __getitem__(self, index: int) -> dict[str, str | Tensor]: Returns: data and label at that index """ - sample: dict[str, str | Tensor] = {'image': self._load_image(index)} + sample: Sample = {'image': self._load_image(index)} sample.update(self._load_features(index)) if self.transforms is not None: @@ -176,7 +176,7 @@ def _load_image(self, index: int) -> Tensor: tensor = torch.from_numpy(arr).to(torch.float32) return tensor - def _load_features(self, index: int) -> dict[str, str | Tensor]: + def _load_features(self, index: int) -> Sample: """Load label. Args: @@ -194,7 +194,7 @@ def _load_features(self, index: int) -> dict[str, str | Tensor]: path = os.path.join(self.root, f'times_{self.split}_{self.task}.npy') datestring = np.load(path, allow_pickle=True)[index].strftime(self.dateformat) - features: dict[str, str | Tensor] = { + features: Sample = { 'label': torch.tensor(label, dtype=torch.float32), 'date': datestring, } diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py index 6dc041540fa..237d7adb8a7 100644 --- a/torchgeo/datasets/sustainbench_crop_yield.py +++ b/torchgeo/datasets/sustainbench_crop_yield.py @@ -98,7 +98,7 @@ def __init__( self._verify() self.images = [] - self.features = [] + self.features: list[Sample] = [] for country in self.countries: image_file_path = os.path.join( @@ -122,7 +122,7 @@ def __init__( year = year_npz_file[idx] ndvi = ndvi_npz_file[idx] - features = { + features: Sample = { 'label': torch.tensor(target).to(torch.float32), 'year': torch.tensor(int(year)), 'ndvi': torch.from_numpy(ndvi).to(dtype=torch.float32), diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index ee3f934d8c9..ad19bab1ac9 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -65,8 +65,54 @@ class Sample(TypedDict, total=False): bounds: BoundingBox crs: CRS - # TODO: remove + # TODO: Additional dataset-specific keys that should be subclasses + images: Tensor + input: Tensor boxes: Tensor + bboxes: Tensor + masks: Tensor + labels: Tensor + prediction_masks: Tensor + prediction_boxes: Tensor + prediction_labels: Tensor + prediction_label: Tensor + prediction_scores: Tensor + audio: Tensor + points: Tensor + x: Tensor + y: Tensor + relative_time: Tensor + ocean: Tensor + array: Tensor + chm: Tensor + hsi: Tensor + las: Tensor + image1: Tensor + image2: Tensor + crs1: Tensor + crs2: Tensor + magnitude: Tensor + agb: Tensor + key: Tensor + patch: Tensor + geometry: Tensor + properties: Tensor + id: int + centroid_lat: Tensor + centroid_lon: Tensor + content: Tensor + year: Tensor + ndvi: Tensor + filename: str + category: str + field_ids: Tensor + tile_index: Tensor + transform: Tensor + src: Tensor + dst: Tensor + input_size: Tensor + output_size: Tensor + index: BoundingBox class Batch(Sample): @@ -455,7 +501,7 @@ def stack_samples(samples: Iterable[Sample]) -> Batch: .. versionadded:: 0.2 """ - collated: dict[Any, Any] = _list_dict_to_dict_list(samples) + collated: Batch = _list_dict_to_dict_list(samples) for key, value in collated.items(): if isinstance(value[0], Tensor): collated[key] = torch.stack(value) @@ -475,7 +521,7 @@ def concat_samples(samples: Iterable[Sample]) -> Batch: .. versionadded:: 0.2 """ - collated: dict[Any, Any] = _list_dict_to_dict_list(samples) + collated: Batch = _list_dict_to_dict_list(samples) for key, value in collated.items(): if isinstance(value[0], Tensor): collated[key] = torch.cat(value) @@ -497,7 +543,7 @@ def merge_samples(samples: Iterable[Sample]) -> Batch: .. versionadded:: 0.2 """ - collated: dict[Any, Any] = {} + collated: Batch = {} for sample in samples: for key, value in sample.items(): if key in collated and isinstance(value, Tensor):