|
1 | 1 | #
|
2 |
| -# Copyright 2020 IBM Corp. All Rights Reserved. |
| 2 | +# Copyright 2020--2021 IBM Corp. All Rights Reserved. |
3 | 3 | #
|
4 | 4 | # Licensed under the Apache License, Version 2.0 (the "License");
|
5 | 5 | # you may not use this file except in compliance with the License.
|
|
23 | 23 | from copy import deepcopy
|
24 | 24 | import dataclasses
|
25 | 25 | import functools
|
| 26 | +import hashlib |
26 | 27 | from textwrap import dedent
|
27 | 28 | from typing import Any, Callable, Dict, Iterable, Optional, Tuple, TypeVar, Union, cast
|
| 29 | +import os |
28 | 30 | from packaging.version import parse as version_parser
|
| 31 | +import re |
29 | 32 |
|
30 | 33 | from ._config import Config
|
31 | 34 | from ._dataset import Dataset
|
32 | 35 | from . import typing as typing_
|
33 | 36 | from ._schema import (DatasetSchemaCollection, FormatSchemaCollection, LicenseSchemaCollection,
|
34 | 37 | SchemaDict, SchemaCollectionManager)
|
| 38 | +from ._schema_retrieval import is_url |
35 | 39 |
|
36 | 40 | # Global configurations --------------------------------------------------
|
37 | 41 |
|
@@ -208,6 +212,69 @@ def load_dataset(name: str, *,
|
208 | 212 | f'\nCaused by:\n{e}')
|
209 | 213 |
|
210 | 214 |
|
| 215 | +def load_dataset_from_location(url_or_path: Union[str, typing_.PathLike], *, |
| 216 | + schema: Optional[SchemaDict] = None, |
| 217 | + force_redownload: bool = False) -> Dict[str, Any]: |
| 218 | + """ Load the dataset from ``url_or_path``. This function is equivalent to calling :class:`~pardata.Dataset`, where |
| 219 | + ``schema['download_url']`` is set to ``url_or_path``. In the returned :class:`dict` object, keys corresponding to |
| 220 | + empty values are removed (unlike :meth:`~pardata.Dataset.load`). |
| 221 | +
|
| 222 | + :param url_or_path: The URL or path of the dataset archive. |
| 223 | + :param schema: The schema used for loading the dataset. If ``None``, it is set to a default schema that is designed |
| 224 | + to accommodate most common use cases. |
| 225 | + :param force_redownload: ``True`` if to force redownloading the dataset. |
| 226 | + :return: A dictionary that holds the dataset. It is structured the same as the return value of :func:`load_dataset`. |
| 227 | + """ |
| 228 | + |
| 229 | + if not is_url(str(url_or_path)): |
| 230 | + url_or_path = os.path.abspath(url_or_path) # Don't use pathlib.Path.resolve because it resolves symlinks |
| 231 | + url_or_path = cast(str, url_or_path) |
| 232 | + |
| 233 | + # Name of the data dir: {url_or_path with non-alphanums replaced by dashes}-sha512. The sha512 suffix is there to |
| 234 | + # prevent collision. |
| 235 | + data_dir_name = (f'{re.sub("[^0-9a-zA-Z]+", "-", url_or_path)}-' |
| 236 | + f'{hashlib.sha512(url_or_path.encode("utf-8")).hexdigest()}') |
| 237 | + data_dir = get_config().DATADIR / '_location_direct' / data_dir_name |
| 238 | + if schema is None: |
| 239 | + # Construct the default schema |
| 240 | + schema = { |
| 241 | + 'name': 'Direct from a location', |
| 242 | + 'description': 'Loaded directly from a location', |
| 243 | + 'subdatasets': { |
| 244 | + } |
| 245 | + } |
| 246 | + |
| 247 | + RegexFormatPair = namedtuple('RegexFormatPair', ['regex', 'format']) |
| 248 | + regex_format_pairs = ( |
| 249 | + RegexFormatPair(regex=r'.*\.csv', format='table/csv'), |
| 250 | + RegexFormatPair(regex=r'.*\.wav', format='audio/wav'), |
| 251 | + RegexFormatPair(regex=r'.*\.(txt|log)', format='text/plain'), |
| 252 | + RegexFormatPair(regex=r'.*\.(jpg|jpeg)', format='image/jpeg'), |
| 253 | + RegexFormatPair(regex=r'.*\.png', format='image/png'), |
| 254 | + ) |
| 255 | + |
| 256 | + for regex_format_pair in regex_format_pairs: |
| 257 | + schema['subdatasets'][regex_format_pair.format] = { |
| 258 | + 'format': { |
| 259 | + 'id': regex_format_pair.format, |
| 260 | + }, |
| 261 | + 'path': { |
| 262 | + 'type': 'regex', |
| 263 | + 'value': regex_format_pair.regex |
| 264 | + } |
| 265 | + } |
| 266 | + schema['download_url'] = url_or_path |
| 267 | + |
| 268 | + dataset = Dataset(schema=schema, data_dir=data_dir, mode=Dataset.InitializationMode.LAZY) |
| 269 | + if force_redownload or not dataset.is_downloaded(): |
| 270 | + dataset.download(check=False, # Already checked by `is_downloaded` call above |
| 271 | + verify_checksum=False) |
| 272 | + dataset.load() |
| 273 | + |
| 274 | + # strip empty values |
| 275 | + return {k: v for k, v in dataset.data.items() if len(v) > 0} |
| 276 | + |
| 277 | + |
211 | 278 | @_handle_name_param
|
212 | 279 | @_handle_version_param
|
213 | 280 | def get_dataset_metadata(name: str, *, version: str = 'latest') -> SchemaDict:
|
|
0 commit comments