Skip to content

add extra headers to outgoing requests #173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion launch/api_client/paths/v1_async_tasks/post.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ class instances
for serialized_value in serialized_data.values():
used_path += serialized_value

_headers = HTTPHeaderDict()
_headers = HTTPHeaderDict(self.api_client.default_headers)
# TODO add cookie handling
if accept_content_types:
for accept_content_type in accept_content_types:
Expand Down
2 changes: 1 addition & 1 deletion launch/api_client/paths/v1_streaming_tasks/post.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class instances
for serialized_value in serialized_data.values():
used_path += serialized_value

_headers = HTTPHeaderDict()
_headers = HTTPHeaderDict(self.api_client.default_headers)
# TODO add cookie handling
if accept_content_types:
for accept_content_type in accept_content_types:
Expand Down
2 changes: 1 addition & 1 deletion launch/api_client/paths/v1_sync_tasks/post.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ class instances
for serialized_value in serialized_data.values():
used_path += serialized_value

_headers = HTTPHeaderDict()
_headers = HTTPHeaderDict(self.api_client.default_headers)
# TODO add cookie handling
if accept_content_types:
for accept_content_type in accept_content_types:
Expand Down
8 changes: 8 additions & 0 deletions launch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1955,6 +1955,7 @@ def _streaming_request(
url: Optional[str] = None,
args: Optional[Dict] = None,
return_pickled: bool = False,
extra_headers: Optional[Dict[str, str]] = None,
) -> requests.Response:
"""
Not recommended for use, instead use functions provided by StreamingEndpoint. Makes a
Expand Down Expand Up @@ -1991,6 +1992,7 @@ def _streaming_request(
json=payload,
auth=(self.configuration.username, self.configuration.password),
stream=True,
headers=extra_headers or {},
)
return response

Expand All @@ -2000,6 +2002,7 @@ def _sync_request(
url: Optional[str] = None,
args: Optional[Dict] = None,
return_pickled: bool = False,
extra_headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
"""
Not recommended for use, instead use functions provided by SyncEndpoint Makes a request
Expand Down Expand Up @@ -2039,6 +2042,8 @@ def _sync_request(
endpoint = self.get_model_endpoint(endpoint_name)
endpoint_id = endpoint.model_endpoint.id # type: ignore
with ApiClient(self.configuration) as api_client:
for key, value in (extra_headers or {}).items():
api_client.set_default_header(key, value)
api_instance = DefaultApi(api_client)
payload = dict_not_none(return_pickled=return_pickled, url=url, args=args)
request = EndpointPredictV1Request(**payload)
Expand All @@ -2064,6 +2069,7 @@ def _async_request(
callback_auth_cert: Optional[str] = None,
callback_auth_key: Optional[str] = None,
return_pickled: bool = False,
extra_headers: Optional[Dict[str, str]] = None,
) -> str:
"""
Makes a request to the Async Model Endpoint at endpoint_id, and immediately returns a key
Expand Down Expand Up @@ -2120,6 +2126,8 @@ def _async_request(
validate_task_request(url=url, args=args)
endpoint = self.get_model_endpoint(endpoint_name)
with ApiClient(self.configuration) as api_client:
for key, value in (extra_headers or {}).items():
api_client.set_default_header(key, value)
api_instance = DefaultApi(api_client)
if callback_auth_kind is not None:
callback_auth = CallbackAuth(
Expand Down
8 changes: 8 additions & 0 deletions launch/model_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ class EndpointRequest:
request_id: (deprecated) A user-specifiable id for requests.
Should be unique among EndpointRequests made in the same batch call.
If one isn't provided the client will generate its own.

extra_headers: An optional dictionary which is passed on to the model endpoint
as extra HTTP headers.
"""

def __init__(
Expand All @@ -161,6 +164,7 @@ def __init__(
callback_auth_key: Optional[str] = None,
return_pickled: Optional[bool] = False,
request_id: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
):
# TODO: request_id is pretty much here only to support the clientside AsyncEndpointBatchResponse
# so it should be removed when we get proper batch endpoints working.
Expand All @@ -177,6 +181,7 @@ def __init__(
self.callback_auth_key = callback_auth_key
self.return_pickled = return_pickled
self.request_id: str = request_id
self.extra_headers = extra_headers


class EndpointResponse:
Expand Down Expand Up @@ -406,6 +411,7 @@ def predict(self, request: EndpointRequest) -> EndpointResponse:
url=request.url,
args=request.args,
return_pickled=request.return_pickled,
extra_headers=request.extra_headers,
)

raw_response = {
Expand Down Expand Up @@ -457,6 +463,7 @@ def predict(self, request: EndpointRequest) -> EndpointResponseStream:
url=request.url,
args=request.args,
return_pickled=request.return_pickled,
extra_headers=request.extra_headers,
)
return EndpointResponseStream(response=raw_response)

Expand Down Expand Up @@ -517,6 +524,7 @@ def predict(self, request: EndpointRequest) -> EndpointResponseFuture:
callback_auth_cert=request.callback_auth_cert,
callback_auth_key=request.callback_auth_key,
return_pickled=request.return_pickled,
extra_headers=request.extra_headers,
)
async_task_id = response["task_id"]
return EndpointResponseFuture(
Expand Down