diff --git a/dash/_callback.py b/dash/_callback.py index df3176ece1..272d9aa6a2 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -62,6 +62,7 @@ def _invoke_callback(func, *args, **kwargs): # used to mark the frame for the d GLOBAL_CALLBACK_LIST = [] GLOBAL_CALLBACK_MAP = {} GLOBAL_INLINE_SCRIPTS = [] +GLOBAL_API_PATHS = {} # pylint: disable=too-many-locals @@ -77,6 +78,7 @@ def callback( cache_args_to_ignore: Optional[list] = None, cache_ignore_triggered=True, on_error: Optional[Callable[[Exception], Any]] = None, + api_endpoint: Optional[str] = None, **_kwargs, ) -> Callable[..., Any]: """ @@ -168,6 +170,7 @@ def callback( ) callback_map = _kwargs.pop("callback_map", GLOBAL_CALLBACK_MAP) callback_list = _kwargs.pop("callback_list", GLOBAL_CALLBACK_LIST) + callback_api_paths = _kwargs.pop("callback_api_paths", GLOBAL_API_PATHS) if background: background_spec: Any = { @@ -207,12 +210,14 @@ def callback( callback_list, callback_map, config_prevent_initial_callbacks, + callback_api_paths, *_args, **_kwargs, background=background_spec, manager=manager, running=running, on_error=on_error, + api_endpoint=api_endpoint, ) @@ -575,7 +580,12 @@ def _prepare_response( # pylint: disable=too-many-branches,too-many-statements def register_callback( - callback_list, callback_map, config_prevent_initial_callbacks, *_args, **_kwargs + callback_list, + callback_map, + config_prevent_initial_callbacks, + callback_api_paths, + *_args, + **_kwargs, ): ( output, @@ -628,6 +638,10 @@ def register_callback( # pylint: disable=too-many-locals def wrap_func(func): + if _kwargs.get("api_endpoint"): + api_endpoint = _kwargs.get("api_endpoint") + callback_api_paths[api_endpoint] = func + if background is None: background_key = None else: diff --git a/dash/dash.py b/dash/dash.py index 92c4a5d542..575076acc7 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -568,6 +568,7 @@ def __init__( # pylint: disable=too-many-statements self.callback_map = {} # same deps as a list to catch duplicate outputs, and to send to the front end self._callback_list = [] + self.callback_api_paths = {} # list of inline scripts self._inline_scripts = [] @@ -778,6 +779,41 @@ def _setup_routes(self): # catch-all for front-end routes, used by dcc.Location self._add_url("", self.index) + def setup_apis(self): + # Copy over global callback data structures assigned with `dash.callback` + for k in list(_callback.GLOBAL_API_PATHS): + if k in self.callback_api_paths: + raise DuplicateCallback( + f"The callback `{k}` provided with `dash.callback` was already " + "assigned with `app.callback`." + ) + self.callback_api_paths[k] = _callback.GLOBAL_API_PATHS.pop(k) + + def make_parse_body(func): + def _parse_body(): + if flask.request.is_json: + data = flask.request.get_json() + return flask.jsonify(func(**data)) + return flask.jsonify({}) + + return _parse_body + + def make_parse_body_async(func): + async def _parse_body_async(): + if flask.request.is_json: + data = flask.request.get_json() + result = await func(**data) + return flask.jsonify(result) + return flask.jsonify({}) + + return _parse_body_async + + for path, func in self.callback_api_paths.items(): + if asyncio.iscoroutinefunction(func): + self._add_url(path, make_parse_body_async(func), ["POST"]) + else: + self._add_url(path, make_parse_body(func), ["POST"]) + def _setup_plotlyjs(self): # pylint: disable=import-outside-toplevel from plotly.offline import get_plotlyjs_version @@ -1346,6 +1382,7 @@ def callback(self, *_args, **_kwargs) -> Callable[..., Any]: config_prevent_initial_callbacks=self.config.prevent_initial_callbacks, callback_list=self._callback_list, callback_map=self.callback_map, + callback_api_paths=self.callback_api_paths, **_kwargs, ) @@ -1496,6 +1533,7 @@ def dispatch(self): def _setup_server(self): if self._got_first_request["setup_server"]: return + self._got_first_request["setup_server"] = True # Apply _force_eager_loading overrides from modules