diff --git a/tornado/httpclient.py b/tornado/httpclient.py index 3a45ffd04..488fe6de0 100644 --- a/tornado/httpclient.py +++ b/tornado/httpclient.py @@ -53,7 +53,7 @@ from tornado.ioloop import IOLoop from tornado.util import Configurable -from typing import Type, Any, Union, Dict, Callable, Optional, cast +from typing import Type, Any, Union, Dict, Callable, Optional, Awaitable, cast class HTTPClient: @@ -372,7 +372,9 @@ def __init__( user_agent: Optional[str] = None, use_gzip: Optional[bool] = None, network_interface: Optional[str] = None, - streaming_callback: Optional[Callable[[bytes], None]] = None, + streaming_callback: Optional[ + Callable[[bytes], Optional[Awaitable[None]]] + ] = None, header_callback: Optional[Callable[[str], None]] = None, prepare_curl_callback: Optional[Callable[[Any], None]] = None, proxy_host: Optional[str] = None, diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index cc1637613..5ed273db3 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -33,7 +33,7 @@ from io import BytesIO import urllib.parse -from typing import Dict, Any, Callable, Optional, Type, Union +from typing import Dict, Any, Callable, Optional, Type, Union, Awaitable from types import TracebackType import typing @@ -687,14 +687,15 @@ def finish(self) -> None: def _on_end_request(self) -> None: self.stream.close() - def data_received(self, chunk: bytes) -> None: + def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: if self._should_follow_redirect(): # We're going to follow a redirect so just discard the body. - return + return None if self.request.streaming_callback is not None: - self.request.streaming_callback(chunk) + return self.request.streaming_callback(chunk) else: self.chunks.append(chunk) + return None if __name__ == "__main__": diff --git a/tornado/test/simple_httpclient_test.py b/tornado/test/simple_httpclient_test.py index a40435e81..c1ee49c84 100644 --- a/tornado/test/simple_httpclient_test.py +++ b/tornado/test/simple_httpclient_test.py @@ -539,6 +539,26 @@ def test_streaming_follow_redirects(self): num_start_lines = len([h for h in headers if h.startswith("HTTP/")]) self.assertEqual(num_start_lines, 1) + def test_streaming_callback_coroutine(self: typing.Any): + headers = [] # type: typing.List[str] + chunk_bytes = [] # type: typing.List[bytes] + + @gen.coroutine + def _put_chunk(chunk): + chunk_bytes.append(chunk) + yield gen.moment + + self.fetch( + "/hello", + header_callback=headers.append, + streaming_callback=_put_chunk, + ) + chunks = list(map(to_unicode, chunk_bytes)) + self.assertEqual(chunks, ["Hello world!"]) + # Make sure we only got one set of headers. + num_start_lines = len([h for h in headers if h.startswith("HTTP/")]) + self.assertEqual(num_start_lines, 1) + class SimpleHTTPClientTestCase(AsyncHTTPTestCase, SimpleHTTPClientTestMixin): def setUp(self):