Skip to content

fix(core): mypy #810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions core/testcontainers/compose/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
# flake8: noqa: F401
from testcontainers.compose.compose import (
ComposeContainer,
ContainerIsNotRunning,
DockerCompose,
NoSuchPortExposed,
PublishedPort,
PublishedPortModel,
)
from testcontainers.core.exceptions import ContainerIsNotRunning, NoSuchPortExposed

__all__ = [
"ComposeContainer",
"ContainerIsNotRunning",
"DockerCompose",
"NoSuchPortExposed",
"PublishedPortModel",
]
70 changes: 41 additions & 29 deletions core/testcontainers/compose/compose.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import asdict, dataclass, field, fields
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from functools import cached_property
from json import loads
from logging import warning
Expand All @@ -7,6 +7,7 @@
from re import split
from subprocess import CompletedProcess
from subprocess import run as subprocess_run
from types import TracebackType
from typing import Any, Callable, Literal, Optional, TypeVar, Union, cast
from urllib.error import HTTPError, URLError
from urllib.request import urlopen
Expand All @@ -18,35 +19,37 @@
_WARNINGS = {"DOCKER_COMPOSE_GET_CONFIG": "get_config is experimental, see testcontainers/testcontainers-python#669"}


def _ignore_properties(cls: type[_IPT], dict_: any) -> _IPT:
def _ignore_properties(cls: type[_IPT], dict_: Any) -> _IPT:
"""omits extra fields like @JsonIgnoreProperties(ignoreUnknown = true)

https://gist.github.com/alexanderankin/2a4549ac03554a31bef6eaaf2eaf7fd5"""
if isinstance(dict_, cls):
return dict_
if not is_dataclass(cls):
raise TypeError(f"Expected a dataclass type, got {cls}")
class_fields = {f.name for f in fields(cls)}
filtered = {k: v for k, v in dict_.items() if k in class_fields}
return cls(**filtered)
return cast("_IPT", cls(**filtered))


@dataclass
class PublishedPort:
class PublishedPortModel:
"""
Class that represents the response we get from compose when inquiring status
via `DockerCompose.get_running_containers()`.
"""

URL: Optional[str] = None
TargetPort: Optional[str] = None
PublishedPort: Optional[str] = None
TargetPort: Optional[int] = None
PublishedPort: Optional[int] = None
Protocol: Optional[str] = None

def normalize(self):
def normalize(self) -> "PublishedPortModel":
url_not_usable = system() == "Windows" and self.URL == "0.0.0.0"
if url_not_usable:
self_dict = asdict(self)
self_dict.update({"URL": "127.0.0.1"})
return PublishedPort(**self_dict)
return PublishedPortModel(**self_dict)
return self


Expand Down Expand Up @@ -75,19 +78,19 @@ class ComposeContainer:
Service: Optional[str] = None
State: Optional[str] = None
Health: Optional[str] = None
ExitCode: Optional[str] = None
Publishers: list[PublishedPort] = field(default_factory=list)
ExitCode: Optional[int] = None
Publishers: list[PublishedPortModel] = field(default_factory=list)

def __post_init__(self):
def __post_init__(self) -> None:
if self.Publishers:
self.Publishers = [_ignore_properties(PublishedPort, p) for p in self.Publishers]
self.Publishers = [_ignore_properties(PublishedPortModel, p) for p in self.Publishers]

def get_publisher(
self,
by_port: Optional[int] = None,
by_host: Optional[str] = None,
prefer_ip_version: Literal["IPV4", "IPv6"] = "IPv4",
) -> PublishedPort:
prefer_ip_version: Literal["IPv4", "IPv6"] = "IPv4",
) -> PublishedPortModel:
remaining_publishers = self.Publishers

remaining_publishers = [r for r in remaining_publishers if self._matches_protocol(prefer_ip_version, r)]
Expand All @@ -109,8 +112,9 @@ def get_publisher(
)

@staticmethod
def _matches_protocol(prefer_ip_version, r):
return (":" in r.URL) is (prefer_ip_version == "IPv6")
def _matches_protocol(prefer_ip_version: str, r: PublishedPortModel) -> bool:
r_url = r.URL
return (r_url is not None and ":" in r_url) is (prefer_ip_version == "IPv6")


@dataclass
Expand Down Expand Up @@ -164,7 +168,7 @@ class DockerCompose:
image: "hello-world"
"""

context: Union[str, PathLike]
context: Union[str, PathLike[str]]
compose_file_name: Optional[Union[str, list[str]]] = None
pull: bool = False
build: bool = False
Expand All @@ -175,15 +179,17 @@ class DockerCompose:
docker_command_path: Optional[str] = None
profiles: Optional[list[str]] = None

def __post_init__(self):
def __post_init__(self) -> None:
if isinstance(self.compose_file_name, str):
self.compose_file_name = [self.compose_file_name]

def __enter__(self) -> "DockerCompose":
self.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
def __exit__(
self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
) -> None:
self.stop(not self.keep_volumes)

def docker_compose_command(self) -> list[str]:
Expand Down Expand Up @@ -235,7 +241,7 @@ def start(self) -> None:

self._run_command(cmd=up_cmd)

def stop(self, down=True) -> None:
def stop(self, down: bool = True) -> None:
"""
Stops the docker compose environment.
"""
Expand Down Expand Up @@ -295,7 +301,7 @@ def get_config(
cmd_output = self._run_command(cmd=config_cmd).stdout
return cast(dict[str, Any], loads(cmd_output)) # noqa: TC006

def get_containers(self, include_all=False) -> list[ComposeContainer]:
def get_containers(self, include_all: bool = False) -> list[ComposeContainer]:
"""
Fetch information about running containers via `docker compose ps --format json`.
Available only in V2 of compose.
Expand Down Expand Up @@ -370,17 +376,18 @@ def exec_in_container(
"""
if not service_name:
service_name = self.get_container().Service
exec_cmd = [*self.compose_command_property, "exec", "-T", service_name, *command]
assert service_name
exec_cmd: list[str] = [*self.compose_command_property, "exec", "-T", service_name, *command]
result = self._run_command(cmd=exec_cmd)

return (result.stdout.decode("utf-8"), result.stderr.decode("utf-8"), result.returncode)
return result.stdout.decode("utf-8"), result.stderr.decode("utf-8"), result.returncode

def _run_command(
self,
cmd: Union[str, list[str]],
context: Optional[str] = None,
) -> CompletedProcess[bytes]:
context = context or self.context
context = context or str(self.context)
return subprocess_run(
cmd,
capture_output=True,
Expand All @@ -392,7 +399,7 @@ def get_service_port(
self,
service_name: Optional[str] = None,
port: Optional[int] = None,
):
) -> Optional[int]:
"""
Returns the mapped port for one of the services.

Expand All @@ -408,13 +415,14 @@ def get_service_port(
str:
The mapped port on the host
"""
return self.get_container(service_name).get_publisher(by_port=port).normalize().PublishedPort
normalize: PublishedPortModel = self.get_container(service_name).get_publisher(by_port=port).normalize()
return normalize.PublishedPort

def get_service_host(
self,
service_name: Optional[str] = None,
port: Optional[int] = None,
):
) -> Optional[str]:
"""
Returns the host for one of the services.

Expand All @@ -430,13 +438,17 @@ def get_service_host(
str:
The hostname for the service
"""
return self.get_container(service_name).get_publisher(by_port=port).normalize().URL
container: ComposeContainer = self.get_container(service_name)
publisher: PublishedPortModel = container.get_publisher(by_port=port)
normalize: PublishedPortModel = publisher.normalize()
url: Optional[str] = normalize.URL
return url

def get_service_host_and_port(
self,
service_name: Optional[str] = None,
port: Optional[int] = None,
):
) -> tuple[Optional[str], Optional[int]]:
publisher = self.get_container(service_name).get_publisher(by_port=port).normalize()
return publisher.URL, publisher.PublishedPort

Expand Down
4 changes: 3 additions & 1 deletion core/testcontainers/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from os import environ
from os.path import exists
from pathlib import Path
from typing import Optional, Union
from typing import Optional, Union, cast

import docker

Expand Down Expand Up @@ -36,6 +36,7 @@ def get_docker_socket() -> str:
client = docker.from_env()
try:
socket_path = client.api.get_adapter(client.api.base_url).socket_path
socket_path = cast("str", socket_path)
# return the normalized path as string
return str(Path(socket_path).absolute())
except AttributeError:
Expand Down Expand Up @@ -145,5 +146,6 @@ def timeout(self) -> int:
"SLEEP_TIME",
"TIMEOUT",
# Public API of this module:
"ConnectionMode",
"testcontainers_config",
]
Loading