Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HySpecNet-11k: add new dataset #2410

Merged
merged 7 commits into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ GID-15

.. autoclass:: GID15DataModule

HySpecNet-11k
^^^^^^^^^^^^^

.. autoclass:: HySpecNet11kDataModule

Inria Aerial Image Labeling
^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,11 @@ GID-15

.. autoclass:: GID15

HySpecNet-11k
^^^^^^^^^^^^^

.. autoclass:: HySpecNet11k

IDTReeS
^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`Forest Damage`_,OD,Drone imagery,"CDLA-Permissive-1.0","1,543",4,"1,500x1,500",,RGB
`GeoNRW`_,S,Aerial,"CC-BY-4.0","7,783",11,"1,000x1,000",1,"RGB, DEM"
`GID-15`_,S,Gaofen-2,-,150,15,"6,800x7,200",3,RGB
`HySpecNet-11k`_,-,EnMAP,CC0-1.0,11k,-,128,30,HSI
`IDTReeS`_,"OD,C",Aerial,"CC-BY-4.0",591,33,200x200,0.1--1,RGB
`Inria Aerial Image Labeling`_,S,Aerial,-,360,2,"5,000x5,000",0.3,RGB
`LandCover.ai`_,S,Aerial,"CC-BY-NC-SA-4.0","10,674",5,512x512,0.25--0.5,RGB
Expand Down
11 changes: 11 additions & 0 deletions tests/conf/hyspecnet_byol.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
model:
class_path: BYOLTask
init_args:
model: 'resnet18'
in_channels: 202
data:
class_path: HySpecNet11kDataModule
init_args:
batch_size: 2
dict_kwargs:
root: 'tests/data/hyspecnet'
11 changes: 11 additions & 0 deletions tests/conf/hyspecnet_moco.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
model:
class_path: MoCoTask
init_args:
model: 'resnet18'
in_channels: 202
data:
class_path: HySpecNet11kDataModule
init_args:
batch_size: 2
dict_kwargs:
root: 'tests/data/hyspecnet'
11 changes: 11 additions & 0 deletions tests/conf/hyspecnet_simclr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
model:
class_path: SimCLRTask
init_args:
model: 'resnet18'
in_channels: 202
data:
class_path: HySpecNet11kDataModule
init_args:
batch_size: 2
dict_kwargs:
root: 'tests/data/hyspecnet'
61 changes: 61 additions & 0 deletions tests/data/hyspecnet/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil

import numpy as np
import rasterio
from rasterio import Affine
from rasterio.crs import CRS

SIZE = 32
DTYPE = 'int16'

np.random.seed(0)

# Tile name purposefully shortened to avoid Windows git filename length limit.
tiles = ['ENMAP01_20221103T162438Z']
patches = ['Y01460273_X05670694', 'Y01460273_X06950822']

profile = {
'driver': 'GTiff',
'dtype': DTYPE,
'nodata': -32768.0,
'width': SIZE,
'height': SIZE,
'count': 224,
'crs': CRS.from_epsg(32618),
'transform': Affine(30.0, 0.0, 691845.0, 0.0, -30.0, 4561935.0),
'blockysize': 3,
'tiled': False,
'compress': 'deflate',
'interleave': 'band',
}

root = 'hyspecnet-11k'
path = os.path.join(root, 'splits', 'easy')
os.makedirs(path, exist_ok=True)
for tile in tiles:
for patch in patches:
# Split CSV
path = os.path.join(tile, f'{tile}-{patch}', f'{tile}-{patch}-DATA.npy')
for split in ['train', 'val', 'test']:
with open(os.path.join(root, 'splits', 'easy', f'{split}.csv'), 'a+') as f:
f.write(f'{path}\n')

# Spectral image
path = os.path.join(root, 'patches', path)
os.makedirs(os.path.dirname(path), exist_ok=True)
path = path.replace('DATA.npy', 'SPECTRAL_IMAGE.TIF')
Z = np.random.randint(
np.iinfo(DTYPE).min, np.iinfo(DTYPE).max, size=(SIZE, SIZE), dtype=DTYPE
)
with rasterio.open(path, 'w', **profile) as src:
for i in range(1, profile['count'] + 1):
src.write(Z, i)

shutil.make_archive(f'{root}-01', 'gztar', '.', os.path.join(root, 'patches'))
shutil.make_archive(f'{root}-splits', 'gztar', '.', os.path.join(root, 'splits'))
Binary file added tests/data/hyspecnet/hyspecnet-11k-01.tar.gz
Binary file not shown.
Binary file added tests/data/hyspecnet/hyspecnet-11k-splits.tar.gz
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 2 additions & 0 deletions tests/data/hyspecnet/hyspecnet-11k/splits/easy/test.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-DATA.npy
ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-DATA.npy
2 changes: 2 additions & 0 deletions tests/data/hyspecnet/hyspecnet-11k/splits/easy/train.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-DATA.npy
ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-DATA.npy
2 changes: 2 additions & 0 deletions tests/data/hyspecnet/hyspecnet-11k/splits/easy/val.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-DATA.npy
ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-DATA.npy
58 changes: 58 additions & 0 deletions tests/datasets/test_hyspecnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import glob
import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch.nn as nn
from pytest import MonkeyPatch
from torch import Tensor

from torchgeo.datasets import DatasetNotFoundError, HySpecNet11k, RGBBandsMissingError

root = os.path.join('tests', 'data', 'hyspecnet')
md5s = {'hyspecnet-11k-01.tar.gz': '', 'hyspecnet-11k-splits.tar.gz': ''}


class TestHySpecNet11k:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch) -> HySpecNet11k:
monkeypatch.setattr(HySpecNet11k, 'url', root + os.sep)
monkeypatch.setattr(HySpecNet11k, 'md5s', md5s)
transforms = nn.Identity()
return HySpecNet11k(root, transforms=transforms)

def test_getitem(self, dataset: HySpecNet11k) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['image'], Tensor)

def test_len(self, dataset: HySpecNet11k) -> None:
assert len(dataset) == 2

def test_download(self, dataset: HySpecNet11k, tmp_path: Path) -> None:
HySpecNet11k(tmp_path, download=True)

def test_extract(self, dataset: HySpecNet11k, tmp_path: Path) -> None:
for file in glob.iglob(os.path.join(root, '*.tar.gz')):
shutil.copy(file, tmp_path)
HySpecNet11k(tmp_path)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
HySpecNet11k(tmp_path)

def test_plot(self, dataset: HySpecNet11k) -> None:
x = dataset[0]
dataset.plot(x, suptitle='Test')
plt.close()

def test_plot_rgb(self, dataset: HySpecNet11k) -> None:
dataset = HySpecNet11k(root=dataset.root, bands=(1, 2, 3))
match = 'Dataset does not contain some of the RGB bands'
with pytest.raises(RGBBandsMissingError, match=match):
dataset.plot(dataset[0])
1 change: 1 addition & 0 deletions tests/trainers/test_byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class TestBYOLTask:
'name',
[
'chesapeake_cvpr_prior_byol',
'hyspecnet_byol',
'seco_byol_1',
'seco_byol_2',
'ssl4eo_l_byol_1',
Expand Down
1 change: 1 addition & 0 deletions tests/trainers/test_moco.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TestMoCoTask:
'name',
[
'chesapeake_cvpr_prior_moco',
'hyspecnet_moco',
'seco_moco_1',
'seco_moco_2',
'ssl4eo_l_moco_1',
Expand Down
1 change: 1 addition & 0 deletions tests/trainers/test_simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TestSimCLRTask:
'name',
[
'chesapeake_cvpr_prior_simclr',
'hyspecnet_simclr',
'seco_simclr_1',
'seco_simclr_2',
'ssl4eo_l_simclr_1',
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .geo import BaseDataModule, GeoDataModule, NonGeoDataModule
from .geonrw import GeoNRWDataModule
from .gid15 import GID15DataModule
from .hyspecnet import HySpecNet11kDataModule
from .inria import InriaAerialImageLabelingDataModule
from .iobench import IOBenchDataModule
from .l7irish import L7IrishDataModule
Expand Down Expand Up @@ -75,6 +76,7 @@
'GID15DataModule',
'GeoDataModule',
'GeoNRWDataModule',
'HySpecNet11kDataModule',
'IOBenchDataModule',
'InriaAerialImageLabelingDataModule',
'L7IrishDataModule',
Expand Down
35 changes: 35 additions & 0 deletions torchgeo/datamodules/hyspecnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""HySpecNet datamodule."""

from typing import Any

import torch

from ..datasets import HySpecNet11k
from .geo import NonGeoDataModule


class HySpecNet11kDataModule(NonGeoDataModule):
"""LightningDataModule implementation for the HySpecNet11k dataset.

.. versionadded:: 0.7
"""

# https://git.tu-berlin.de/rsim/hyspecnet-tools/-/blob/main/tif_to_npy.ipynb
mean = torch.tensor(0)
std = torch.tensor(10000)

def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a new HySpecNet11kDataModule instance.

Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.HySpecNet11k`.
"""
super().__init__(HySpecNet11k, batch_size, num_workers, **kwargs)
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from .geonrw import GeoNRW
from .gid15 import GID15
from .globbiomass import GlobBiomass
from .hyspecnet import HySpecNet11k
from .idtrees import IDTReeS
from .inaturalist import INaturalist
from .inria import InriaAerialImageLabeling
Expand Down Expand Up @@ -214,6 +215,7 @@
'GeoDataset',
'GeoNRW',
'GlobBiomass',
'HySpecNet11k',
'IDTReeS',
'INaturalist',
'IOBench',
Expand Down
Loading