diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index abb0c6eaefa..cab21f872d4 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -124,6 +124,11 @@ GID-15 .. autoclass:: GID15DataModule +HySpecNet-11k +^^^^^^^^^^^^^ + +.. autoclass:: HySpecNet11kDataModule + Inria Aerial Image Labeling ^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 96ca225344a..99cc4d75427 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -312,6 +312,11 @@ GID-15 .. autoclass:: GID15 +HySpecNet-11k +^^^^^^^^^^^^^ + +.. autoclass:: HySpecNet11k + IDTReeS ^^^^^^^ diff --git a/docs/api/datasets/non_geo_datasets.csv b/docs/api/datasets/non_geo_datasets.csv index f000bc1d8da..2f9b64227fa 100644 --- a/docs/api/datasets/non_geo_datasets.csv +++ b/docs/api/datasets/non_geo_datasets.csv @@ -21,6 +21,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `Forest Damage`_,OD,Drone imagery,"CDLA-Permissive-1.0","1,543",4,"1,500x1,500",,RGB `GeoNRW`_,S,Aerial,"CC-BY-4.0","7,783",11,"1,000x1,000",1,"RGB, DEM" `GID-15`_,S,Gaofen-2,-,150,15,"6,800x7,200",3,RGB +`HySpecNet-11k`_,-,EnMAP,CC0-1.0,11k,-,128,30,HSI `IDTReeS`_,"OD,C",Aerial,"CC-BY-4.0",591,33,200x200,0.1--1,RGB `Inria Aerial Image Labeling`_,S,Aerial,-,360,2,"5,000x5,000",0.3,RGB `LandCover.ai`_,S,Aerial,"CC-BY-NC-SA-4.0","10,674",5,512x512,0.25--0.5,RGB diff --git a/tests/conf/hyspecnet_byol.yaml b/tests/conf/hyspecnet_byol.yaml new file mode 100644 index 00000000000..5c0fa31d609 --- /dev/null +++ b/tests/conf/hyspecnet_byol.yaml @@ -0,0 +1,11 @@ +model: + class_path: BYOLTask + init_args: + model: 'resnet18' + in_channels: 202 +data: + class_path: HySpecNet11kDataModule + init_args: + batch_size: 2 + dict_kwargs: + root: 'tests/data/hyspecnet' diff --git a/tests/conf/hyspecnet_moco.yaml b/tests/conf/hyspecnet_moco.yaml new file mode 100644 index 00000000000..732b83912c1 --- /dev/null +++ b/tests/conf/hyspecnet_moco.yaml @@ -0,0 +1,11 @@ +model: + class_path: MoCoTask + init_args: + model: 'resnet18' + in_channels: 202 +data: + class_path: HySpecNet11kDataModule + init_args: + batch_size: 2 + dict_kwargs: + root: 'tests/data/hyspecnet' diff --git a/tests/conf/hyspecnet_simclr.yaml b/tests/conf/hyspecnet_simclr.yaml new file mode 100644 index 00000000000..d16e8209326 --- /dev/null +++ b/tests/conf/hyspecnet_simclr.yaml @@ -0,0 +1,11 @@ +model: + class_path: SimCLRTask + init_args: + model: 'resnet18' + in_channels: 202 +data: + class_path: HySpecNet11kDataModule + init_args: + batch_size: 2 + dict_kwargs: + root: 'tests/data/hyspecnet' diff --git a/tests/data/hyspecnet/data.py b/tests/data/hyspecnet/data.py new file mode 100755 index 00000000000..3b4b701106e --- /dev/null +++ b/tests/data/hyspecnet/data.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil + +import numpy as np +import rasterio +from rasterio import Affine +from rasterio.crs import CRS + +SIZE = 32 +DTYPE = 'int16' + +np.random.seed(0) + +# Tile name purposefully shortened to avoid Windows git filename length limit. +tiles = ['ENMAP01_20221103T162438Z'] +patches = ['Y01460273_X05670694', 'Y01460273_X06950822'] + +profile = { + 'driver': 'GTiff', + 'dtype': DTYPE, + 'nodata': -32768.0, + 'width': SIZE, + 'height': SIZE, + 'count': 224, + 'crs': CRS.from_epsg(32618), + 'transform': Affine(30.0, 0.0, 691845.0, 0.0, -30.0, 4561935.0), + 'blockysize': 3, + 'tiled': False, + 'compress': 'deflate', + 'interleave': 'band', +} + +root = 'hyspecnet-11k' +path = os.path.join(root, 'splits', 'easy') +os.makedirs(path, exist_ok=True) +for tile in tiles: + for patch in patches: + # Split CSV + path = os.path.join(tile, f'{tile}-{patch}', f'{tile}-{patch}-DATA.npy') + for split in ['train', 'val', 'test']: + with open(os.path.join(root, 'splits', 'easy', f'{split}.csv'), 'a+') as f: + f.write(f'{path}\n') + + # Spectral image + path = os.path.join(root, 'patches', path) + os.makedirs(os.path.dirname(path), exist_ok=True) + path = path.replace('DATA.npy', 'SPECTRAL_IMAGE.TIF') + Z = np.random.randint( + np.iinfo(DTYPE).min, np.iinfo(DTYPE).max, size=(SIZE, SIZE), dtype=DTYPE + ) + with rasterio.open(path, 'w', **profile) as src: + for i in range(1, profile['count'] + 1): + src.write(Z, i) + +shutil.make_archive(f'{root}-01', 'gztar', '.', os.path.join(root, 'patches')) +shutil.make_archive(f'{root}-splits', 'gztar', '.', os.path.join(root, 'splits')) diff --git a/tests/data/hyspecnet/hyspecnet-11k-01.tar.gz b/tests/data/hyspecnet/hyspecnet-11k-01.tar.gz new file mode 100644 index 00000000000..b5a5ec766a5 Binary files /dev/null and b/tests/data/hyspecnet/hyspecnet-11k-01.tar.gz differ diff --git a/tests/data/hyspecnet/hyspecnet-11k-splits.tar.gz b/tests/data/hyspecnet/hyspecnet-11k-splits.tar.gz new file mode 100644 index 00000000000..152f71c040f Binary files /dev/null and b/tests/data/hyspecnet/hyspecnet-11k-splits.tar.gz differ diff --git a/tests/data/hyspecnet/hyspecnet-11k/patches/ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-SPECTRAL_IMAGE.TIF b/tests/data/hyspecnet/hyspecnet-11k/patches/ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-SPECTRAL_IMAGE.TIF new file mode 100644 index 00000000000..498bf304fa1 Binary files /dev/null and b/tests/data/hyspecnet/hyspecnet-11k/patches/ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-SPECTRAL_IMAGE.TIF differ diff --git a/tests/data/hyspecnet/hyspecnet-11k/patches/ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-SPECTRAL_IMAGE.TIF b/tests/data/hyspecnet/hyspecnet-11k/patches/ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-SPECTRAL_IMAGE.TIF new file mode 100644 index 00000000000..5142ff4fbcf Binary files /dev/null and b/tests/data/hyspecnet/hyspecnet-11k/patches/ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-SPECTRAL_IMAGE.TIF differ diff --git a/tests/data/hyspecnet/hyspecnet-11k/splits/easy/test.csv b/tests/data/hyspecnet/hyspecnet-11k/splits/easy/test.csv new file mode 100644 index 00000000000..14393bce82a --- /dev/null +++ b/tests/data/hyspecnet/hyspecnet-11k/splits/easy/test.csv @@ -0,0 +1,2 @@ +ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-DATA.npy +ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-DATA.npy diff --git a/tests/data/hyspecnet/hyspecnet-11k/splits/easy/train.csv b/tests/data/hyspecnet/hyspecnet-11k/splits/easy/train.csv new file mode 100644 index 00000000000..14393bce82a --- /dev/null +++ b/tests/data/hyspecnet/hyspecnet-11k/splits/easy/train.csv @@ -0,0 +1,2 @@ +ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-DATA.npy +ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-DATA.npy diff --git a/tests/data/hyspecnet/hyspecnet-11k/splits/easy/val.csv b/tests/data/hyspecnet/hyspecnet-11k/splits/easy/val.csv new file mode 100644 index 00000000000..14393bce82a --- /dev/null +++ b/tests/data/hyspecnet/hyspecnet-11k/splits/easy/val.csv @@ -0,0 +1,2 @@ +ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-DATA.npy +ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-DATA.npy diff --git a/tests/datasets/test_hyspecnet.py b/tests/datasets/test_hyspecnet.py new file mode 100644 index 00000000000..1e5a646cee6 --- /dev/null +++ b/tests/datasets/test_hyspecnet.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import glob +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch.nn as nn +from pytest import MonkeyPatch +from torch import Tensor + +from torchgeo.datasets import DatasetNotFoundError, HySpecNet11k, RGBBandsMissingError + +root = os.path.join('tests', 'data', 'hyspecnet') +md5s = {'hyspecnet-11k-01.tar.gz': '', 'hyspecnet-11k-splits.tar.gz': ''} + + +class TestHySpecNet11k: + @pytest.fixture + def dataset(self, monkeypatch: MonkeyPatch) -> HySpecNet11k: + monkeypatch.setattr(HySpecNet11k, 'url', root + os.sep) + monkeypatch.setattr(HySpecNet11k, 'md5s', md5s) + transforms = nn.Identity() + return HySpecNet11k(root, transforms=transforms) + + def test_getitem(self, dataset: HySpecNet11k) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x['image'], Tensor) + + def test_len(self, dataset: HySpecNet11k) -> None: + assert len(dataset) == 2 + + def test_download(self, dataset: HySpecNet11k, tmp_path: Path) -> None: + HySpecNet11k(tmp_path, download=True) + + def test_extract(self, dataset: HySpecNet11k, tmp_path: Path) -> None: + for file in glob.iglob(os.path.join(root, '*.tar.gz')): + shutil.copy(file, tmp_path) + HySpecNet11k(tmp_path) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + HySpecNet11k(tmp_path) + + def test_plot(self, dataset: HySpecNet11k) -> None: + x = dataset[0] + dataset.plot(x, suptitle='Test') + plt.close() + + def test_plot_rgb(self, dataset: HySpecNet11k) -> None: + dataset = HySpecNet11k(root=dataset.root, bands=(1, 2, 3)) + match = 'Dataset does not contain some of the RGB bands' + with pytest.raises(RGBBandsMissingError, match=match): + dataset.plot(dataset[0]) diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index b0c13b6075b..808bf937220 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -41,6 +41,7 @@ class TestBYOLTask: 'name', [ 'chesapeake_cvpr_prior_byol', + 'hyspecnet_byol', 'seco_byol_1', 'seco_byol_2', 'ssl4eo_l_byol_1', diff --git a/tests/trainers/test_moco.py b/tests/trainers/test_moco.py index 32c002dc573..002944b929e 100644 --- a/tests/trainers/test_moco.py +++ b/tests/trainers/test_moco.py @@ -29,6 +29,7 @@ class TestMoCoTask: 'name', [ 'chesapeake_cvpr_prior_moco', + 'hyspecnet_moco', 'seco_moco_1', 'seco_moco_2', 'ssl4eo_l_moco_1', diff --git a/tests/trainers/test_simclr.py b/tests/trainers/test_simclr.py index 7e1292ab7c0..3924b6e3785 100644 --- a/tests/trainers/test_simclr.py +++ b/tests/trainers/test_simclr.py @@ -29,6 +29,7 @@ class TestSimCLRTask: 'name', [ 'chesapeake_cvpr_prior_simclr', + 'hyspecnet_simclr', 'seco_simclr_1', 'seco_simclr_2', 'ssl4eo_l_simclr_1', diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index dc9513a6524..8f1cee47720 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -21,6 +21,7 @@ from .geo import BaseDataModule, GeoDataModule, NonGeoDataModule from .geonrw import GeoNRWDataModule from .gid15 import GID15DataModule +from .hyspecnet import HySpecNet11kDataModule from .inria import InriaAerialImageLabelingDataModule from .iobench import IOBenchDataModule from .l7irish import L7IrishDataModule @@ -75,6 +76,7 @@ 'GID15DataModule', 'GeoDataModule', 'GeoNRWDataModule', + 'HySpecNet11kDataModule', 'IOBenchDataModule', 'InriaAerialImageLabelingDataModule', 'L7IrishDataModule', diff --git a/torchgeo/datamodules/hyspecnet.py b/torchgeo/datamodules/hyspecnet.py new file mode 100644 index 00000000000..3e508ef11a7 --- /dev/null +++ b/torchgeo/datamodules/hyspecnet.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""HySpecNet datamodule.""" + +from typing import Any + +import torch + +from ..datasets import HySpecNet11k +from .geo import NonGeoDataModule + + +class HySpecNet11kDataModule(NonGeoDataModule): + """LightningDataModule implementation for the HySpecNet11k dataset. + + .. versionadded:: 0.7 + """ + + # https://git.tu-berlin.de/rsim/hyspecnet-tools/-/blob/main/tif_to_npy.ipynb + mean = torch.tensor(0) + std = torch.tensor(10000) + + def __init__( + self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + ) -> None: + """Initialize a new HySpecNet11kDataModule instance. + + Args: + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.HySpecNet11k`. + """ + super().__init__(HySpecNet11k, batch_size, num_workers, **kwargs) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 3016c3af7e2..d77e13475a0 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -61,6 +61,7 @@ from .geonrw import GeoNRW from .gid15 import GID15 from .globbiomass import GlobBiomass +from .hyspecnet import HySpecNet11k from .idtrees import IDTReeS from .inaturalist import INaturalist from .inria import InriaAerialImageLabeling @@ -214,6 +215,7 @@ 'GeoDataset', 'GeoNRW', 'GlobBiomass', + 'HySpecNet11k', 'IDTReeS', 'INaturalist', 'IOBench', diff --git a/torchgeo/datasets/hyspecnet.py b/torchgeo/datasets/hyspecnet.py new file mode 100644 index 00000000000..412ea504b24 --- /dev/null +++ b/torchgeo/datasets/hyspecnet.py @@ -0,0 +1,229 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""HySpecNet dataset.""" + +import os +from collections.abc import Callable, Sequence +from typing import ClassVar + +import rasterio as rio +import torch +from einops import rearrange +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from torch import Tensor + +from .errors import DatasetNotFoundError, RGBBandsMissingError +from .geo import NonGeoDataset +from .utils import Path, download_url, extract_archive, percentile_normalization + +# https://git.tu-berlin.de/rsim/hyspecnet-tools/-/blob/main/tif_to_npy.ipynb +invalid_channels = [ + 126, + 127, + 128, + 129, + 130, + 131, + 132, + 133, + 134, + 135, + 136, + 137, + 138, + 139, + 140, + 160, + 161, + 162, + 163, + 164, + 165, + 166, +] +valid_channels_ids = [c + 1 for c in range(224) if c not in invalid_channels] + + +class HySpecNet11k(NonGeoDataset): + """HySpecNet-11k dataset. + + `HySpecNet-11k `__ is a large-scale + benchmark dataset for hyperspectral image compression and self-supervised learning. + It is made up of 11,483 nonoverlapping image patches acquired by the + `EnMAP satellite `_. Each patch is a portion of 128 x 128 + pixels with 224 spectral bands and with a ground sample distance of 30 m. + + To construct HySpecNet-11k, a total of 250 EnMAP tiles acquired during the routine + operation phase between 2 November 2022 and 9 November 2022 were considered. The + considered tiles are associated with less than 10% cloud and snow cover. The tiles + were radiometrically, geometrically and atmospherically corrected (L2A water & land + product). Then, the tiles were divided into nonoverlapping image patches. The + cropped patches at the borders of the tiles were eliminated. As a result, more than + 45 patches per tile are obtained, resulting in 11,483 patches for the full dataset. + + We provide predefined splits obtained by randomly dividing HySpecNet into: + + #. a training set that includes 70% of the patches, + #. a validation set that includes 20% of the patches, and + #. a test set that includes 10% of the patches. + + Depending on the way that we used for splitting the dataset, we define two + different splits: + + #. an easy split, where patches from the same tile can be present in different sets + (patchwise splitting); and + #. a hard split, where all patches from one tile belong to the same set + (tilewise splitting). + + If you use this dataset in your research, please cite the following paper: + + * https://arxiv.org/abs/2306.00385 + + .. versionadded:: 0.7 + """ + + url = 'https://hf.co/datasets/torchgeo/hyspecnet/resolve/13e110422a6925cbac0f11edff610219b9399227/' + md5s: ClassVar[dict[str, str]] = { + 'hyspecnet-11k-01.tar.gz': '974aae9197006727b42ec81796049efe', + 'hyspecnet-11k-02.tar.gz': 'f80574485f835b8a263b6c64076c0c62', + 'hyspecnet-11k-03.tar.gz': '6bc1de573f97fa4a75b79719b9270cb3', + 'hyspecnet-11k-04.tar.gz': '2463dc10653cb8be10d44951307c5e7d', + 'hyspecnet-11k-05.tar.gz': '16c1bd9e684673e741c0849bd015c988', + 'hyspecnet-11k-06.tar.gz': '8eef16b67d71af6eb4bc836d294fe3c4', + 'hyspecnet-11k-07.tar.gz': 'f61f0e7d6b05c861e69026b09130a5d6', + 'hyspecnet-11k-08.tar.gz': '19d390bc9e61b85e7d765f3077984976', + 'hyspecnet-11k-09.tar.gz': '197ff47befe5b9de88be5e1321c5ce5d', + 'hyspecnet-11k-10.tar.gz': '9e674cca126a9d139d6584be148d4bac', + 'hyspecnet-11k-splits.tar.gz': '94fad9e3c979c612c29a045406247d6c', + } + + all_bands = valid_channels_ids + rgb_bands = (43, 28, 10) + + def __init__( + self, + root: Path = 'data', + split: str = 'train', + strategy: str = 'easy', + bands: Sequence[int] = all_bands, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new HySpecNet11k instance. + + Args: + root: Root directory where dataset can be found. + split: One of 'train', 'val', or 'test'. + strategy: Either 'easy' for patchwise splitting or 'hard' for tilewise + splitting. + bands: Bands to return. + transforms: A function/transform that takes input sample and its target as + entry and returns a transformed version. + download: If True, download dataset and store it in the root directory. + checksum: If True, check the MD5 of the downloaded files (may be slow). + + Raises: + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + self.root = root + self.split = split + self.strategy = strategy + self.bands = bands + self.transforms = transforms + self.download = download + self.checksum = checksum + + self._verify() + + path = os.path.join(root, 'hyspecnet-11k', 'splits', strategy, f'{split}.csv') + with open(path) as f: + self.files = f.read().strip().split('\n') + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + Length of the dataset. + """ + return len(self.files) + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: Index to return. + + Returns: + Data and label at that index. + """ + file = self.files[index].replace('DATA.npy', 'SPECTRAL_IMAGE.TIF') + with rio.open(os.path.join(self.root, 'hyspecnet-11k', 'patches', file)) as src: + sample = {'image': torch.tensor(src.read(self.bands).astype('float32'))} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the extracted files already exist + exists = [] + for directory in ['patches', 'splits']: + path = os.path.join(self.root, 'hyspecnet-11k', directory) + exists.append(os.path.isdir(path)) + + if all(exists): + return + + for file, md5 in self.md5s.items(): + # Check if the file has already been downloaded + path = os.path.join(self.root, file) + if os.path.isfile(path): + extract_archive(path) + continue + + # Check if the user requested to download the dataset + if self.download: + url = self.url + file + download_url(url, self.root, md5=md5 if self.checksum else None) + extract_archive(path) + continue + + raise DatasetNotFoundError(self) + + def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: A sample returned by :meth:`__getitem__`. + suptitle: optional string to use as a suptitle + + Returns: + A matplotlib Figure with the rendered sample. + + Raises: + RGBBandsMissingError: If *bands* does not include all RGB bands. + """ + rgb_indices = [] + for band in self.rgb_bands: + if band in self.bands: + rgb_indices.append(self.bands.index(band)) + else: + raise RGBBandsMissingError() + + image = sample['image'][rgb_indices].cpu().numpy() + image = rearrange(image, 'c h w -> h w c') + image = percentile_normalization(image) + + fig, ax = plt.subplots() + ax.imshow(image) + ax.axis('off') + + if suptitle: + fig.suptitle(suptitle) + + return fig