diff --git a/.github/workflows/cibuild.yml b/.github/workflows/cibuild.yml index 0e44c4b..9b81e30 100644 --- a/.github/workflows/cibuild.yml +++ b/.github/workflows/cibuild.yml @@ -33,6 +33,15 @@ jobs: - run: pip install tox - run: tox -e flake8 + mypy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v1 + - uses: actions/setup-python@v2 + - run: python -m pip install --upgrade pip + - run: pip install tox + - run: tox -e mypy + unit_tests: strategy: fail-fast: false diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2708aa6..e596b1e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,3 +21,9 @@ repos: args: [--config=.flake8] language: system files: \.py$ + - id: mypy + name: mypy + entry: mypy + stages: [commit] + language: system + files: \.py$ diff --git a/pyproject.toml b/pyproject.toml index 15a19d1..1c56279 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,3 +18,8 @@ line_length = 119 multi_line_output = 3 use_parentheses = true include_trailing_comma = true + + +[tool.mypy] +exclude = "thirdparty" + diff --git a/seleniumwire/__main__.py b/seleniumwire/__main__.py index 82c17e0..9144442 100644 --- a/seleniumwire/__main__.py +++ b/seleniumwire/__main__.py @@ -2,6 +2,7 @@ import logging import signal from argparse import RawDescriptionHelpFormatter +from typing import Callable, Dict from seleniumwire import backend, utils @@ -23,10 +24,13 @@ def standalone_proxy(port=0, addr='127.0.0.1'): signal.signal(signal.SIGINT, lambda *_: b.shutdown()) +# Mapping of command names to the command callables +COMMANDS: Dict[str, Callable] = {'extractcert': utils.extract_cert, 'standaloneproxy': standalone_proxy} + + if __name__ == '__main__': - commands = {'extractcert': utils.extract_cert, 'standaloneproxy': standalone_proxy} parser = argparse.ArgumentParser( - description='\n\nsupported commands: \n %s' % '\n '.join(sorted(commands)), + description='\n\nsupported commands: \n %s' % '\n '.join(sorted(COMMANDS)), formatter_class=RawDescriptionHelpFormatter, usage='python -m seleniumwire ', ) @@ -40,10 +44,10 @@ def standalone_proxy(port=0, addr='127.0.0.1'): args = parser.parse_args() pargs = [arg for arg in args.args if '=' not in arg and arg is not args.command] - kwargs = dict([tuple(arg.split('=')) for arg in args.args if '=' in arg]) + kwargs: Dict[str, str] = dict([arg.split('=') for arg in args.args if '=' in arg]) try: - commands[args.command](*pargs, **kwargs) + COMMANDS[args.command](*pargs, **kwargs) except KeyError: print("Unsupported command '{}' (use --help for list of commands)".format(args.command)) except TypeError as e: diff --git a/seleniumwire/inspect.py b/seleniumwire/inspect.py index 7e03f89..6dc3dcc 100644 --- a/seleniumwire/inspect.py +++ b/seleniumwire/inspect.py @@ -1,8 +1,8 @@ import inspect import time -from typing import Iterator, List, Optional, Union +from typing import Callable, Iterator, List, Optional, Union -from selenium.common.exceptions import TimeoutException +from selenium.common.exceptions import TimeoutException # type: ignore from seleniumwire import har from seleniumwire.request import Request @@ -23,18 +23,18 @@ def requests(self) -> List[Request]: A list of Request instances representing the requests made between the browser and server. """ - return self.backend.storage.load_requests() + return self.backend.storage.load_requests() # type: ignore @requests.deleter def requests(self): - self.backend.storage.clear_requests() + self.backend.storage.clear_requests() # type: ignore def iter_requests(self) -> Iterator[Request]: """Return an iterator of requests. Returns: An iterator. """ - yield from self.backend.storage.iter_requests() + yield from self.backend.storage.iter_requests() # type: ignore @property def last_request(self) -> Optional[Request]: @@ -46,7 +46,7 @@ def last_request(self) -> Optional[Request]: A Request instance representing the last request made, or None if no requests have been made. """ - return self.backend.storage.load_last_request() + return self.backend.storage.load_last_request() # type: ignore def wait_for_request(self, pat: str, timeout: Union[int, float] = 10) -> Request: """Wait up to the timeout period for a request matching the specified @@ -73,7 +73,7 @@ def wait_for_request(self, pat: str, timeout: Union[int, float] = 10) -> Request start = time.time() while time.time() - start < timeout: - request = self.backend.storage.find(pat) + request = self.backend.storage.find(pat) # type: ignore if request is None: time.sleep(1 / 5) @@ -91,7 +91,7 @@ def har(self) -> str: Returns: A JSON string of HAR data. """ - return har.generate_har(self.backend.storage.load_har_entries()) + return har.generate_har(self.backend.storage.load_har_entries()) # type: ignore @property def header_overrides(self): @@ -119,7 +119,7 @@ def header_overrides(self): ('*.somewhere-else.com.*', {'User-Agent': 'Chrome'}) ] """ - return self.backend.modifier.headers + return self.backend.modifier.headers # type: ignore @header_overrides.setter def header_overrides(self, headers): @@ -129,16 +129,16 @@ def header_overrides(self, headers): else: self._validate_headers(headers) - self.backend.modifier.headers = headers + self.backend.modifier.headers = headers # type: ignore def _validate_headers(self, headers): for v in headers.values(): if v is not None: assert isinstance(v, str), 'Header values must be strings' - @header_overrides.deleter + @header_overrides.deleter # type: ignore def header_overrides(self): - del self.backend.modifier.headers + del self.backend.modifier.headers # type: ignore @property def param_overrides(self): @@ -164,15 +164,15 @@ def param_overrides(self): ('*.somewhere-else.com.*', {'x': 'y'}), ] """ - return self.backend.modifier.params + return self.backend.modifier.params # type: ignore @param_overrides.setter def param_overrides(self, params): - self.backend.modifier.params = params + self.backend.modifier.params = params # type: ignore @param_overrides.deleter def param_overrides(self): - del self.backend.modifier.params + del self.backend.modifier.params # type: ignore @property def body_overrides(self): @@ -194,15 +194,15 @@ def body_overrides(self): ('*.somewhere-else.com.*', '{"x":"y"}'), ] """ - return self.backend.modifier.bodies + return self.backend.modifier.bodies # type: ignore @body_overrides.setter def body_overrides(self, bodies): - self.backend.modifier.bodies = bodies + self.backend.modifier.bodies = bodies # type: ignore @body_overrides.deleter def body_overrides(self): - del self.backend.modifier.bodies + del self.backend.modifier.bodies # type: ignore @property def querystring_overrides(self): @@ -223,15 +223,15 @@ def querystring_overrides(self): ('*.somewhere-else.com.*', 'a=b&c=d'), ] """ - return self.backend.modifier.querystring + return self.backend.modifier.querystring # type: ignore @querystring_overrides.setter def querystring_overrides(self, querystrings): - self.backend.modifier.querystring = querystrings + self.backend.modifier.querystring = querystrings # type: ignore @querystring_overrides.deleter def querystring_overrides(self): - del self.backend.modifier.querystring + del self.backend.modifier.querystring # type: ignore @property def rewrite_rules(self): @@ -248,15 +248,15 @@ def rewrite_rules(self): (r'https://docs.python.org/2/', r'https://docs.python.org/3/'), ] """ - return self.backend.modifier.rewrite_rules + return self.backend.modifier.rewrite_rules # type: ignore @rewrite_rules.setter def rewrite_rules(self, rewrite_rules): - self.backend.modifier.rewrite_rules = rewrite_rules + self.backend.modifier.rewrite_rules = rewrite_rules # type: ignore @rewrite_rules.deleter def rewrite_rules(self): - del self.backend.modifier.rewrite_rules + del self.backend.modifier.rewrite_rules # type: ignore @property def scopes(self) -> List[str]: @@ -271,48 +271,48 @@ def scopes(self) -> List[str]: '.*github.*' ] """ - return self.backend.scopes + return self.backend.scopes # type: ignore @scopes.setter def scopes(self, scopes: List[str]): - self.backend.scopes = scopes + self.backend.scopes = scopes # type: ignore @scopes.deleter def scopes(self): - self.backend.scopes = [] + self.backend.scopes = [] # type: ignore @property - def request_interceptor(self) -> callable: + def request_interceptor(self) -> Callable: """A callable that will be used to intercept/modify requests. The callable must accept a single argument for the request being intercepted. """ - return self.backend.request_interceptor + return self.backend.request_interceptor # type: ignore @request_interceptor.setter - def request_interceptor(self, interceptor: callable): - self.backend.request_interceptor = interceptor + def request_interceptor(self, interceptor: Callable): + self.backend.request_interceptor = interceptor # type: ignore @request_interceptor.deleter def request_interceptor(self): - self.backend.request_interceptor = None + self.backend.request_interceptor = None # type: ignore @property - def response_interceptor(self) -> callable: + def response_interceptor(self) -> Callable: """A callable that will be used to intercept/modify responses. The callable must accept two arguments: the response being intercepted and the originating request. """ - return self.backend.response_interceptor + return self.backend.response_interceptor # type: ignore @response_interceptor.setter - def response_interceptor(self, interceptor: callable): + def response_interceptor(self, interceptor: Callable): if len(inspect.signature(interceptor).parameters) != 2: raise RuntimeError('A response interceptor takes two parameters: the request and response') - self.backend.response_interceptor = interceptor + self.backend.response_interceptor = interceptor # type: ignore @response_interceptor.deleter def response_interceptor(self): - self.backend.response_interceptor = None + self.backend.response_interceptor = None # type: ignore diff --git a/seleniumwire/request.py b/seleniumwire/request.py index 7d640a7..84a8510 100644 --- a/seleniumwire/request.py +++ b/seleniumwire/request.py @@ -19,6 +19,8 @@ def __repr__(self): class Request: """Represents an HTTP request.""" + _body: bytes + def __init__(self, *, method: str, url: str, headers: Iterable[Tuple[str, str]], body: bytes = b''): """Initialise a new Request object. @@ -119,7 +121,7 @@ def host(self) -> str: """ return urlsplit(self.url).netloc - @path.setter + @path.setter # type: ignore def path(self, p: str): parts = list(urlsplit(self.url)) parts[2] = p diff --git a/seleniumwire/undetected_chromedriver/__init__.py b/seleniumwire/undetected_chromedriver/__init__.py index a470664..dce2efd 100644 --- a/seleniumwire/undetected_chromedriver/__init__.py +++ b/seleniumwire/undetected_chromedriver/__init__.py @@ -1,5 +1,5 @@ try: - import undetected_chromedriver as uc + import undetected_chromedriver as uc # type: ignore except ImportError as e: raise ImportError( 'undetected_chromedriver not found. ' 'Install it with `pip install undetected_chromedriver`.' @@ -8,5 +8,5 @@ from seleniumwire.webdriver import Chrome uc._Chrome = Chrome -Chrome = uc.Chrome +Chrome = uc.Chrome # type: ignore ChromeOptions = uc.ChromeOptions # noqa: F811 diff --git a/seleniumwire/undetected_chromedriver/v2.py b/seleniumwire/undetected_chromedriver/v2.py index 24bfdaa..6604a5d 100644 --- a/seleniumwire/undetected_chromedriver/v2.py +++ b/seleniumwire/undetected_chromedriver/v2.py @@ -1,7 +1,7 @@ import logging -import undetected_chromedriver.v2 as uc -from selenium.webdriver import DesiredCapabilities +import undetected_chromedriver.v2 as uc # type: ignore +from selenium.webdriver import DesiredCapabilities # type: ignore from seleniumwire import backend from seleniumwire.inspect import InspectRequestsMixin diff --git a/seleniumwire/utils.py b/seleniumwire/utils.py index aa091b7..ac3ec44 100644 --- a/seleniumwire/utils.py +++ b/seleniumwire/utils.py @@ -2,10 +2,9 @@ import logging import os import pkgutil -from collections import namedtuple from pathlib import Path -from typing import Dict, NamedTuple -from urllib.request import _parse_proxy +from typing import Dict, List, Sequence, Union +from urllib.request import _parse_proxy # type: ignore from seleniumwire.thirdparty.mitmproxy.net.http import encoding as decoder @@ -16,7 +15,7 @@ COMBINED_CERT = 'seleniumwire-ca.pem' -def get_upstream_proxy(options): +def get_upstream_proxy(options: Dict[str, Union[str, Dict]]) -> Dict[str, str]: """Get the upstream proxy configuration from the options dictionary. This will be overridden with any configuration found in the environment variables HTTP_PROXY, HTTPS_PROXY, NO_PROXY @@ -34,13 +33,16 @@ def get_upstream_proxy(options): options: The selenium wire options. Returns: A dictionary. """ - proxy_options = (options or {}).pop('proxy', {}) + try: + proxy_options: Dict[str, str] = options['proxy'] # type: ignore + except KeyError: + proxy_options = {} http_proxy = os.environ.get('HTTP_PROXY') https_proxy = os.environ.get('HTTPS_PROXY') no_proxy = os.environ.get('NO_PROXY') - merged = {} + merged: Dict[str, str] = {} if http_proxy: merged['http'] = http_proxy @@ -51,22 +53,10 @@ def get_upstream_proxy(options): merged.update(proxy_options) - no_proxy = merged.get('no_proxy') - if isinstance(no_proxy, str): - merged['no_proxy'] = [h.strip() for h in no_proxy.split(',')] - - conf = namedtuple('ProxyConf', 'scheme username password hostport') - - for proxy_type in ('http', 'https'): - # Parse the upstream proxy URL into (scheme, username, password, hostport) - # for ease of access. - if merged.get(proxy_type) is not None: - merged[proxy_type] = conf(*_parse_proxy(merged[proxy_type])) - return merged -def build_proxy_args(proxy_config: Dict[str, NamedTuple]) -> Dict[str, str]: +def build_proxy_args(proxy_config: Dict[str, str]) -> Dict[str, Union[str, List[str]]]: """Build the arguments needed to pass an upstream proxy to mitmproxy. Args: @@ -77,21 +67,15 @@ def build_proxy_args(proxy_config: Dict[str, NamedTuple]) -> Dict[str, str]: https_proxy = proxy_config.get('https') conf = None - if http_proxy and https_proxy: - if http_proxy.hostport != https_proxy.hostport: # noqa - # We only support a single upstream proxy server - raise ValueError('Different settings for http and https proxy servers not supported') - + if https_proxy: conf = https_proxy elif http_proxy: conf = http_proxy - elif https_proxy: - conf = https_proxy - args = {} + args: Dict[str, Union[str, List[str]]] = {} if conf: - scheme, username, password, hostport = conf + scheme, username, password, hostport = _parse_proxy(conf) args['mode'] = 'upstream:{}://{}'.format(scheme, hostport) @@ -106,17 +90,17 @@ def build_proxy_args(proxy_config: Dict[str, NamedTuple]) -> Dict[str, str]: no_proxy = proxy_config.get('no_proxy') if no_proxy: - args['no_proxy'] = no_proxy + args['no_proxy'] = [h.strip() for h in no_proxy.split(',')] return args -def extract_cert(cert_name='ca.crt'): +def extract_cert(cert_name: str = 'ca.crt') -> None: """Extracts the root certificate to the current working directory.""" - try: - cert = pkgutil.get_data(__package__, cert_name) - except FileNotFoundError: + cert = pkgutil.get_data(__package__, cert_name) + + if cert is None: log.error("Invalid certificate '{}'".format(cert_name)) else: with open(Path(os.getcwd(), cert_name), 'wb') as out: @@ -124,7 +108,7 @@ def extract_cert(cert_name='ca.crt'): log.info('{} extracted. You can now import this into a browser.'.format(cert_name)) -def extract_cert_and_key(dest_folder, check_exists=True): +def extract_cert_and_key(dest_folder: Union[str, Path], check_exists: bool = True) -> None: """Extracts the root certificate and key and combines them into a single file called seleniumwire-ca.pem in the specified destination folder. @@ -143,11 +127,14 @@ def extract_cert_and_key(dest_folder, check_exists=True): root_cert = pkgutil.get_data(__package__, ROOT_CERT) root_key = pkgutil.get_data(__package__, ROOT_KEY) - with open(combined_path, 'wb') as f_out: - f_out.write(root_cert + root_key) + if root_cert is None or root_key is None: + log.error('Root certificate and/or key missing') + else: + with open(combined_path, 'wb') as f_out: + f_out.write(root_cert + root_key) -def is_list_alike(container): +def is_list_alike(container: Union[str, Sequence]) -> bool: return isinstance(container, collections.abc.Sequence) and not isinstance(container, str) @@ -168,7 +155,7 @@ def urlsafe_address(address): return addr, port -def decode(data: bytes, encoding: str) -> bytes: +def decode(data: bytes, encoding: str) -> Union[None, str, bytes]: """Attempt to decode data based on the supplied encoding. If decoding fails a ValueError is raised. diff --git a/seleniumwire/webdriver.py b/seleniumwire/webdriver.py index 77be642..b53972b 100644 --- a/seleniumwire/webdriver.py +++ b/seleniumwire/webdriver.py @@ -54,25 +54,25 @@ def proxy(self) -> Dict[str, Any]: """Get the proxy configuration for the driver.""" conf = {} - mode = getattr(self.backend.master.options, 'mode') + mode = getattr(self.backend.master.options, 'mode') # type: ignore if mode and mode.startswith('upstream'): upstream = mode.split('upstream:')[1] scheme, *rest = upstream.split('://') - auth = getattr(self.backend.master.options, 'upstream_auth') + auth = getattr(self.backend.master.options, 'upstream_auth') # type: ignore if auth: conf[scheme] = f'{scheme}://{auth}@{rest[0]}' else: conf[scheme] = f'{scheme}://{rest[0]}' - no_proxy = getattr(self.backend.master.options, 'no_proxy') + no_proxy = getattr(self.backend.master.options, 'no_proxy') # type: ignore if no_proxy: conf['no_proxy'] = ','.join(no_proxy) - custom_auth = getattr(self.backend.master.options, 'upstream_custom_auth') + custom_auth = getattr(self.backend.master.options, 'upstream_custom_auth') # type: ignore if custom_auth: conf['custom_authorization'] = custom_auth @@ -93,7 +93,9 @@ def proxy(self, proxy_conf: Dict[str, Any]): Args: proxy_conf: The proxy configuration. """ - self.backend.master.options.update(**build_proxy_args(get_upstream_proxy({'proxy': proxy_conf}))) + self.backend.master.options.update( # type: ignore + **build_proxy_args(get_upstream_proxy({'proxy': proxy_conf})) + ) class Firefox(InspectRequestsMixin, DriverCommonMixin, _Firefox): diff --git a/tests/seleniumwire/test_utils.py b/tests/seleniumwire/test_utils.py index 927466f..60aa6f1 100644 --- a/tests/seleniumwire/test_utils.py +++ b/tests/seleniumwire/test_utils.py @@ -241,9 +241,10 @@ def test_extract_cert(self, mock_pkgutil, mock_getcwd): m_open.assert_called_once_with(Path('cwd', 'ca.crt'), 'wb') m_open.return_value.write.assert_called_once_with(b'cert_data') + @patch('seleniumwire.utils.log') @patch('seleniumwire.utils.pkgutil') - def test_extract_cert_not_found(self, mock_pkgutil): - mock_pkgutil.get_data.side_effect = FileNotFoundError + def test_extract_cert_not_found(self, mock_pkgutil, mock_log): + mock_pkgutil.get_data.return_value = None m_open = mock_open() with patch('seleniumwire.utils.open', m_open): @@ -251,6 +252,7 @@ def test_extract_cert_not_found(self, mock_pkgutil): mock_pkgutil.get_data.assert_called_once_with('seleniumwire', 'foo.crt') m_open.assert_not_called() + mock_log.error.assert_called_once() @patch('seleniumwire.utils.os') @patch('seleniumwire.utils.pkgutil') @@ -291,6 +293,21 @@ def test_extract_cert_and_key_no_check(self, mock_path, mock_os): m_open.assert_called_once() + @patch('seleniumwire.utils.log') + @patch('seleniumwire.utils.os') + @patch('seleniumwire.utils.pkgutil') + @patch('seleniumwire.utils.Path') + def test_extract_cert_and_key_not_found(self, mock_path, mock_pkgutil, mock_os, mock_log): + mock_path.return_value.exists.return_value = False + mock_pkgutil.get_data.side_effect = (None, None) + m_open = mock_open() + + with patch('seleniumwire.utils.open', m_open): + extract_cert_and_key(Path('some', 'path')) + + m_open.assert_not_called() + mock_log.error.assert_called_once() + def test_urlsafe_address_ipv4(): assert urlsafe_address(('192.168.0.1', 9999)) == ('192.168.0.1', 9999) diff --git a/tox.ini b/tox.ini index 33cdef8..0e39c49 100644 --- a/tox.ini +++ b/tox.ini @@ -7,6 +7,7 @@ envlist = isort black flake8 + mypy [testenv] setenv = @@ -34,6 +35,12 @@ deps = commands = flake8 seleniumwire +[testenv:mypy] +deps = + mypy +commands = + mypy seleniumwire tests + [testenv:e2e] commands = pytest -s -vv --tb=native tests/end2end