Skip to content

Commit

Permalink
fix: Use splat in pass_through_invoker (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon authored Jan 18, 2023
1 parent 3c09cd2 commit e5c3aa5
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 18 deletions.
18 changes: 10 additions & 8 deletions meltano/edk/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import sys
from abc import ABCMeta, abstractmethod
from enum import Enum
from typing import Any

import structlog
import yaml
from devtools.prettier import pformat

from meltano.edk import models
from meltano.edk.types import ExecArg


class DescribeFormat(str, Enum):
Expand All @@ -24,7 +24,7 @@ class DescribeFormat(str, Enum):
class ExtensionBase(metaclass=ABCMeta):
"""Basic extension interface that must be implemented by all extensions."""

def pre_invoke(self, invoke_name: str | None, *invoke_args: Any) -> None:
def pre_invoke(self, invoke_name: str | None, *invoke_args: ExecArg) -> None:
"""Called before the extension is invoked.
Args:
Expand All @@ -45,7 +45,7 @@ def initialize(self, force: bool = False) -> None:
pass

@abstractmethod
def invoke(self, command_name: str | None, *command_args: Any) -> None:
def invoke(self, command_name: str | None, *command_args: ExecArg) -> None:
"""Invoke method.
This method is called when the extension is invoked.
Expand All @@ -56,7 +56,7 @@ def invoke(self, command_name: str | None, *command_args: Any) -> None:
"""
pass

def post_invoke(self, invoked_name: str | None, *invoked_args: Any) -> None:
def post_invoke(self, invoked_name: str | None, *invoked_args: ExecArg) -> None:
"""Called after the extension is invoked.
Args:
Expand Down Expand Up @@ -100,7 +100,9 @@ def describe_formatted( # type: ignore[return]
)

def pass_through_invoker(
self, logger: structlog.BoundLogger, *command_args: Any
self,
logger: structlog.BoundLogger,
*command_args: ExecArg,
) -> None:
"""Pass-through invoker.
Expand All @@ -117,23 +119,23 @@ def pass_through_invoker(
command_args=command_args,
)
try:
self.pre_invoke(None, command_args)
self.pre_invoke(None, *command_args)
except Exception:
logger.exception(
"pre_invoke failed with uncaught exception, please report to maintainer"
)
sys.exit(1)

try:
self.invoke(None, command_args)
self.invoke(None, *command_args)
except Exception:
logger.exception(
"invoke failed with uncaught exception, please report to maintainer"
)
sys.exit(1)

try:
self.post_invoke(None, command_args)
self.post_invoke(None, *command_args)
except Exception:
logger.exception(
"post_invoke failed with uncaught exception, please report to maintainer" # noqa: E501
Expand Down
13 changes: 7 additions & 6 deletions meltano/edk/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import asyncio
import os
import subprocess
from typing import IO, Any, Union
from typing import IO, Any

import structlog

from meltano.edk.types import ExecArg

log = structlog.get_logger()
_ExecArg = Union[str, bytes]


def log_subprocess_error(
Expand Down Expand Up @@ -54,7 +55,7 @@ def __init__(

def run(
self,
*args: _ExecArg,
*args: ExecArg,
stdout: None | int | IO = subprocess.PIPE,
stderr: None | int | IO = subprocess.PIPE,
text: bool = True,
Expand Down Expand Up @@ -114,9 +115,9 @@ async def _log_stdio(reader: asyncio.streams.StreamReader) -> None:
async def _exec(
self,
sub_command: str | None = None,
*args: _ExecArg,
*args: ExecArg,
) -> asyncio.subprocess.Process:
popen_args: list[_ExecArg] = []
popen_args: list[ExecArg] = []
if sub_command:
popen_args.append(sub_command)
if args:
Expand Down Expand Up @@ -154,7 +155,7 @@ async def _exec(
def run_and_log(
self,
sub_command: str | None = None,
*args: _ExecArg,
*args: ExecArg,
) -> None:
"""Run a subprocess and stream the output to the logger.
Expand Down
7 changes: 7 additions & 0 deletions meltano/edk/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Types used in the Meltano EDK."""

from __future__ import annotations

from typing import Union

ExecArg = Union[str, bytes]
38 changes: 34 additions & 4 deletions tests/test_meltano_edk.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
from __future__ import annotations

import structlog

from meltano.edk import models
from meltano.edk.extension import ExtensionBase
from meltano.edk.types import ExecArg


class CustomExtension(ExtensionBase):
def __init__(self) -> None:
super().__init__()
self.history: list[tuple[str | None, tuple[ExecArg, ...]]] = []

def pre_invoke(self, invoke_name: str | None, *command_args: ExecArg) -> None:
self.history.append((invoke_name, ("pre", *command_args)))

def invoke(self, command_name: str | None, *command_args: ExecArg) -> None:
self.history.append((command_name, command_args))

class TestExtension(ExtensionBase):
def invoke(self, command_name: str | None, *command_args) -> None:
pass
def post_invoke(self, invoked_name: str | None, *command_args: ExecArg) -> None:
self.history.append((invoked_name, ("post", *command_args)))

def describe(self) -> models.Describe:
return models.Describe(
Expand All @@ -15,7 +28,24 @@ def describe(self) -> models.Describe:


def test_canary():
test = TestExtension()
test = CustomExtension()
assert test.describe() == models.Describe(
commands=[models.ExtensionCommand(name="test", description="test")]
)


def test_invoke():
test = CustomExtension()
test.invoke("echo", "test")
assert test.history == [("echo", ("test",))]


def test_pass_through():
test = CustomExtension()
logger = structlog.getLogger("test")
test.pass_through_invoker(logger, "test")
assert test.history == [
(None, ("pre", "test")),
(None, ("test",)),
(None, ("post", "test")),
]

0 comments on commit e5c3aa5

Please sign in to comment.