diff --git a/launch/api_client/paths/v1_async_tasks/post.py b/launch/api_client/paths/v1_async_tasks/post.py index 19d68ae1..7f3c6fd2 100644 --- a/launch/api_client/paths/v1_async_tasks/post.py +++ b/launch/api_client/paths/v1_async_tasks/post.py @@ -192,7 +192,7 @@ class instances for serialized_value in serialized_data.values(): used_path += serialized_value - _headers = HTTPHeaderDict() + _headers = HTTPHeaderDict(self.api_client.default_headers) # TODO add cookie handling if accept_content_types: for accept_content_type in accept_content_types: diff --git a/launch/api_client/paths/v1_streaming_tasks/post.py b/launch/api_client/paths/v1_streaming_tasks/post.py index 35563f0a..346acc43 100644 --- a/launch/api_client/paths/v1_streaming_tasks/post.py +++ b/launch/api_client/paths/v1_streaming_tasks/post.py @@ -189,7 +189,7 @@ class instances for serialized_value in serialized_data.values(): used_path += serialized_value - _headers = HTTPHeaderDict() + _headers = HTTPHeaderDict(self.api_client.default_headers) # TODO add cookie handling if accept_content_types: for accept_content_type in accept_content_types: diff --git a/launch/api_client/paths/v1_sync_tasks/post.py b/launch/api_client/paths/v1_sync_tasks/post.py index 0faf1568..dd4cce00 100644 --- a/launch/api_client/paths/v1_sync_tasks/post.py +++ b/launch/api_client/paths/v1_sync_tasks/post.py @@ -192,7 +192,7 @@ class instances for serialized_value in serialized_data.values(): used_path += serialized_value - _headers = HTTPHeaderDict() + _headers = HTTPHeaderDict(self.api_client.default_headers) # TODO add cookie handling if accept_content_types: for accept_content_type in accept_content_types: diff --git a/launch/client.py b/launch/client.py index 68b64823..89a686a9 100644 --- a/launch/client.py +++ b/launch/client.py @@ -1955,6 +1955,7 @@ def _streaming_request( url: Optional[str] = None, args: Optional[Dict] = None, return_pickled: bool = False, + extra_headers: Optional[Dict[str, str]] = None, ) -> requests.Response: """ Not recommended for use, instead use functions provided by StreamingEndpoint. Makes a @@ -1991,6 +1992,7 @@ def _streaming_request( json=payload, auth=(self.configuration.username, self.configuration.password), stream=True, + headers=extra_headers or {}, ) return response @@ -2000,6 +2002,7 @@ def _sync_request( url: Optional[str] = None, args: Optional[Dict] = None, return_pickled: bool = False, + extra_headers: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: """ Not recommended for use, instead use functions provided by SyncEndpoint Makes a request @@ -2039,6 +2042,8 @@ def _sync_request( endpoint = self.get_model_endpoint(endpoint_name) endpoint_id = endpoint.model_endpoint.id # type: ignore with ApiClient(self.configuration) as api_client: + for key, value in (extra_headers or {}).items(): + api_client.set_default_header(key, value) api_instance = DefaultApi(api_client) payload = dict_not_none(return_pickled=return_pickled, url=url, args=args) request = EndpointPredictV1Request(**payload) @@ -2064,6 +2069,7 @@ def _async_request( callback_auth_cert: Optional[str] = None, callback_auth_key: Optional[str] = None, return_pickled: bool = False, + extra_headers: Optional[Dict[str, str]] = None, ) -> str: """ Makes a request to the Async Model Endpoint at endpoint_id, and immediately returns a key @@ -2120,6 +2126,8 @@ def _async_request( validate_task_request(url=url, args=args) endpoint = self.get_model_endpoint(endpoint_name) with ApiClient(self.configuration) as api_client: + for key, value in (extra_headers or {}).items(): + api_client.set_default_header(key, value) api_instance = DefaultApi(api_client) if callback_auth_kind is not None: callback_auth = CallbackAuth( diff --git a/launch/model_endpoint.py b/launch/model_endpoint.py index 8e6adb8b..b9f1585b 100644 --- a/launch/model_endpoint.py +++ b/launch/model_endpoint.py @@ -147,6 +147,9 @@ class EndpointRequest: request_id: (deprecated) A user-specifiable id for requests. Should be unique among EndpointRequests made in the same batch call. If one isn't provided the client will generate its own. + + extra_headers: An optional dictionary which is passed on to the model endpoint + as extra HTTP headers. """ def __init__( @@ -161,6 +164,7 @@ def __init__( callback_auth_key: Optional[str] = None, return_pickled: Optional[bool] = False, request_id: Optional[str] = None, + extra_headers: Optional[Dict[str, str]] = None, ): # TODO: request_id is pretty much here only to support the clientside AsyncEndpointBatchResponse # so it should be removed when we get proper batch endpoints working. @@ -177,6 +181,7 @@ def __init__( self.callback_auth_key = callback_auth_key self.return_pickled = return_pickled self.request_id: str = request_id + self.extra_headers = extra_headers class EndpointResponse: @@ -406,6 +411,7 @@ def predict(self, request: EndpointRequest) -> EndpointResponse: url=request.url, args=request.args, return_pickled=request.return_pickled, + extra_headers=request.extra_headers, ) raw_response = { @@ -457,6 +463,7 @@ def predict(self, request: EndpointRequest) -> EndpointResponseStream: url=request.url, args=request.args, return_pickled=request.return_pickled, + extra_headers=request.extra_headers, ) return EndpointResponseStream(response=raw_response) @@ -517,6 +524,7 @@ def predict(self, request: EndpointRequest) -> EndpointResponseFuture: callback_auth_cert=request.callback_auth_cert, callback_auth_key=request.callback_auth_key, return_pickled=request.return_pickled, + extra_headers=request.extra_headers, ) async_task_id = response["task_id"] return EndpointResponseFuture(