diff --git a/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py b/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py index 6ebab47f..f345bc2b 100644 --- a/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py +++ b/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py @@ -21,13 +21,13 @@ from gooddata_flexconnect.function.function_task import FlexConnectFunctionTask _LOGGER = structlog.get_logger("gooddata_flexconnect.rpc") -_DEFAULT_TASK_WAIT = 60.0 class _FlexConnectServerMethods(FlightServerMethods): - def __init__(self, ctx: ServerContext, registry: FlexConnectFunctionRegistry) -> None: + def __init__(self, ctx: ServerContext, registry: FlexConnectFunctionRegistry, call_deadline_ms: float) -> None: self._ctx = ctx self._registry = registry + self._call_deadline = call_deadline_ms / 1000 @staticmethod def _create_descriptor(fun_name: str, metadata: Optional[dict]) -> pyarrow.flight.FlightDescriptor: @@ -148,8 +148,13 @@ def get_flight_info( try: # XXX: this should be enhanced to implement polling - task_result = self._ctx.task_executor.wait_for_result(task.task_id, _DEFAULT_TASK_WAIT) + task_result = self._ctx.task_executor.wait_for_result(task.task_id, self._call_deadline) except TaskWaitTimeoutError: + cancelled = self._ctx.task_executor.cancel(task.task_id) + _LOGGER.warning( + "flexconnect_fun_call_timeout", task_id=task.task_id, fun=task.fun_name, cancelled=cancelled + ) + raise ErrorInfo.for_reason( ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task.task_id}." ).to_timeout_error() @@ -195,6 +200,27 @@ def do_get( _FLEX_CONNECT_CONFIG_SECTION = "flexconnect" _FLEX_CONNECT_FUNCTION_LIST = "functions" +_FLEX_CONNECT_CALL_DEADLINE_MS = "call_deadline_ms" +_DEFAULT_FLEX_CONNECT_CALL_DEADLINE_MS = 180_000 + + +def _read_call_deadline_ms(ctx: ServerContext) -> int: + call_deadline = ctx.settings.get(f"{_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_CALL_DEADLINE_MS}") + if call_deadline is None: + return _DEFAULT_FLEX_CONNECT_CALL_DEADLINE_MS + + try: + call_deadline_ms = int(call_deadline) + if call_deadline_ms <= 0: + raise ValueError() + + return call_deadline_ms + except ValueError: + raise ValueError( + f"Value of {_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_CALL_DEADLINE_MS} must " + f"be a positive number - duration, in milliseconds, that FlexConnect function " + f"calls are expected to run." + ) @flight_server_methods @@ -209,7 +235,9 @@ def create_flexconnect_flight_methods(ctx: ServerContext) -> FlightServerMethods :return: new instance of Flight RPC server methods to integrate into the server """ modules = list(ctx.settings.get(f"{_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_FUNCTION_LIST}") or []) + call_deadline_ms = _read_call_deadline_ms(ctx) + _LOGGER.info("flexconnect_init", modules=modules) registry = FlexConnectFunctionRegistry().load(ctx, modules) - return _FlexConnectServerMethods(ctx, registry) + return _FlexConnectServerMethods(ctx, registry, call_deadline_ms) diff --git a/gooddata-flexconnect/gooddata_flexconnect/function/function_task.py b/gooddata-flexconnect/gooddata_flexconnect/function/function_task.py index baf11bae..0da016a3 100644 --- a/gooddata-flexconnect/gooddata_flexconnect/function/function_task.py +++ b/gooddata-flexconnect/gooddata_flexconnect/function/function_task.py @@ -45,6 +45,14 @@ def run(self) -> Union[TaskResult, TaskError]: headers=self._headers, ) + # switch task to non-cancellable state; once the code creates + # and returns the result, the task successfully executed and there + # is nothing to cancel. + # + # NOTE: if the switch finds that task got cancelled already, it + # bails and raises error. + self.switch_non_cancellable() + return FlightDataTaskResult.for_data(result) def on_task_cancel(self) -> None: diff --git a/gooddata-flexconnect/tests/server/conftest.py b/gooddata-flexconnect/tests/server/conftest.py index 6b5e5ef0..8e87f3ed 100644 --- a/gooddata-flexconnect/tests/server/conftest.py +++ b/gooddata-flexconnect/tests/server/conftest.py @@ -76,10 +76,11 @@ def flexconnect_server( tls: bool = False, mtls: bool = False, ) -> GoodDataFlightServer: - envvar = ", ".join([f'"{module}"' for module in modules]) - envvar = f"[{envvar}]" + funs = ", ".join([f'"{module}"' for module in modules]) + funs = f"[{funs}]" - os.environ["GOODDATA_FLIGHT_FLEXCONNECT__FUNCTIONS"] = envvar + os.environ["GOODDATA_FLIGHT_FLEXCONNECT__FUNCTIONS"] = funs + os.environ["GOODDATA_FLIGHT_FLEXCONNECT__CALL_DEADLINE_MS"] = "500" with server(create_flexconnect_flight_methods, tls, mtls) as s: yield s diff --git a/gooddata-flexconnect/tests/server/funs/fun1.py b/gooddata-flexconnect/tests/server/funs/fun1.py index 106be05b..ee61f41f 100644 --- a/gooddata-flexconnect/tests/server/funs/fun1.py +++ b/gooddata-flexconnect/tests/server/funs/fun1.py @@ -5,8 +5,8 @@ from gooddata_flight_server import ArrowData -class _SimpleFun(FlexConnectFunction): - Name = "SimpleFun" +class _SimpleFun1(FlexConnectFunction): + Name = "SimpleFun1" Schema = pyarrow.schema( fields=[ pyarrow.field("col1", pyarrow.int64()), diff --git a/gooddata-flexconnect/tests/server/funs/fun2.py b/gooddata-flexconnect/tests/server/funs/fun2.py index 393e1599..8d646a8e 100644 --- a/gooddata-flexconnect/tests/server/funs/fun2.py +++ b/gooddata-flexconnect/tests/server/funs/fun2.py @@ -8,8 +8,8 @@ _DATA: Optional[pyarrow.Table] = None -class _SimpleFun(FlexConnectFunction): - Name = "SimpleFun" +class _SimpleFun2(FlexConnectFunction): + Name = "SimpleFun2" Schema = pyarrow.schema( fields=[ pyarrow.field("col1", pyarrow.int64()), @@ -37,5 +37,5 @@ def on_load(ctx: ServerContext) -> None: "col2": ["a", "b", "c"], "col3": [True, False, True], }, - schema=_SimpleFun.Schema, + schema=_SimpleFun2.Schema, ) diff --git a/gooddata-flexconnect/tests/server/funs/fun3.py b/gooddata-flexconnect/tests/server/funs/fun3.py new file mode 100644 index 00000000..42ac365e --- /dev/null +++ b/gooddata-flexconnect/tests/server/funs/fun3.py @@ -0,0 +1,39 @@ +# (C) 2024 GoodData Corporation +import time +from typing import Optional + +import pyarrow +from gooddata_flexconnect.function.function import FlexConnectFunction +from gooddata_flight_server import ArrowData + +_DATA: Optional[pyarrow.Table] = None + + +class _LongRunningFun(FlexConnectFunction): + Name = "LongRunningFun" + Schema = pyarrow.schema( + fields=[ + pyarrow.field("col1", pyarrow.int64()), + pyarrow.field("col2", pyarrow.string()), + pyarrow.field("col3", pyarrow.bool_()), + ] + ) + + def call( + self, + parameters: dict, + columns: tuple[str, ...], + headers: dict[str, list[str]], + ) -> ArrowData: + # sleep is intentionally setup to be longer than the deadline for + # the function invocation (see conftest.py // flexconnect_server fixture) + time.sleep(1) + + return pyarrow.table( + data={ + "col1": [1, 2, 3], + "col2": ["a", "b", "c"], + "col3": [True, False, True], + }, + schema=self.Schema, + ) diff --git a/gooddata-flexconnect/tests/server/test_flexconnect_server.py b/gooddata-flexconnect/tests/server/test_flexconnect_server.py index a18969bb..42fc3ffd 100644 --- a/gooddata-flexconnect/tests/server/test_flexconnect_server.py +++ b/gooddata-flexconnect/tests/server/test_flexconnect_server.py @@ -1,7 +1,10 @@ # (C) 2024 GoodData Corporation import orjson import pyarrow.flight +import pytest +from gooddata_flight_server import ErrorCode +from tests.assert_error_info import assert_error_code from tests.server.conftest import flexconnect_server @@ -16,12 +19,12 @@ def test_basic_function(): assert fun_info.descriptor.command is not None assert len(fun_info.descriptor.command) cmd = orjson.loads(fun_info.descriptor.command) - assert cmd["functionName"] == "SimpleFun" + assert cmd["functionName"] == "SimpleFun1" descriptor = pyarrow.flight.FlightDescriptor.for_command( orjson.dumps( { - "functionName": "SimpleFun", + "functionName": "SimpleFun1", "parameters": {"test1": 1, "test2": 2, "test3": 3}, } ) @@ -45,7 +48,7 @@ def test_function_with_on_load(): descriptor = pyarrow.flight.FlightDescriptor.for_command( orjson.dumps( { - "functionName": "SimpleFun", + "functionName": "SimpleFun2", "parameters": {"test1": 1, "test2": 2, "test3": 3}, } ) @@ -69,12 +72,12 @@ def test_basic_function_tls(tls_ca_cert): assert fun_info.descriptor.command is not None assert len(fun_info.descriptor.command) cmd = orjson.loads(fun_info.descriptor.command) - assert cmd["functionName"] == "SimpleFun" + assert cmd["functionName"] == "SimpleFun1" descriptor = pyarrow.flight.FlightDescriptor.for_command( orjson.dumps( { - "functionName": "SimpleFun", + "functionName": "SimpleFun1", "parameters": {"test1": 1, "test2": 2, "test3": 3}, } ) @@ -84,3 +87,32 @@ def test_basic_function_tls(tls_ca_cert): assert len(data) == 3 assert data.column_names == ["col1", "col2", "col3"] + + +def test_function_with_call_deadline(): + """ + Flight RPC implementation that invokes FlexConnect can be setup with + deadline for the invocation duration (done by GetFlightInfo). + + If the function invocation (or wait for the invocation) exceeds the + deadline, the GetFlightInfo will fail with timeout and the underlying + task will be cancelled (if possible). + + In these cases, the GetFlightInfo raises FlightTimedOutError with + appropriate error code. + """ + with flexconnect_server(["tests.server.funs.fun3"]) as s: + c = pyarrow.flight.FlightClient(s.location) + descriptor = pyarrow.flight.FlightDescriptor.for_command( + orjson.dumps( + { + "functionName": "LongRunningFun", + "parameters": {"test1": 1, "test2": 2, "test3": 3}, + } + ) + ) + + with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: + c.get_flight_info(descriptor) + + assert_error_code(ErrorCode.TIMEOUT, e.value) diff --git a/gooddata-flight-server/gooddata_flight_server/tasks/thread_task_executor.py b/gooddata-flight-server/gooddata_flight_server/tasks/thread_task_executor.py index a1530c59..f5cfe408 100644 --- a/gooddata-flight-server/gooddata_flight_server/tasks/thread_task_executor.py +++ b/gooddata-flight-server/gooddata_flight_server/tasks/thread_task_executor.py @@ -270,7 +270,7 @@ def cancel(self) -> bool: """ Cancels the execution. - IMPORTANT: task executor most not hold any locks at the time of cancellation. + IMPORTANT: task executor must not hold any locks at the time of cancellation. :return: True if cancel was successful, false if it was not possible """