diff --git a/src/tld/tests/test_core.py b/src/tld/tests/test_core.py index d416e36..24c3b61 100644 --- a/src/tld/tests/test_core.py +++ b/src/tld/tests/test_core.py @@ -306,6 +306,15 @@ def setUp(self): "tld": "com", "kwargs": {"fail_silently": True}, }, + { + "url": "sftp://sftp.test.com", + "fld": "test.com", + "subdomain": "sftp", + "domain": "test", + "suffix": "com", + "tld": "com", + "kwargs": {"fail_silently": True, "fix_protocol": True}, + }, ] self.bad_patterns = { diff --git a/src/tld/utils.py b/src/tld/utils.py index e49794e..f7f1aea 100644 --- a/src/tld/utils.py +++ b/src/tld/utils.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals import argparse +import re import sys from codecs import open as codecs_open from functools import lru_cache @@ -46,6 +47,7 @@ ) tld_names: Dict[str, Trie] = {} +protocol_re = re.compile(r"^(?:[a-z0-9.+-:]*)//") def get_tld_names_container() -> Dict[str, Trie]: @@ -306,7 +308,7 @@ def process_url( ) if not isinstance(url, SplitResult): - if fix_protocol and not url.startswith(("//", "http://", "https://")): + if fix_protocol and not protocol_re.match(url): url = f"https://{url}" # Get parsed URL as we might need it later