|
15 | 15 | import json
|
16 | 16 | import logging
|
17 | 17 | import os
|
| 18 | +import re |
18 | 19 | import shutil
|
19 | 20 | import sys
|
20 | 21 | import tarfile
|
|
30 | 31 | from monai.config.type_definitions import PathLike
|
31 | 32 | from monai.utils import look_up_option, min_version, optional_import
|
32 | 33 |
|
| 34 | +requests, has_requests = optional_import("requests") |
33 | 35 | gdown, has_gdown = optional_import("gdown", "4.7.3")
|
| 36 | +BeautifulSoup, has_bs4 = optional_import("bs4", name="BeautifulSoup") |
34 | 37 |
|
35 | 38 | if TYPE_CHECKING:
|
36 | 39 | from tqdm import tqdm
|
@@ -298,6 +301,29 @@ def extractall(
|
298 | 301 | )
|
299 | 302 |
|
300 | 303 |
|
| 304 | +def get_filename_from_url(data_url: str) -> str: |
| 305 | + """ |
| 306 | + Get the filename from the URL link. |
| 307 | + """ |
| 308 | + try: |
| 309 | + response = requests.head(data_url, allow_redirects=True) |
| 310 | + content_disposition = response.headers.get("Content-Disposition") |
| 311 | + if content_disposition: |
| 312 | + filename = re.findall('filename="?([^";]+)"?', content_disposition) |
| 313 | + if filename: |
| 314 | + return str(filename[0]) |
| 315 | + if "drive.google.com" in data_url: |
| 316 | + response = requests.get(data_url) |
| 317 | + if "text/html" in response.headers.get("Content-Type", ""): |
| 318 | + soup = BeautifulSoup(response.text, "html.parser") |
| 319 | + filename_div = soup.find("span", {"class": "uc-name-size"}) |
| 320 | + if filename_div: |
| 321 | + return str(filename_div.find("a").text) |
| 322 | + return _basename(data_url) |
| 323 | + except Exception as e: |
| 324 | + raise Exception(f"Error processing URL: {e}") from e |
| 325 | + |
| 326 | + |
301 | 327 | def download_and_extract(
|
302 | 328 | url: str,
|
303 | 329 | filepath: PathLike = "",
|
@@ -327,7 +353,18 @@ def download_and_extract(
|
327 | 353 | be False.
|
328 | 354 | progress: whether to display progress bar.
|
329 | 355 | """
|
| 356 | + url_filename_ext = "".join(Path(get_filename_from_url(url)).suffixes) |
| 357 | + filepath_ext = "".join(Path(_basename(filepath)).suffixes) |
| 358 | + if filepath not in ["", "."]: |
| 359 | + if filepath_ext == "": |
| 360 | + new_filepath = Path(filepath).with_suffix(url_filename_ext) |
| 361 | + logger.warning( |
| 362 | + f"filepath={filepath}, which missing file extension. Auto-appending extension to: {new_filepath}" |
| 363 | + ) |
| 364 | + filepath = new_filepath |
| 365 | + if filepath_ext and filepath_ext != url_filename_ext: |
| 366 | + raise ValueError(f"File extension mismatch: expected extension {url_filename_ext}, but get {filepath_ext}") |
330 | 367 | with tempfile.TemporaryDirectory() as tmp_dir:
|
331 |
| - filename = filepath or Path(tmp_dir, _basename(url)).resolve() |
| 368 | + filename = filepath or Path(tmp_dir, get_filename_from_url(url)).resolve() |
332 | 369 | download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress)
|
333 | 370 | extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base)
|
0 commit comments