Skip to content

Commit b3b20a6

Browse files
authored
Add a convenient function to load a dataset from a particular location (#270)
Close #262
1 parent fbe2362 commit b3b20a6

File tree

3 files changed

+100
-2
lines changed

3 files changed

+100
-2
lines changed

pardata/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,6 @@
2525
init,
2626
list_all_datasets,
2727
load_dataset,
28+
load_dataset_from_location,
2829
load_schema_collections)
2930
from ._version import version as __version__

pardata/_high_level.py

+68-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# Copyright 2020 IBM Corp. All Rights Reserved.
2+
# Copyright 2020--2021 IBM Corp. All Rights Reserved.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -23,15 +23,19 @@
2323
from copy import deepcopy
2424
import dataclasses
2525
import functools
26+
import hashlib
2627
from textwrap import dedent
2728
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, TypeVar, Union, cast
29+
import os
2830
from packaging.version import parse as version_parser
31+
import re
2932

3033
from ._config import Config
3134
from ._dataset import Dataset
3235
from . import typing as typing_
3336
from ._schema import (DatasetSchemaCollection, FormatSchemaCollection, LicenseSchemaCollection,
3437
SchemaDict, SchemaCollectionManager)
38+
from ._schema_retrieval import is_url
3539

3640
# Global configurations --------------------------------------------------
3741

@@ -208,6 +212,69 @@ def load_dataset(name: str, *,
208212
f'\nCaused by:\n{e}')
209213

210214

215+
def load_dataset_from_location(url_or_path: Union[str, typing_.PathLike], *,
216+
schema: Optional[SchemaDict] = None,
217+
force_redownload: bool = False) -> Dict[str, Any]:
218+
""" Load the dataset from ``url_or_path``. This function is equivalent to calling :class:`~pardata.Dataset`, where
219+
``schema['download_url']`` is set to ``url_or_path``. In the returned :class:`dict` object, keys corresponding to
220+
empty values are removed (unlike :meth:`~pardata.Dataset.load`).
221+
222+
:param url_or_path: The URL or path of the dataset archive.
223+
:param schema: The schema used for loading the dataset. If ``None``, it is set to a default schema that is designed
224+
to accommodate most common use cases.
225+
:param force_redownload: ``True`` if to force redownloading the dataset.
226+
:return: A dictionary that holds the dataset. It is structured the same as the return value of :func:`load_dataset`.
227+
"""
228+
229+
if not is_url(str(url_or_path)):
230+
url_or_path = os.path.abspath(url_or_path) # Don't use pathlib.Path.resolve because it resolves symlinks
231+
url_or_path = cast(str, url_or_path)
232+
233+
# Name of the data dir: {url_or_path with non-alphanums replaced by dashes}-sha512. The sha512 suffix is there to
234+
# prevent collision.
235+
data_dir_name = (f'{re.sub("[^0-9a-zA-Z]+", "-", url_or_path)}-'
236+
f'{hashlib.sha512(url_or_path.encode("utf-8")).hexdigest()}')
237+
data_dir = get_config().DATADIR / '_location_direct' / data_dir_name
238+
if schema is None:
239+
# Construct the default schema
240+
schema = {
241+
'name': 'Direct from a location',
242+
'description': 'Loaded directly from a location',
243+
'subdatasets': {
244+
}
245+
}
246+
247+
RegexFormatPair = namedtuple('RegexFormatPair', ['regex', 'format'])
248+
regex_format_pairs = (
249+
RegexFormatPair(regex=r'.*\.csv', format='table/csv'),
250+
RegexFormatPair(regex=r'.*\.wav', format='audio/wav'),
251+
RegexFormatPair(regex=r'.*\.(txt|log)', format='text/plain'),
252+
RegexFormatPair(regex=r'.*\.(jpg|jpeg)', format='image/jpeg'),
253+
RegexFormatPair(regex=r'.*\.png', format='image/png'),
254+
)
255+
256+
for regex_format_pair in regex_format_pairs:
257+
schema['subdatasets'][regex_format_pair.format] = {
258+
'format': {
259+
'id': regex_format_pair.format,
260+
},
261+
'path': {
262+
'type': 'regex',
263+
'value': regex_format_pair.regex
264+
}
265+
}
266+
schema['download_url'] = url_or_path
267+
268+
dataset = Dataset(schema=schema, data_dir=data_dir, mode=Dataset.InitializationMode.LAZY)
269+
if force_redownload or not dataset.is_downloaded():
270+
dataset.download(check=False, # Already checked by `is_downloaded` call above
271+
verify_checksum=False)
272+
dataset.load()
273+
274+
# strip empty values
275+
return {k: v for k, v in dataset.data.items() if len(v) > 0}
276+
277+
211278
@_handle_name_param
212279
@_handle_version_param
213280
def get_dataset_metadata(name: str, *, version: str = 'latest') -> SchemaDict:

tests/test_high_level.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pydantic import ValidationError
2525

2626
from pardata import (describe_dataset, export_schema_collections, get_config, get_dataset_metadata, init,
27-
list_all_datasets, load_dataset, load_schema_collections)
27+
list_all_datasets, load_dataset, load_dataset_from_location, load_schema_collections)
2828
from pardata.dataset import Dataset
2929
from pardata._config import Config
3030
from pardata._high_level import _get_schema_collections
@@ -204,6 +204,36 @@ def test_loading_undownloaded(self, tmp_path):
204204
'(by calling this function with `download=True` for at least once)?') in str(e.value)
205205

206206

207+
class TestLoadDatasetFromLocation:
208+
"Test ``load_dataset_from_location."
209+
210+
def test_loading_dataset_from_path(self, downloaded_gmb_dataset, dataset_dir):
211+
for force_redownload in ('False', 'False', 'True'):
212+
data = load_dataset_from_location(dataset_dir / 'gmb-1.0.2.zip', force_redownload=force_redownload)
213+
assert frozenset(data.keys()) == frozenset(('text/plain',))
214+
assert frozenset(data['text/plain'].keys()) == frozenset((
215+
'groningen_meaning_bank_modified/gmb_subset_full.txt',
216+
'groningen_meaning_bank_modified/LICENSE.txt',
217+
'groningen_meaning_bank_modified/README.txt'
218+
))
219+
220+
def test_loading_dataset_from_url(self, gmb_schema):
221+
for force_redownload in ('False', 'False', 'True'):
222+
data = load_dataset_from_location(gmb_schema['download_url'], force_redownload=force_redownload)
223+
assert frozenset(data.keys()) == frozenset(('text/plain',))
224+
assert frozenset(data['text/plain'].keys()) == frozenset((
225+
'groningen_meaning_bank_modified/gmb_subset_full.txt',
226+
'groningen_meaning_bank_modified/LICENSE.txt',
227+
'groningen_meaning_bank_modified/README.txt'
228+
))
229+
230+
def test_custom_schema(self, gmb_schema):
231+
data = load_dataset_from_location(gmb_schema['download_url'], schema=gmb_schema)
232+
assert frozenset(data.keys()) == frozenset(('gmb_subset_full',))
233+
assert data['gmb_subset_full'].startswith('Masked VBN O\n')
234+
assert data['gmb_subset_full'].endswith('. . O\n\n')
235+
236+
207237
def test_get_dataset_metadata():
208238
"Test ``get_dataset_metadata``."
209239

0 commit comments

Comments
 (0)