Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EuroSAT: redistribute split files on Hugging Face #2432

Merged
merged 2 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added tests/data/eurosat/EuroSAT100.zip
Binary file not shown.
2 changes: 2 additions & 0 deletions tests/data/eurosat/eurosat-100-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
AnnualCrop_1.tif
Forest_1.tif
2 changes: 2 additions & 0 deletions tests/data/eurosat/eurosat-100-train.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
AnnualCrop_1.tif
Forest_1.tif
2 changes: 2 additions & 0 deletions tests/data/eurosat/eurosat-100-val.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
AnnualCrop_1.tif
Forest_1.tif
2 changes: 2 additions & 0 deletions tests/data/eurosat/eurosat-spatial-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
AnnualCrop_1.tif
Forest_1.tif
2 changes: 2 additions & 0 deletions tests/data/eurosat/eurosat-spatial-train.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
AnnualCrop_1.tif
Forest_1.tif
2 changes: 2 additions & 0 deletions tests/data/eurosat/eurosat-spatial-val.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
AnnualCrop_1.tif
Forest_1.tif
36 changes: 6 additions & 30 deletions tests/datasets/test_eurosat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,34 +32,10 @@ def dataset(
) -> EuroSAT:
base_class: type[EuroSAT] = request.param[0]
split: str = request.param[1]
md5 = 'aa051207b0547daba0ac6af57808d68e'
monkeypatch.setattr(base_class, 'md5', md5)
url = os.path.join('tests', 'data', 'eurosat', 'EuroSATallBands.zip')
url = os.path.join('tests', 'data', 'eurosat') + os.sep
monkeypatch.setattr(base_class, 'url', url)
monkeypatch.setattr(base_class, 'filename', 'EuroSATallBands.zip')
monkeypatch.setattr(
base_class,
'split_urls',
{
'train': os.path.join('tests', 'data', 'eurosat', 'eurosat-train.txt'),
'val': os.path.join('tests', 'data', 'eurosat', 'eurosat-val.txt'),
'test': os.path.join('tests', 'data', 'eurosat', 'eurosat-test.txt'),
},
)
monkeypatch.setattr(
base_class,
'split_md5s',
{
'train': '4af60a00fdfdf8500572ae5360694b71',
'val': '4af60a00fdfdf8500572ae5360694b71',
'test': '4af60a00fdfdf8500572ae5360694b71',
},
)
root = tmp_path
transforms = nn.Identity()
return base_class(
root=root, split=split, transforms=transforms, download=True, checksum=True
)
return base_class(tmp_path, split=split, transforms=transforms, download=True)

def test_getitem(self, dataset: EuroSAT) -> None:
x = dataset[0]
Expand All @@ -84,14 +60,14 @@ def test_add(self, dataset: EuroSAT) -> None:
assert len(ds) == 4

def test_already_downloaded(self, dataset: EuroSAT, tmp_path: Path) -> None:
EuroSAT(root=tmp_path, download=True)
type(dataset)(tmp_path)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing got a bit ugly. We don't actually have to test EuroSATSpatial or EuroSAT100, we could just test EuroSAT and get 100% coverage.


def test_already_downloaded_not_extracted(
self, dataset: EuroSAT, tmp_path: Path
) -> None:
shutil.rmtree(dataset.root)
shutil.copy(dataset.url, tmp_path)
EuroSAT(root=tmp_path, download=False)
shutil.copy(dataset.url + dataset.filename, tmp_path)
type(dataset)(tmp_path)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
Expand All @@ -108,7 +84,7 @@ def test_plot(self, dataset: EuroSAT) -> None:
plt.close()

def test_plot_rgb(self, dataset: EuroSAT, tmp_path: Path) -> None:
dataset = EuroSAT(root=tmp_path, bands=('B03',))
dataset = type(dataset)(tmp_path, bands=('B03',))
with pytest.raises(
RGBBandsMissingError, match='Dataset does not contain some of the RGB bands'
):
Expand Down
37 changes: 16 additions & 21 deletions torchgeo/datasets/eurosat.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class EuroSAT(NonGeoClassificationDataset):
* https://ieeexplore.ieee.org/document/8519248
"""

url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSATallBands.zip'
url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/1ce6f1bfb56db63fd91b6ecc466ea67f2509774c/'
filename = 'EuroSATallBands.zip'
md5 = '5ac12b3b2557aa56e1826e981e8e200e'

Expand All @@ -64,10 +64,10 @@ class EuroSAT(NonGeoClassificationDataset):
)

splits = ('train', 'val', 'test')
split_urls: ClassVar[dict[str, str]] = {
'train': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-train.txt',
'val': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-val.txt',
'test': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-test.txt',
split_filenames: ClassVar[dict[str, str]] = {
'train': 'eurosat-train.txt',
'val': 'eurosat-val.txt',
'test': 'eurosat-test.txt',
}
split_md5s: ClassVar[dict[str, str]] = {
'train': '908f142e73d6acdf3f482c5e80d851b1',
Expand Down Expand Up @@ -141,7 +141,7 @@ def __init__(
self._verify()

valid_fns = set()
with open(os.path.join(self.root, f'eurosat-{split}.txt')) as f:
with open(os.path.join(self.root, self.split_filenames[split])) as f:
for fn in f:
valid_fns.add(fn.strip().replace('.jpg', '.tif'))

Expand Down Expand Up @@ -207,16 +207,12 @@ def _verify(self) -> None:
def _download(self) -> None:
"""Download the dataset."""
download_url(
self.url,
self.root,
filename=self.filename,
md5=self.md5 if self.checksum else None,
self.url + self.filename, self.root, md5=self.md5 if self.checksum else None
)
for split in self.splits:
download_url(
self.split_urls[split],
self.url + self.split_filenames[split],
self.root,
filename=f'eurosat-{split}.txt',
md5=self.split_md5s[split] if self.checksum else None,
)

Expand Down Expand Up @@ -305,10 +301,10 @@ class EuroSATSpatial(EuroSAT):
.. versionadded:: 0.6
"""

split_urls: ClassVar[dict[str, str]] = {
'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-train.txt',
'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-val.txt',
'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-test.txt',
split_filenames: ClassVar[dict[str, str]] = {
'train': 'eurosat-spatial-train.txt',
'val': 'eurosat-spatial-val.txt',
'test': 'eurosat-spatial-test.txt',
}
split_md5s: ClassVar[dict[str, str]] = {
'train': '7be3254be39f23ce4d4d144290c93292',
Expand All @@ -328,14 +324,13 @@ class EuroSAT100(EuroSAT):
.. versionadded:: 0.5
"""

url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSAT100.zip'
filename = 'EuroSAT100.zip'
md5 = 'c21c649ba747e86eda813407ef17d596'

split_urls: ClassVar[dict[str, str]] = {
'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-train.txt',
'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-val.txt',
'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-test.txt',
split_filenames: ClassVar[dict[str, str]] = {
'train': 'eurosat-100-train.txt',
'val': 'eurosat-100-val.txt',
'test': 'eurosat-100-test.txt',
}
split_md5s: ClassVar[dict[str, str]] = {
'train': '033d0c23e3a75e3fa79618b0e35fe1c7',
Expand Down