Skip to content

Implements "trailing_headers" for HTTP/2 #1012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 45 additions & 9 deletions httpcore/_async/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,14 @@ def __init__(
h2.events.ResponseReceived
| h2.events.DataReceived
| h2.events.StreamEnded
| h2.events.StreamReset,
| h2.events.StreamReset
| h2.events.TrailersReceived,
],
] = {}

# Mapping from stream ID to trailing headers
self._trailing_headers: dict[int, list[tuple[bytes, bytes]]] = {}

# Connection terminated events are stored as state since
# we need to handle them for all streams.
self._connection_terminated: h2.events.ConnectionTerminated | None = None
Expand Down Expand Up @@ -152,15 +156,22 @@ async def handle_async_request(self, request: Request) -> Response:
)
trace.return_value = (status, headers)

extensions = {
"http_version": b"HTTP/2",
"network_stream": self._network_stream,
"stream_id": stream_id,
}

return Response(
status=status,
headers=headers,
content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
extensions={
"http_version": b"HTTP/2",
"network_stream": self._network_stream,
"stream_id": stream_id,
},
content=HTTP2ConnectionByteStream(
connection=self,
request=request,
stream_id=stream_id,
extensions=extensions,
),
extensions=extensions,
)
except BaseException as exc: # noqa: PIE786
with AsyncShieldCancellation():
Expand Down Expand Up @@ -326,7 +337,12 @@ async def _receive_response_body(

async def _receive_stream_event(
self, request: Request, stream_id: int
) -> h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded:
) -> (
h2.events.ResponseReceived
| h2.events.DataReceived
| h2.events.StreamEnded
| h2.events.TrailersReceived
):
"""
Return the next available event for a given stream ID.

Expand All @@ -337,6 +353,13 @@ async def _receive_stream_event(
event = self._events[stream_id].pop(0)
if isinstance(event, h2.events.StreamReset):
raise RemoteProtocolError(event)
elif isinstance(event, h2.events.TrailersReceived):
if event.stream_id in self._events and event.headers is not None:
self._trailing_headers[event.stream_id] = []
for k, v in event.headers:
if not k.startswith(b":"):
self._trailing_headers[event.stream_id].append((k, v))

return event

async def _receive_events(
Expand Down Expand Up @@ -377,6 +400,7 @@ async def _receive_events(
h2.events.DataReceived,
h2.events.StreamEnded,
h2.events.StreamReset,
h2.events.TrailersReceived,
),
):
if event.stream_id in self._events:
Expand Down Expand Up @@ -409,6 +433,8 @@ async def _receive_remote_settings_change(
async def _response_closed(self, stream_id: int) -> None:
await self._max_streams_semaphore.release()
del self._events[stream_id]
if stream_id in self._trailing_headers:
del self._trailing_headers[stream_id]
async with self._state_lock:
if self._connection_terminated and not self._events:
await self.aclose()
Expand Down Expand Up @@ -561,12 +587,17 @@ async def __aexit__(

class HTTP2ConnectionByteStream:
def __init__(
self, connection: AsyncHTTP2Connection, request: Request, stream_id: int
self,
connection: AsyncHTTP2Connection,
request: Request,
stream_id: int,
extensions: typing.MutableMapping[str, typing.Any],
) -> None:
self._connection = connection
self._request = request
self._stream_id = stream_id
self._closed = False
self._extensions = extensions

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
kwargs = {"request": self._request, "stream_id": self._stream_id}
Expand All @@ -576,6 +607,11 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
request=self._request, stream_id=self._stream_id
):
yield chunk

if self._stream_id in self._connection._trailing_headers:
self._extensions["trailing_headers"] = (
self._connection._trailing_headers[self._stream_id]
)
except BaseException as exc:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
Expand Down
54 changes: 45 additions & 9 deletions httpcore/_sync/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,14 @@ def __init__(
h2.events.ResponseReceived
| h2.events.DataReceived
| h2.events.StreamEnded
| h2.events.StreamReset,
| h2.events.StreamReset
| h2.events.TrailersReceived,
],
] = {}

# Mapping from stream ID to trailing headers
self._trailing_headers: dict[int, list[tuple[bytes, bytes]]] = {}

# Connection terminated events are stored as state since
# we need to handle them for all streams.
self._connection_terminated: h2.events.ConnectionTerminated | None = None
Expand Down Expand Up @@ -152,15 +156,22 @@ def handle_request(self, request: Request) -> Response:
)
trace.return_value = (status, headers)

extensions = {
"http_version": b"HTTP/2",
"network_stream": self._network_stream,
"stream_id": stream_id,
}

return Response(
status=status,
headers=headers,
content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
extensions={
"http_version": b"HTTP/2",
"network_stream": self._network_stream,
"stream_id": stream_id,
},
content=HTTP2ConnectionByteStream(
connection=self,
request=request,
stream_id=stream_id,
extensions=extensions,
),
extensions=extensions,
)
except BaseException as exc: # noqa: PIE786
with ShieldCancellation():
Expand Down Expand Up @@ -326,7 +337,12 @@ def _receive_response_body(

def _receive_stream_event(
self, request: Request, stream_id: int
) -> h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded:
) -> (
h2.events.ResponseReceived
| h2.events.DataReceived
| h2.events.StreamEnded
| h2.events.TrailersReceived
):
"""
Return the next available event for a given stream ID.

Expand All @@ -337,6 +353,13 @@ def _receive_stream_event(
event = self._events[stream_id].pop(0)
if isinstance(event, h2.events.StreamReset):
raise RemoteProtocolError(event)
elif isinstance(event, h2.events.TrailersReceived):
if event.stream_id in self._events and event.headers is not None:
self._trailing_headers[event.stream_id] = []
for k, v in event.headers:
if not k.startswith(b":"):
self._trailing_headers[event.stream_id].append((k, v))

return event

def _receive_events(
Expand Down Expand Up @@ -377,6 +400,7 @@ def _receive_events(
h2.events.DataReceived,
h2.events.StreamEnded,
h2.events.StreamReset,
h2.events.TrailersReceived,
),
):
if event.stream_id in self._events:
Expand Down Expand Up @@ -409,6 +433,8 @@ def _receive_remote_settings_change(
def _response_closed(self, stream_id: int) -> None:
self._max_streams_semaphore.release()
del self._events[stream_id]
if stream_id in self._trailing_headers:
del self._trailing_headers[stream_id]
with self._state_lock:
if self._connection_terminated and not self._events:
self.close()
Expand Down Expand Up @@ -561,12 +587,17 @@ def __exit__(

class HTTP2ConnectionByteStream:
def __init__(
self, connection: HTTP2Connection, request: Request, stream_id: int
self,
connection: HTTP2Connection,
request: Request,
stream_id: int,
extensions: typing.MutableMapping[str, typing.Any],
) -> None:
self._connection = connection
self._request = request
self._stream_id = stream_id
self._closed = False
self._extensions = extensions

def __iter__(self) -> typing.Iterator[bytes]:
kwargs = {"request": self._request, "stream_id": self._stream_id}
Expand All @@ -576,6 +607,11 @@ def __iter__(self) -> typing.Iterator[bytes]:
request=self._request, stream_id=self._stream_id
):
yield chunk

if self._stream_id in self._connection._trailing_headers:
self._extensions["trailing_headers"] = (
self._connection._trailing_headers[self._stream_id]
)
except BaseException as exc:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
Expand Down
156 changes: 156 additions & 0 deletions tests/_async/test_http2_trailing_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import hpack
import hyperframe.frame
import pytest

import httpcore


@pytest.mark.anyio
async def test_http2_connection_with_trailing_headers():
"""
Test that trailing headers are correctly received and processed.
"""
origin = httpcore.Origin(b"https", b"example.com", 443)
stream = httpcore.AsyncMockStream(
[
hyperframe.frame.SettingsFrame().serialize(),
hyperframe.frame.HeadersFrame(
stream_id=1,
data=hpack.Encoder().encode(
[
(b":status", b"200"),
(b"content-type", b"plain/text"),
]
),
flags=["END_HEADERS"],
).serialize(),
hyperframe.frame.DataFrame(stream_id=1, data=b"Hello, world!").serialize(),
# Send trailing headers
hyperframe.frame.HeadersFrame(
stream_id=1,
data=hpack.Encoder().encode(
[
(b"x-trailer-1", b"trailer-value-1"),
(b"x-trailer-2", b"trailer-value-2"),
]
),
flags=["END_HEADERS", "END_STREAM"],
).serialize(),
]
)
async with httpcore.AsyncHTTP2Connection(
origin=origin, stream=stream, keepalive_expiry=5.0
) as conn:
response = await conn.request("GET", "https://example.com/")
assert response.status == 200
assert response.content == b"Hello, world!"

# Check that trailing headers are included in extensions
assert "trailing_headers" in response.extensions
assert response.extensions["trailing_headers"] == [
(b"x-trailer-1", b"trailer-value-1"),
(b"x-trailer-2", b"trailer-value-2"),
]


@pytest.mark.anyio
async def test_http2_connection_with_body_and_trailing_headers():
"""
Test that trailing headers are correctly received and processed
when reading the response body in chunks.
"""
origin = httpcore.Origin(b"https", b"example.com", 443)
stream = httpcore.AsyncMockStream(
[
hyperframe.frame.SettingsFrame().serialize(),
hyperframe.frame.HeadersFrame(
stream_id=1,
data=hpack.Encoder().encode(
[
(b":status", b"200"),
(b"content-type", b"plain/text"),
]
),
flags=["END_HEADERS"],
).serialize(),
hyperframe.frame.DataFrame(stream_id=1, data=b"Hello, ").serialize(),
hyperframe.frame.DataFrame(stream_id=1, data=b"world!").serialize(),
# Send trailing headers
hyperframe.frame.HeadersFrame(
stream_id=1,
data=hpack.Encoder().encode(
[
(b"x-trailer-1", b"trailer-value-1"),
(b"x-trailer-2", b"trailer-value-2"),
]
),
flags=["END_HEADERS", "END_STREAM"],
).serialize(),
]
)

async with httpcore.AsyncHTTP2Connection(
origin=origin, stream=stream, keepalive_expiry=5.0
) as conn:
async with conn.stream("GET", "https://example.com/") as response:
content = b""
async for chunk in response.aiter_stream():
content += chunk

assert response.status == 200
assert content == b"Hello, world!"

# Check that trailing headers are included in extensions
assert "trailing_headers" in response.extensions
assert response.extensions["trailing_headers"] == [
(b"x-trailer-1", b"trailer-value-1"),
(b"x-trailer-2", b"trailer-value-2"),
]


@pytest.mark.anyio
async def test_http2_connection_with_trailing_headers_pseudo_removed():
"""
Test that pseudo-headers in trailing headers are correctly filtered out.
"""
origin = httpcore.Origin(b"https", b"example.com", 443)
stream = httpcore.AsyncMockStream(
[
hyperframe.frame.SettingsFrame().serialize(),
hyperframe.frame.HeadersFrame(
stream_id=1,
data=hpack.Encoder().encode(
[
(b":status", b"200"),
(b"content-type", b"plain/text"),
]
),
flags=["END_HEADERS"],
).serialize(),
hyperframe.frame.DataFrame(stream_id=1, data=b"Hello, world!").serialize(),
# Send trailing headers with a pseudo-header which should be filtered out
hyperframe.frame.HeadersFrame(
stream_id=1,
data=hpack.Encoder().encode(
[
(b":pseudo", b"should-be-filtered"),
(b"x-trailer", b"trailer-value"),
]
),
flags=["END_HEADERS", "END_STREAM"],
).serialize(),
]
)
async with httpcore.AsyncHTTP2Connection(
origin=origin, stream=stream, keepalive_expiry=5.0
) as conn:
response = await conn.request("GET", "https://example.com/")
assert response.status == 200
assert response.content == b"Hello, world!"

# Check that trailing headers are included in extensions but pseudo-headers are filtered
assert "trailing_headers" in response.extensions
assert len(response.extensions["trailing_headers"]) == 1
assert response.extensions["trailing_headers"] == [
(b"x-trailer", b"trailer-value"),
]
Loading