Skip to content

Commit

Permalink
Mock out raw_headers, and allow custom response_classes.
Browse files Browse the repository at this point in the history
Both of these are features of aiohttp.

These changes were necessary to mock aiobotocore responses.
  • Loading branch information
brycedrennan committed Aug 30, 2017
1 parent d73bbc5 commit e04de1d
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 3 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,5 @@ ENV/
.ropeproject

*.idea/*
.envrc
.direnv
23 changes: 20 additions & 3 deletions aioresponses/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def __init__(self, url: str, method: str = hdrs.METH_GET,
status: int = 200, body: str = '',
exception: 'Exception' = None,
headers: Dict = None, payload: Dict = None,
content_type: str = 'application/json', ):
content_type: str = 'application/json',
response_class=None):
self.url = self.parse_url(url)
self.method = method.lower()
self.status = status
Expand All @@ -35,9 +36,11 @@ def __init__(self, url: str, method: str = hdrs.METH_GET,
self.exception = exception
self.headers = headers
self.content_type = content_type
self.response_class = response_class or ClientResponse

def parse_url(self, url: str) -> str:
"""Normalize url to make comparisons."""
url = str(url)
_url = url.split('?')[0]
query = urlencode(sorted(parse_qsl(urlparse(url).query)))

Expand All @@ -51,18 +54,30 @@ def match(self, method: str, url: str) -> bool:
def build_response(self) -> 'ClientResponse':
if isinstance(self.exception, Exception):
raise self.exception
self.resp = ClientResponse(self.method, URL(self.url))
self.resp = self.response_class(self.method, URL(self.url))
# we need to initialize headers manually
self.resp.headers = CIMultiDict({hdrs.CONTENT_TYPE: self.content_type})
if self.headers:
self.resp.headers.update(self.headers)
self.resp.raw_headers = self._build_raw_headers(self.resp.headers)
self.resp.status = self.status
self.resp.content = StreamReader()
self.resp.content.feed_data(self.body)
self.resp.content.feed_eof()

return self.resp

def _build_raw_headers(self, headers):
"""
Convert a dict of headers to a tuple of tuples
Mimics the format of ClientResponse.
"""
raw_headers = []
for k, v in headers.items():
raw_headers.append((k.encode('utf8'), v.encode('utf8')))
return tuple(raw_headers)


class aioresponses(object):
"""Mock aiohttp requests made by ClientSession."""
Expand Down Expand Up @@ -145,7 +160,8 @@ def add(self, url: str, method: str = hdrs.METH_GET, status: int = 200,
exception: 'Exception' = None,
content_type: str = 'application/json',
payload: Dict = None,
headers: Dict = None) -> None:
headers: Dict = None,
response_class=None) -> None:
self._responses.append(UrlResponse(
url,
method=method,
Expand All @@ -155,6 +171,7 @@ def add(self, url: str, method: str = hdrs.METH_GET, status: int = 200,
exception=exception,
payload=payload,
headers=headers,
response_class=response_class,
))

def match(self, method: str, url: str) -> 'ClientResponse':
Expand Down
24 changes: 24 additions & 0 deletions tests/test_aioresponses.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,20 @@ def test_returned_response_headers(self, m):
self.assertEqual(response.headers['Connection'], 'keep-alive')
self.assertEqual(response.headers[hdrs.CONTENT_TYPE], 'text/html')

@aioresponses()
@asyncio.coroutine
def test_returned_response_raw_headers(self, m):
m.get(self.url,
content_type='text/html',
headers={'Connection': 'keep-alive'})
response = yield from self.session.get(self.url)
expected_raw_headers = (
(b'Content-Type', b'text/html'),
(b'Connection', b'keep-alive')
)

self.assertEqual(response.raw_headers, expected_raw_headers)

@aioresponses()
def test_method_dont_match(self, m):
m.get(self.url)
Expand Down Expand Up @@ -197,3 +211,13 @@ def doit():

self.assertEqual(api.status, 200)
self.assertEqual(ext.status, 201)

@aioresponses()
@asyncio.coroutine
def test_custom_response_class(self, m):
class CustomClientResponse(ClientResponse):
pass

m.get(self.url, body='Test', response_class=CustomClientResponse)
resp = yield from self.session.get(self.url)
self.assertTrue(isinstance(resp, CustomClientResponse))

0 comments on commit e04de1d

Please sign in to comment.