Skip to content

Commit

Permalink
Merge pull request #929 from lupko/master
Browse files Browse the repository at this point in the history
RELATED: CQ-1005 - make FlexConnect function call deadline configurable
  • Loading branch information
lupko authored Dec 12, 2024
2 parents a576d4f + a94898b commit 484fc20
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
from gooddata_flexconnect.function.function_task import FlexConnectFunctionTask

_LOGGER = structlog.get_logger("gooddata_flexconnect.rpc")
_DEFAULT_TASK_WAIT = 60.0


class _FlexConnectServerMethods(FlightServerMethods):
def __init__(self, ctx: ServerContext, registry: FlexConnectFunctionRegistry) -> None:
def __init__(self, ctx: ServerContext, registry: FlexConnectFunctionRegistry, call_deadline_ms: float) -> None:
self._ctx = ctx
self._registry = registry
self._call_deadline = call_deadline_ms / 1000

@staticmethod
def _create_descriptor(fun_name: str, metadata: Optional[dict]) -> pyarrow.flight.FlightDescriptor:
Expand Down Expand Up @@ -148,8 +148,13 @@ def get_flight_info(

try:
# XXX: this should be enhanced to implement polling
task_result = self._ctx.task_executor.wait_for_result(task.task_id, _DEFAULT_TASK_WAIT)
task_result = self._ctx.task_executor.wait_for_result(task.task_id, self._call_deadline)
except TaskWaitTimeoutError:
cancelled = self._ctx.task_executor.cancel(task.task_id)
_LOGGER.warning(
"flexconnect_fun_call_timeout", task_id=task.task_id, fun=task.fun_name, cancelled=cancelled
)

raise ErrorInfo.for_reason(
ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task.task_id}."
).to_timeout_error()
Expand Down Expand Up @@ -195,6 +200,27 @@ def do_get(

_FLEX_CONNECT_CONFIG_SECTION = "flexconnect"
_FLEX_CONNECT_FUNCTION_LIST = "functions"
_FLEX_CONNECT_CALL_DEADLINE_MS = "call_deadline_ms"
_DEFAULT_FLEX_CONNECT_CALL_DEADLINE_MS = 180_000


def _read_call_deadline_ms(ctx: ServerContext) -> int:
call_deadline = ctx.settings.get(f"{_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_CALL_DEADLINE_MS}")
if call_deadline is None:
return _DEFAULT_FLEX_CONNECT_CALL_DEADLINE_MS

try:
call_deadline_ms = int(call_deadline)
if call_deadline_ms <= 0:
raise ValueError()

return call_deadline_ms
except ValueError:
raise ValueError(
f"Value of {_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_CALL_DEADLINE_MS} must "
f"be a positive number - duration, in milliseconds, that FlexConnect function "
f"calls are expected to run."
)


@flight_server_methods
Expand All @@ -209,7 +235,9 @@ def create_flexconnect_flight_methods(ctx: ServerContext) -> FlightServerMethods
:return: new instance of Flight RPC server methods to integrate into the server
"""
modules = list(ctx.settings.get(f"{_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_FUNCTION_LIST}") or [])
call_deadline_ms = _read_call_deadline_ms(ctx)

_LOGGER.info("flexconnect_init", modules=modules)
registry = FlexConnectFunctionRegistry().load(ctx, modules)

return _FlexConnectServerMethods(ctx, registry)
return _FlexConnectServerMethods(ctx, registry, call_deadline_ms)
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def run(self) -> Union[TaskResult, TaskError]:
headers=self._headers,
)

# switch task to non-cancellable state; once the code creates
# and returns the result, the task successfully executed and there
# is nothing to cancel.
#
# NOTE: if the switch finds that task got cancelled already, it
# bails and raises error.
self.switch_non_cancellable()

return FlightDataTaskResult.for_data(result)

def on_task_cancel(self) -> None:
Expand Down
7 changes: 4 additions & 3 deletions gooddata-flexconnect/tests/server/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ def flexconnect_server(
tls: bool = False,
mtls: bool = False,
) -> GoodDataFlightServer:
envvar = ", ".join([f'"{module}"' for module in modules])
envvar = f"[{envvar}]"
funs = ", ".join([f'"{module}"' for module in modules])
funs = f"[{funs}]"

os.environ["GOODDATA_FLIGHT_FLEXCONNECT__FUNCTIONS"] = envvar
os.environ["GOODDATA_FLIGHT_FLEXCONNECT__FUNCTIONS"] = funs
os.environ["GOODDATA_FLIGHT_FLEXCONNECT__CALL_DEADLINE_MS"] = "500"

with server(create_flexconnect_flight_methods, tls, mtls) as s:
yield s
Expand Down
4 changes: 2 additions & 2 deletions gooddata-flexconnect/tests/server/funs/fun1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from gooddata_flight_server import ArrowData


class _SimpleFun(FlexConnectFunction):
Name = "SimpleFun"
class _SimpleFun1(FlexConnectFunction):
Name = "SimpleFun1"
Schema = pyarrow.schema(
fields=[
pyarrow.field("col1", pyarrow.int64()),
Expand Down
6 changes: 3 additions & 3 deletions gooddata-flexconnect/tests/server/funs/fun2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
_DATA: Optional[pyarrow.Table] = None


class _SimpleFun(FlexConnectFunction):
Name = "SimpleFun"
class _SimpleFun2(FlexConnectFunction):
Name = "SimpleFun2"
Schema = pyarrow.schema(
fields=[
pyarrow.field("col1", pyarrow.int64()),
Expand Down Expand Up @@ -37,5 +37,5 @@ def on_load(ctx: ServerContext) -> None:
"col2": ["a", "b", "c"],
"col3": [True, False, True],
},
schema=_SimpleFun.Schema,
schema=_SimpleFun2.Schema,
)
39 changes: 39 additions & 0 deletions gooddata-flexconnect/tests/server/funs/fun3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# (C) 2024 GoodData Corporation
import time
from typing import Optional

import pyarrow
from gooddata_flexconnect.function.function import FlexConnectFunction
from gooddata_flight_server import ArrowData

_DATA: Optional[pyarrow.Table] = None


class _LongRunningFun(FlexConnectFunction):
Name = "LongRunningFun"
Schema = pyarrow.schema(
fields=[
pyarrow.field("col1", pyarrow.int64()),
pyarrow.field("col2", pyarrow.string()),
pyarrow.field("col3", pyarrow.bool_()),
]
)

def call(
self,
parameters: dict,
columns: tuple[str, ...],
headers: dict[str, list[str]],
) -> ArrowData:
# sleep is intentionally setup to be longer than the deadline for
# the function invocation (see conftest.py // flexconnect_server fixture)
time.sleep(1)

return pyarrow.table(
data={
"col1": [1, 2, 3],
"col2": ["a", "b", "c"],
"col3": [True, False, True],
},
schema=self.Schema,
)
42 changes: 37 additions & 5 deletions gooddata-flexconnect/tests/server/test_flexconnect_server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# (C) 2024 GoodData Corporation
import orjson
import pyarrow.flight
import pytest
from gooddata_flight_server import ErrorCode

from tests.assert_error_info import assert_error_code
from tests.server.conftest import flexconnect_server


Expand All @@ -16,12 +19,12 @@ def test_basic_function():
assert fun_info.descriptor.command is not None
assert len(fun_info.descriptor.command)
cmd = orjson.loads(fun_info.descriptor.command)
assert cmd["functionName"] == "SimpleFun"
assert cmd["functionName"] == "SimpleFun1"

descriptor = pyarrow.flight.FlightDescriptor.for_command(
orjson.dumps(
{
"functionName": "SimpleFun",
"functionName": "SimpleFun1",
"parameters": {"test1": 1, "test2": 2, "test3": 3},
}
)
Expand All @@ -45,7 +48,7 @@ def test_function_with_on_load():
descriptor = pyarrow.flight.FlightDescriptor.for_command(
orjson.dumps(
{
"functionName": "SimpleFun",
"functionName": "SimpleFun2",
"parameters": {"test1": 1, "test2": 2, "test3": 3},
}
)
Expand All @@ -69,12 +72,12 @@ def test_basic_function_tls(tls_ca_cert):
assert fun_info.descriptor.command is not None
assert len(fun_info.descriptor.command)
cmd = orjson.loads(fun_info.descriptor.command)
assert cmd["functionName"] == "SimpleFun"
assert cmd["functionName"] == "SimpleFun1"

descriptor = pyarrow.flight.FlightDescriptor.for_command(
orjson.dumps(
{
"functionName": "SimpleFun",
"functionName": "SimpleFun1",
"parameters": {"test1": 1, "test2": 2, "test3": 3},
}
)
Expand All @@ -84,3 +87,32 @@ def test_basic_function_tls(tls_ca_cert):

assert len(data) == 3
assert data.column_names == ["col1", "col2", "col3"]


def test_function_with_call_deadline():
"""
Flight RPC implementation that invokes FlexConnect can be setup with
deadline for the invocation duration (done by GetFlightInfo).
If the function invocation (or wait for the invocation) exceeds the
deadline, the GetFlightInfo will fail with timeout and the underlying
task will be cancelled (if possible).
In these cases, the GetFlightInfo raises FlightTimedOutError with
appropriate error code.
"""
with flexconnect_server(["tests.server.funs.fun3"]) as s:
c = pyarrow.flight.FlightClient(s.location)
descriptor = pyarrow.flight.FlightDescriptor.for_command(
orjson.dumps(
{
"functionName": "LongRunningFun",
"parameters": {"test1": 1, "test2": 2, "test3": 3},
}
)
)

with pytest.raises(pyarrow.flight.FlightTimedOutError) as e:
c.get_flight_info(descriptor)

assert_error_code(ErrorCode.TIMEOUT, e.value)
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def cancel(self) -> bool:
"""
Cancels the execution.
IMPORTANT: task executor most not hold any locks at the time of cancellation.
IMPORTANT: task executor must not hold any locks at the time of cancellation.
:return: True if cancel was successful, false if it was not possible
"""
Expand Down

0 comments on commit 484fc20

Please sign in to comment.