diff --git a/meltano/edk/extension.py b/meltano/edk/extension.py index fb5e024..68e7613 100644 --- a/meltano/edk/extension.py +++ b/meltano/edk/extension.py @@ -4,13 +4,13 @@ import sys from abc import ABCMeta, abstractmethod from enum import Enum -from typing import Any import structlog import yaml from devtools.prettier import pformat from meltano.edk import models +from meltano.edk.types import ExecArg class DescribeFormat(str, Enum): @@ -24,7 +24,7 @@ class DescribeFormat(str, Enum): class ExtensionBase(metaclass=ABCMeta): """Basic extension interface that must be implemented by all extensions.""" - def pre_invoke(self, invoke_name: str | None, *invoke_args: Any) -> None: + def pre_invoke(self, invoke_name: str | None, *invoke_args: ExecArg) -> None: """Called before the extension is invoked. Args: @@ -45,7 +45,7 @@ def initialize(self, force: bool = False) -> None: pass @abstractmethod - def invoke(self, command_name: str | None, *command_args: Any) -> None: + def invoke(self, command_name: str | None, *command_args: ExecArg) -> None: """Invoke method. This method is called when the extension is invoked. @@ -56,7 +56,7 @@ def invoke(self, command_name: str | None, *command_args: Any) -> None: """ pass - def post_invoke(self, invoked_name: str | None, *invoked_args: Any) -> None: + def post_invoke(self, invoked_name: str | None, *invoked_args: ExecArg) -> None: """Called after the extension is invoked. Args: @@ -100,7 +100,9 @@ def describe_formatted( # type: ignore[return] ) def pass_through_invoker( - self, logger: structlog.BoundLogger, *command_args: Any + self, + logger: structlog.BoundLogger, + *command_args: ExecArg, ) -> None: """Pass-through invoker. @@ -117,7 +119,7 @@ def pass_through_invoker( command_args=command_args, ) try: - self.pre_invoke(None, command_args) + self.pre_invoke(None, *command_args) except Exception: logger.exception( "pre_invoke failed with uncaught exception, please report to maintainer" @@ -125,7 +127,7 @@ def pass_through_invoker( sys.exit(1) try: - self.invoke(None, command_args) + self.invoke(None, *command_args) except Exception: logger.exception( "invoke failed with uncaught exception, please report to maintainer" @@ -133,7 +135,7 @@ def pass_through_invoker( sys.exit(1) try: - self.post_invoke(None, command_args) + self.post_invoke(None, *command_args) except Exception: logger.exception( "post_invoke failed with uncaught exception, please report to maintainer" # noqa: E501 diff --git a/meltano/edk/process.py b/meltano/edk/process.py index 179dc87..92b30e5 100644 --- a/meltano/edk/process.py +++ b/meltano/edk/process.py @@ -4,12 +4,13 @@ import asyncio import os import subprocess -from typing import IO, Any, Union +from typing import IO, Any import structlog +from meltano.edk.types import ExecArg + log = structlog.get_logger() -_ExecArg = Union[str, bytes] def log_subprocess_error( @@ -54,7 +55,7 @@ def __init__( def run( self, - *args: _ExecArg, + *args: ExecArg, stdout: None | int | IO = subprocess.PIPE, stderr: None | int | IO = subprocess.PIPE, text: bool = True, @@ -114,9 +115,9 @@ async def _log_stdio(reader: asyncio.streams.StreamReader) -> None: async def _exec( self, sub_command: str | None = None, - *args: _ExecArg, + *args: ExecArg, ) -> asyncio.subprocess.Process: - popen_args: list[_ExecArg] = [] + popen_args: list[ExecArg] = [] if sub_command: popen_args.append(sub_command) if args: @@ -154,7 +155,7 @@ async def _exec( def run_and_log( self, sub_command: str | None = None, - *args: _ExecArg, + *args: ExecArg, ) -> None: """Run a subprocess and stream the output to the logger. diff --git a/meltano/edk/types.py b/meltano/edk/types.py new file mode 100644 index 0000000..15303fd --- /dev/null +++ b/meltano/edk/types.py @@ -0,0 +1,7 @@ +"""Types used in the Meltano EDK.""" + +from __future__ import annotations + +from typing import Union + +ExecArg = Union[str, bytes] diff --git a/tests/test_meltano_edk.py b/tests/test_meltano_edk.py index f19b44a..bc7f5ee 100644 --- a/tests/test_meltano_edk.py +++ b/tests/test_meltano_edk.py @@ -1,12 +1,25 @@ from __future__ import annotations +import structlog + from meltano.edk import models from meltano.edk.extension import ExtensionBase +from meltano.edk.types import ExecArg + + +class CustomExtension(ExtensionBase): + def __init__(self) -> None: + super().__init__() + self.history: list[tuple[str | None, tuple[ExecArg, ...]]] = [] + + def pre_invoke(self, invoke_name: str | None, *command_args: ExecArg) -> None: + self.history.append((invoke_name, ("pre", *command_args))) + def invoke(self, command_name: str | None, *command_args: ExecArg) -> None: + self.history.append((command_name, command_args)) -class TestExtension(ExtensionBase): - def invoke(self, command_name: str | None, *command_args) -> None: - pass + def post_invoke(self, invoked_name: str | None, *command_args: ExecArg) -> None: + self.history.append((invoked_name, ("post", *command_args))) def describe(self) -> models.Describe: return models.Describe( @@ -15,7 +28,24 @@ def describe(self) -> models.Describe: def test_canary(): - test = TestExtension() + test = CustomExtension() assert test.describe() == models.Describe( commands=[models.ExtensionCommand(name="test", description="test")] ) + + +def test_invoke(): + test = CustomExtension() + test.invoke("echo", "test") + assert test.history == [("echo", ("test",))] + + +def test_pass_through(): + test = CustomExtension() + logger = structlog.getLogger("test") + test.pass_through_invoker(logger, "test") + assert test.history == [ + (None, ("pre", "test")), + (None, ("test",)), + (None, ("post", "test")), + ]