Skip to content

refactor: first iteration of fixing mypy complaints #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
9 changes: 5 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ jobs:
if: always()
run: |
pixi run --environment dev lint --diff
# - name: Mypy
# if: always()
# run: |
# pixi run --environment dev type-check

- name: Mypy
if: always()
run: |
pixi run --environment dev type-check

- name: Collect QC
run: echo "All quality control checks passed"
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ platforms = ["osx-arm64", "linux-64"]

[tool.pixi.pypi-dependencies]


[tool.pixi.environments]
dev = { features = ["dev"] }
publish = { features = ["publish"] }
Expand All @@ -52,12 +51,15 @@ lint.ignore = ["E721"]
disallow_untyped_defs = true
warn_no_return = true


[[tool.mypy.overrides]]
# TODO:: figure out expected types for the TestRegistryBase class
module = "snakemake_interface_common.plugin_registry.tests"
ignore_errors = true

[[tool.mypy.overrides]]
module = "argparse_dataclass"
ignore_missing_imports = true

[tool.pixi.feature.dev.tasks.test]
cmd = [
"pytest",
Expand Down
17 changes: 10 additions & 7 deletions src/snakemake_interface_common/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
__email__ = "[email protected]"
__license__ = "MIT"

from pathlib import Path
import sys
import textwrap
from typing import Optional
from typing import Optional, Any

from snakemake_interface_common.rules import RuleInterface

Expand All @@ -16,7 +15,11 @@ class ApiError(Exception):


class WorkflowError(Exception):
def format_arg(self, arg):
lineno: Optional[int]
snakefile: Optional[str]
rule: Optional[RuleInterface]

def format_arg(self, arg: object) -> str:
if isinstance(arg, str):
return arg
elif isinstance(arg, WorkflowError):
Expand All @@ -38,9 +41,9 @@ def format_arg(self, arg):

def __init__(
self,
*args,
*args: Any,
lineno: Optional[int] = None,
snakefile: Optional[Path] = None,
snakefile: Optional[str] = None,
rule: Optional[RuleInterface] = None,
):
if rule is not None:
Expand All @@ -55,12 +58,12 @@ def __init__(
if args and isinstance(args[0], str):
spec = self._get_spec(self)
if spec:
args = [f"{args[0]} ({spec})"] + list(args[1:])
args = tuple([f"{args[0]} ({spec})"] + list(args[1:]))

super().__init__("\n".join(self.format_arg(arg) for arg in args))

@classmethod
def _get_spec(cls, exc):
def _get_spec(cls, exc: "WorkflowError") -> str:
spec = ""
if exc.rule is not None:
spec += f"rule {exc.rule.name}"
Expand Down
8 changes: 6 additions & 2 deletions src/snakemake_interface_common/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
__copyright__ = "Copyright 2023, Johannes Köster"
__email__ = "[email protected]"
__license__ = "MIT"
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import logging

def get_logger():

def get_logger() -> "logging.Logger":
"""Retrieve the logger singleton from snakemake."""
from snakemake.logging import logger
from snakemake.logging import logger # type: ignore

return logger
26 changes: 17 additions & 9 deletions src/snakemake_interface_common/plugin_registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,32 @@
import types
import pkgutil
import importlib
from typing import List, Mapping
from typing import List, Mapping, TYPE_CHECKING, TypeVar, Generic

from snakemake_interface_common.exceptions import InvalidPluginException
from snakemake_interface_common.plugin_registry.plugin import PluginBase
from snakemake_interface_common.plugin_registry.attribute_types import AttributeType

if TYPE_CHECKING:
from argparse import ArgumentParser

class PluginRegistryBase(ABC):
TPlugin = TypeVar("TPlugin", bound=PluginBase, covariant=True)


class PluginRegistryBase(ABC, Generic[TPlugin]):
"""This class is a singleton that holds all registered executor plugins."""

_instance = None
plugins: dict[str, TPlugin]

def __new__(cls):
def __new__(
cls: type["PluginRegistryBase[TPlugin]"],
) -> "PluginRegistryBase[TPlugin]":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __init__(self):
def __init__(self) -> None:
if hasattr(self, "plugins"):
# init has been called before
return
Expand Down Expand Up @@ -53,13 +61,13 @@ def get_plugin_package_name(self, plugin_name: str) -> str:
"""Get the package name of a plugin by name."""
return f"{self.module_prefix.replace('_', '-')}{plugin_name}"

def register_cli_args(self, argparser):
def register_cli_args(self, argparser: "ArgumentParser") -> None:
"""Add arguments derived from self.executor_settings to given
argparser."""
for _, plugin in self.plugins.items():
plugin.register_cli_args(argparser)

def collect_plugins(self):
def collect_plugins(self) -> None:
"""Collect plugins and call register_plugin for each."""
self.plugins = dict()

Expand All @@ -77,7 +85,7 @@ def collect_plugins(self):
module = importlib.import_module(moduleinfo.name)
self.register_plugin(moduleinfo.name, module)

def register_plugin(self, name: str, plugin: types.ModuleType):
def register_plugin(self, name: str, plugin: types.ModuleType) -> None:
"""Validate and register a plugin.

Does nothing if the plugin is already registered.
Expand All @@ -92,7 +100,7 @@ def register_plugin(self, name: str, plugin: types.ModuleType):

self.plugins[plugin_name] = self.load_plugin(plugin_name, plugin)

def validate_plugin(self, name: str, module: types.ModuleType):
def validate_plugin(self, name: str, module: types.ModuleType) -> None:
"""Validate a plugin for attributes and naming"""
expected_attributes = self.expected_attributes()
for attr, attr_type in expected_attributes.items():
Expand Down Expand Up @@ -123,7 +131,7 @@ def validate_plugin(self, name: str, module: types.ModuleType):
def module_prefix(self) -> str: ...

@abstractmethod
def load_plugin(self, name: str, module: types.ModuleType) -> PluginBase:
def load_plugin(self, name: str, module: types.ModuleType) -> TPlugin:
"""Load a plugin by name."""
...

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ class AttributeType:
kind: AttributeKind = AttributeKind.OBJECT

@property
def is_optional(self):
def is_optional(self) -> bool:
return self.mode == AttributeMode.OPTIONAL

@property
def is_class(self):
def is_class(self) -> bool:
return self.kind == AttributeKind.CLASS

def into_required(self):
def into_required(self) -> "AttributeType":
return AttributeType(cls=self.cls, mode=AttributeMode.REQUIRED, kind=self.kind)
Loading