Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Sample and Batch TypedDicts #2249

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,8 @@ Utilities
---------

.. autoclass:: BoundingBox
.. autoclass:: Sample
.. autoclass:: Batch

Collation Functions
^^^^^^^^^^^^^^^^^^^
Expand Down
7 changes: 3 additions & 4 deletions tests/datamodules/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
from lightning.pytorch import Trainer
from matplotlib.figure import Figure
from rasterio.crs import CRS
from torch import Tensor

from torchgeo.datamodules import (
GeoDataModule,
MisconfigurationException,
NonGeoDataModule,
)
from torchgeo.datasets import BoundingBox, GeoDataset, NonGeoDataset
from torchgeo.datasets import BoundingBox, GeoDataset, NonGeoDataset, Sample
from torchgeo.samplers import RandomBatchGeoSampler, RandomGeoSampler


Expand All @@ -30,7 +29,7 @@ def __init__(
self.index.insert(i, (0, 1, 2, 3, 4, 5))
self.res = 1

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
def __getitem__(self, query: BoundingBox) -> Sample:
image = torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)
return {'image': image, 'crs': CRS.from_epsg(4326), 'bounds': query}

Expand Down Expand Up @@ -67,7 +66,7 @@ def __init__(
) -> None:
self.length = length

def __getitem__(self, index: int) -> dict[str, Tensor]:
def __getitem__(self, index: int) -> Sample:
return {'image': torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)}

def __len__(self) -> int:
Expand Down
5 changes: 3 additions & 2 deletions tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
NonGeoClassificationDataset,
NonGeoDataset,
RasterDataset,
Sample,
Sentinel2,
UnionDataset,
VectorDataset,
Expand All @@ -46,7 +47,7 @@ def __init__(
self.res = res
self.paths = paths or []

def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
def __getitem__(self, query: BoundingBox) -> Sample:
hits = self.index.intersection(tuple(query), objects=True)
hit = next(iter(hits))
bounds = BoundingBox(*hit.bounds)
Expand Down Expand Up @@ -77,7 +78,7 @@ class CustomSentinelDataset(Sentinel2):


class CustomNonGeoDataset(NonGeoDataset):
def __getitem__(self, index: int) -> dict[str, int]:
def __getitem__(self, index: int) -> Sample:
return {'index': index}

def __len__(self) -> int:
Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/test_splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

from collections.abc import Sequence
from math import floor, isclose
from typing import Any

import pytest
from rasterio.crs import CRS

from torchgeo.datasets import (
BoundingBox,
GeoDataset,
Sample,
random_bbox_assignment,
random_bbox_splitting,
random_grid_cell_assignment,
Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(
self._crs = crs
self.res = res

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
def __getitem__(self, query: BoundingBox) -> Sample:
hits = self.index.intersection(tuple(query), objects=True)
hit = next(iter(hits))
return {'content': hit.object}
Expand Down
18 changes: 9 additions & 9 deletions tests/datasets/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
from rasterio.crs import CRS

from torchgeo.datasets import BoundingBox, DependencyNotFoundError
from torchgeo.datasets import BoundingBox, DependencyNotFoundError, Sample
from torchgeo.datasets.utils import (
Executable,
array_to_tensor,
Expand Down Expand Up @@ -393,13 +393,13 @@ def test_disambiguate_timestamp(

class TestCollateFunctionsMatchingKeys:
@pytest.fixture(scope='class')
def samples(self) -> list[dict[str, Any]]:
def samples(self) -> list[Sample]:
return [
{'image': torch.tensor([1, 2, 0]), 'crs': CRS.from_epsg(2000)},
{'image': torch.tensor([0, 0, 3]), 'crs': CRS.from_epsg(2001)},
]

def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None:
def test_stack_unbind_samples(self, samples: list[Sample]) -> None:
sample = stack_samples(samples)
assert sample['image'].size() == torch.Size([2, 3])
assert torch.allclose(sample['image'], torch.tensor([[1, 2, 0], [0, 0, 3]]))
Expand All @@ -410,13 +410,13 @@ def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None:
assert torch.allclose(samples[i]['image'], new_samples[i]['image'])
assert samples[i]['crs'] == new_samples[i]['crs']

def test_concat_samples(self, samples: list[dict[str, Any]]) -> None:
def test_concat_samples(self, samples: list[Sample]) -> None:
sample = concat_samples(samples)
assert sample['image'].size() == torch.Size([6])
assert torch.allclose(sample['image'], torch.tensor([1, 2, 0, 0, 0, 3]))
assert sample['crs'] == CRS.from_epsg(2000)

def test_merge_samples(self, samples: list[dict[str, Any]]) -> None:
def test_merge_samples(self, samples: list[Sample]) -> None:
sample = merge_samples(samples)
assert sample['image'].size() == torch.Size([3])
assert torch.allclose(sample['image'], torch.tensor([1, 2, 3]))
Expand All @@ -425,13 +425,13 @@ def test_merge_samples(self, samples: list[dict[str, Any]]) -> None:

class TestCollateFunctionsDifferingKeys:
@pytest.fixture(scope='class')
def samples(self) -> list[dict[str, Any]]:
def samples(self) -> list[Sample]:
return [
{'image': torch.tensor([1, 2, 0]), 'crs1': CRS.from_epsg(2000)},
{'mask': torch.tensor([0, 0, 3]), 'crs2': CRS.from_epsg(2001)},
]

def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None:
def test_stack_unbind_samples(self, samples: list[Sample]) -> None:
sample = stack_samples(samples)
assert sample['image'].size() == torch.Size([1, 3])
assert sample['mask'].size() == torch.Size([1, 3])
Expand All @@ -446,7 +446,7 @@ def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None:
assert torch.allclose(samples[1]['mask'], new_samples[0]['mask'])
assert samples[1]['crs2'] == new_samples[0]['crs2']

def test_concat_samples(self, samples: list[dict[str, Any]]) -> None:
def test_concat_samples(self, samples: list[Sample]) -> None:
sample = concat_samples(samples)
assert sample['image'].size() == torch.Size([3])
assert sample['mask'].size() == torch.Size([3])
Expand All @@ -455,7 +455,7 @@ def test_concat_samples(self, samples: list[dict[str, Any]]) -> None:
assert sample['crs1'] == CRS.from_epsg(2000)
assert sample['crs2'] == CRS.from_epsg(2001)

def test_merge_samples(self, samples: list[dict[str, Any]]) -> None:
def test_merge_samples(self, samples: list[Sample]) -> None:
sample = merge_samples(samples)
assert sample['image'].size() == torch.Size([3])
assert sample['mask'].size() == torch.Size([3])
Expand Down
4 changes: 2 additions & 2 deletions tests/samplers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rasterio.crs import CRS
from torch.utils.data import DataLoader

from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples
from torchgeo.datasets import BoundingBox, GeoDataset, Sample, stack_samples
from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler, Units


Expand All @@ -33,7 +33,7 @@ def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None:
self._crs = crs
self.res = res

def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
def __getitem__(self, query: BoundingBox) -> Sample:
return {'index': query}


Expand Down
4 changes: 2 additions & 2 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rasterio.crs import CRS
from torch.utils.data import DataLoader

from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples
from torchgeo.datasets import BoundingBox, GeoDataset, Sample, stack_samples
from torchgeo.samplers import (
GeoSampler,
GridGeoSampler,
Expand Down Expand Up @@ -40,7 +40,7 @@ def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None:
self._crs = crs
self.res = res

def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
def __getitem__(self, query: BoundingBox) -> Sample:
return {'index': query}


Expand Down
9 changes: 5 additions & 4 deletions tests/transforms/test_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@
import torch
from torch import Tensor

from torchgeo.datasets import Batch, Sample
from torchgeo.transforms import RandomGrayscale


@pytest.fixture
def sample() -> dict[str, Tensor]:
def sample() -> Sample:
return {
'image': torch.arange(3 * 4 * 4, dtype=torch.float).view(3, 4, 4),
'mask': torch.arange(4 * 4, dtype=torch.long).view(1, 4, 4),
}


@pytest.fixture
def batch() -> dict[str, Tensor]:
def batch() -> Batch:
return {
'image': torch.arange(2 * 3 * 4 * 4, dtype=torch.float).view(2, 3, 4, 4),
'mask': torch.arange(2 * 4 * 4, dtype=torch.long).view(2, 1, 4, 4),
Expand All @@ -33,7 +34,7 @@ def batch() -> dict[str, Tensor]:
torch.tensor([1.0, 2.0, 3.0]),
],
)
def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) -> None:
def test_random_grayscale_sample(weights: Tensor, sample: Sample) -> None:
aug = K.AugmentationSequential(
RandomGrayscale(weights, p=1), keepdim=True, data_keys=None
)
Expand All @@ -51,7 +52,7 @@ def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) ->
torch.tensor([1.0, 2.0, 3.0]),
],
)
def test_random_grayscale_batch(weights: Tensor, batch: dict[str, Tensor]) -> None:
def test_random_grayscale_batch(weights: Tensor, batch: Batch) -> None:
aug = K.AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=None)
output = aug(batch)
assert output['image'].shape == batch['image'].shape
Expand Down
16 changes: 8 additions & 8 deletions tests/transforms/test_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import kornia.augmentation as K
import pytest
import torch
from torch import Tensor

from torchgeo.datasets import Batch, Sample
from torchgeo.transforms import (
AppendBNDVI,
AppendGBNDVI,
Expand All @@ -25,22 +25,22 @@


@pytest.fixture
def sample() -> dict[str, Tensor]:
def sample() -> Sample:
return {
'image': torch.arange(3 * 4 * 4, dtype=torch.float).view(3, 4, 4),
'mask': torch.arange(4 * 4, dtype=torch.long).view(1, 4, 4),
}


@pytest.fixture
def batch() -> dict[str, Tensor]:
def batch() -> Batch:
return {
'image': torch.arange(2 * 3 * 4 * 4, dtype=torch.float).view(2, 3, 4, 4),
'mask': torch.arange(2 * 4 * 4, dtype=torch.long).view(2, 1, 4, 4),
}


def test_append_index_sample(sample: dict[str, Tensor]) -> None:
def test_append_index_sample(sample: Sample) -> None:
c, h, w = sample['image'].shape
aug = K.AugmentationSequential(
AppendNormalizedDifferenceIndex(index_a=0, index_b=1), data_keys=None
Expand All @@ -49,7 +49,7 @@ def test_append_index_sample(sample: dict[str, Tensor]) -> None:
assert output['image'].shape == (1, c + 1, h, w)


def test_append_index_batch(batch: dict[str, Tensor]) -> None:
def test_append_index_batch(batch: Batch) -> None:
b, c, h, w = batch['image'].shape
aug = K.AugmentationSequential(
AppendNormalizedDifferenceIndex(index_a=0, index_b=1), data_keys=None
Expand All @@ -58,7 +58,7 @@ def test_append_index_batch(batch: dict[str, Tensor]) -> None:
assert output['image'].shape == (b, c + 1, h, w)


def test_append_triband_index_batch(batch: dict[str, Tensor]) -> None:
def test_append_triband_index_batch(batch: Batch) -> None:
b, c, h, w = batch['image'].shape
aug = K.AugmentationSequential(
AppendTriBandNormalizedDifferenceIndex(index_a=0, index_b=1, index_c=2),
Expand All @@ -83,7 +83,7 @@ def test_append_triband_index_batch(batch: dict[str, Tensor]) -> None:
],
)
def test_append_normalized_difference_indices(
sample: dict[str, Tensor], index: AppendNormalizedDifferenceIndex
sample: Sample, index: AppendNormalizedDifferenceIndex
) -> None:
c, h, w = sample['image'].shape
aug = K.AugmentationSequential(index(0, 1), data_keys=None)
Expand All @@ -93,7 +93,7 @@ def test_append_normalized_difference_indices(

@pytest.mark.parametrize('index', [AppendGBNDVI, AppendGRNDVI, AppendRBNDVI])
def test_append_tri_band_normalized_difference_indices(
sample: dict[str, Tensor], index: AppendTriBandNormalizedDifferenceIndex
sample: Sample, index: AppendTriBandNormalizedDifferenceIndex
) -> None:
c, h, w = sample['image'].shape
aug = K.AugmentationSequential(index(0, 1, 2), data_keys=None)
Expand Down
7 changes: 2 additions & 5 deletions torchgeo/datamodules/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

import kornia.augmentation as K
import torch.nn.functional as F
from torch import Tensor

from ..datasets import ChesapeakeCVPR
from ..datasets import ChesapeakeCVPR, Sample
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from .geo import GeoDataModule

Expand Down Expand Up @@ -124,9 +123,7 @@ def setup(self, stage: str) -> None:
self.test_dataset, self.original_patch_size, self.original_patch_size
)

def on_after_batch_transfer(
self, batch: dict[str, Tensor], dataloader_idx: int
) -> dict[str, Tensor]:
def on_after_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample:
"""Apply batch augmentations to the batch after it is transferred to the device.

Args:
Expand Down
7 changes: 2 additions & 5 deletions torchgeo/datamodules/etci2021.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from typing import Any

import torch
from torch import Tensor

from ..datasets import ETCI2021
from ..datasets import ETCI2021, Sample
from .geo import NonGeoDataModule


Expand Down Expand Up @@ -62,9 +61,7 @@ def setup(self, stage: str) -> None:
# Test set masks are not public, use for prediction instead
self.predict_dataset = ETCI2021(split='test', **self.kwargs)

def on_after_batch_transfer(
self, batch: dict[str, Tensor], dataloader_idx: int
) -> dict[str, Tensor]:
def on_after_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample:
"""Apply batch augmentations to the batch after it is transferred to the device.

Args:
Expand Down
7 changes: 3 additions & 4 deletions torchgeo/datamodules/fair1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
from typing import Any

import torch
from torch import Tensor

from ..datasets import FAIR1M
from ..datasets import FAIR1M, Sample
from .geo import NonGeoDataModule


def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]:
def collate_fn(batch: list[Sample]) -> Sample:
"""Custom object detection collate fn to handle variable boxes.

Args:
Expand All @@ -23,7 +22,7 @@ def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]:

.. versionadded:: 0.5
"""
output: dict[str, Any] = {}
output: Sample = {}
output['image'] = torch.stack([sample['image'] for sample in batch])

if 'boxes' in batch[0]:
Expand Down
Loading
Loading