Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bot] Merge master/484fc20c into rel/dev #930

Merged
merged 5 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
from gooddata_flexconnect.function.function_task import FlexConnectFunctionTask

_LOGGER = structlog.get_logger("gooddata_flexconnect.rpc")
_DEFAULT_TASK_WAIT = 60.0


class _FlexConnectServerMethods(FlightServerMethods):
def __init__(self, ctx: ServerContext, registry: FlexConnectFunctionRegistry) -> None:
def __init__(self, ctx: ServerContext, registry: FlexConnectFunctionRegistry, call_deadline_ms: float) -> None:
self._ctx = ctx
self._registry = registry
self._call_deadline = call_deadline_ms / 1000

@staticmethod
def _create_descriptor(fun_name: str, metadata: Optional[dict]) -> pyarrow.flight.FlightDescriptor:
Expand Down Expand Up @@ -148,8 +148,13 @@ def get_flight_info(

try:
# XXX: this should be enhanced to implement polling
task_result = self._ctx.task_executor.wait_for_result(task.task_id, _DEFAULT_TASK_WAIT)
task_result = self._ctx.task_executor.wait_for_result(task.task_id, self._call_deadline)
except TaskWaitTimeoutError:
cancelled = self._ctx.task_executor.cancel(task.task_id)
_LOGGER.warning(
"flexconnect_fun_call_timeout", task_id=task.task_id, fun=task.fun_name, cancelled=cancelled
)

raise ErrorInfo.for_reason(
ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task.task_id}."
).to_timeout_error()
Expand Down Expand Up @@ -195,6 +200,27 @@ def do_get(

_FLEX_CONNECT_CONFIG_SECTION = "flexconnect"
_FLEX_CONNECT_FUNCTION_LIST = "functions"
_FLEX_CONNECT_CALL_DEADLINE_MS = "call_deadline_ms"
_DEFAULT_FLEX_CONNECT_CALL_DEADLINE_MS = 180_000


def _read_call_deadline_ms(ctx: ServerContext) -> int:
call_deadline = ctx.settings.get(f"{_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_CALL_DEADLINE_MS}")
if call_deadline is None:
return _DEFAULT_FLEX_CONNECT_CALL_DEADLINE_MS

try:
call_deadline_ms = int(call_deadline)
if call_deadline_ms <= 0:
raise ValueError()

return call_deadline_ms
except ValueError:
raise ValueError(
f"Value of {_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_CALL_DEADLINE_MS} must "
f"be a positive number - duration, in milliseconds, that FlexConnect function "
f"calls are expected to run."
)


@flight_server_methods
Expand All @@ -209,7 +235,9 @@ def create_flexconnect_flight_methods(ctx: ServerContext) -> FlightServerMethods
:return: new instance of Flight RPC server methods to integrate into the server
"""
modules = list(ctx.settings.get(f"{_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_FUNCTION_LIST}") or [])
call_deadline_ms = _read_call_deadline_ms(ctx)

_LOGGER.info("flexconnect_init", modules=modules)
registry = FlexConnectFunctionRegistry().load(ctx, modules)

return _FlexConnectServerMethods(ctx, registry)
return _FlexConnectServerMethods(ctx, registry, call_deadline_ms)
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def run(self) -> Union[TaskResult, TaskError]:
headers=self._headers,
)

# switch task to non-cancellable state; once the code creates
# and returns the result, the task successfully executed and there
# is nothing to cancel.
#
# NOTE: if the switch finds that task got cancelled already, it
# bails and raises error.
self.switch_non_cancellable()

return FlightDataTaskResult.for_data(result)

def on_task_cancel(self) -> None:
Expand Down
7 changes: 4 additions & 3 deletions gooddata-flexconnect/tests/server/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ def flexconnect_server(
tls: bool = False,
mtls: bool = False,
) -> GoodDataFlightServer:
envvar = ", ".join([f'"{module}"' for module in modules])
envvar = f"[{envvar}]"
funs = ", ".join([f'"{module}"' for module in modules])
funs = f"[{funs}]"

os.environ["GOODDATA_FLIGHT_FLEXCONNECT__FUNCTIONS"] = envvar
os.environ["GOODDATA_FLIGHT_FLEXCONNECT__FUNCTIONS"] = funs
os.environ["GOODDATA_FLIGHT_FLEXCONNECT__CALL_DEADLINE_MS"] = "500"

with server(create_flexconnect_flight_methods, tls, mtls) as s:
yield s
Expand Down
4 changes: 2 additions & 2 deletions gooddata-flexconnect/tests/server/funs/fun1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from gooddata_flight_server import ArrowData


class _SimpleFun(FlexConnectFunction):
Name = "SimpleFun"
class _SimpleFun1(FlexConnectFunction):
Name = "SimpleFun1"
Schema = pyarrow.schema(
fields=[
pyarrow.field("col1", pyarrow.int64()),
Expand Down
6 changes: 3 additions & 3 deletions gooddata-flexconnect/tests/server/funs/fun2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
_DATA: Optional[pyarrow.Table] = None


class _SimpleFun(FlexConnectFunction):
Name = "SimpleFun"
class _SimpleFun2(FlexConnectFunction):
Name = "SimpleFun2"
Schema = pyarrow.schema(
fields=[
pyarrow.field("col1", pyarrow.int64()),
Expand Down Expand Up @@ -37,5 +37,5 @@ def on_load(ctx: ServerContext) -> None:
"col2": ["a", "b", "c"],
"col3": [True, False, True],
},
schema=_SimpleFun.Schema,
schema=_SimpleFun2.Schema,
)
39 changes: 39 additions & 0 deletions gooddata-flexconnect/tests/server/funs/fun3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# (C) 2024 GoodData Corporation
import time
from typing import Optional

import pyarrow
from gooddata_flexconnect.function.function import FlexConnectFunction
from gooddata_flight_server import ArrowData

_DATA: Optional[pyarrow.Table] = None


class _LongRunningFun(FlexConnectFunction):
Name = "LongRunningFun"
Schema = pyarrow.schema(
fields=[
pyarrow.field("col1", pyarrow.int64()),
pyarrow.field("col2", pyarrow.string()),
pyarrow.field("col3", pyarrow.bool_()),
]
)

def call(
self,
parameters: dict,
columns: tuple[str, ...],
headers: dict[str, list[str]],
) -> ArrowData:
# sleep is intentionally setup to be longer than the deadline for
# the function invocation (see conftest.py // flexconnect_server fixture)
time.sleep(1)

return pyarrow.table(
data={
"col1": [1, 2, 3],
"col2": ["a", "b", "c"],
"col3": [True, False, True],
},
schema=self.Schema,
)
42 changes: 37 additions & 5 deletions gooddata-flexconnect/tests/server/test_flexconnect_server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# (C) 2024 GoodData Corporation
import orjson
import pyarrow.flight
import pytest
from gooddata_flight_server import ErrorCode

from tests.assert_error_info import assert_error_code
from tests.server.conftest import flexconnect_server


Expand All @@ -16,12 +19,12 @@ def test_basic_function():
assert fun_info.descriptor.command is not None
assert len(fun_info.descriptor.command)
cmd = orjson.loads(fun_info.descriptor.command)
assert cmd["functionName"] == "SimpleFun"
assert cmd["functionName"] == "SimpleFun1"

descriptor = pyarrow.flight.FlightDescriptor.for_command(
orjson.dumps(
{
"functionName": "SimpleFun",
"functionName": "SimpleFun1",
"parameters": {"test1": 1, "test2": 2, "test3": 3},
}
)
Expand All @@ -45,7 +48,7 @@ def test_function_with_on_load():
descriptor = pyarrow.flight.FlightDescriptor.for_command(
orjson.dumps(
{
"functionName": "SimpleFun",
"functionName": "SimpleFun2",
"parameters": {"test1": 1, "test2": 2, "test3": 3},
}
)
Expand All @@ -69,12 +72,12 @@ def test_basic_function_tls(tls_ca_cert):
assert fun_info.descriptor.command is not None
assert len(fun_info.descriptor.command)
cmd = orjson.loads(fun_info.descriptor.command)
assert cmd["functionName"] == "SimpleFun"
assert cmd["functionName"] == "SimpleFun1"

descriptor = pyarrow.flight.FlightDescriptor.for_command(
orjson.dumps(
{
"functionName": "SimpleFun",
"functionName": "SimpleFun1",
"parameters": {"test1": 1, "test2": 2, "test3": 3},
}
)
Expand All @@ -84,3 +87,32 @@ def test_basic_function_tls(tls_ca_cert):

assert len(data) == 3
assert data.column_names == ["col1", "col2", "col3"]


def test_function_with_call_deadline():
"""
Flight RPC implementation that invokes FlexConnect can be setup with
deadline for the invocation duration (done by GetFlightInfo).

If the function invocation (or wait for the invocation) exceeds the
deadline, the GetFlightInfo will fail with timeout and the underlying
task will be cancelled (if possible).

In these cases, the GetFlightInfo raises FlightTimedOutError with
appropriate error code.
"""
with flexconnect_server(["tests.server.funs.fun3"]) as s:
c = pyarrow.flight.FlightClient(s.location)
descriptor = pyarrow.flight.FlightDescriptor.for_command(
orjson.dumps(
{
"functionName": "LongRunningFun",
"parameters": {"test1": 1, "test2": 2, "test3": 3},
}
)
)

with pytest.raises(pyarrow.flight.FlightTimedOutError) as e:
c.get_flight_info(descriptor)

assert_error_code(ErrorCode.TIMEOUT, e.value)
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def cancel(self) -> bool:
"""
Cancels the execution.

IMPORTANT: task executor most not hold any locks at the time of cancellation.
IMPORTANT: task executor must not hold any locks at the time of cancellation.

:return: True if cancel was successful, false if it was not possible
"""
Expand Down
Loading