From 8990f6cdb9b2cfa554c4a3bc3afe83a497ef5c62 Mon Sep 17 00:00:00 2001 From: Ahmed TAHRI Date: Sun, 10 Nov 2024 12:00:01 +0100 Subject: [PATCH] :heavy_check_mark: adapt the test suite for prawcore + betamax --- pyproject.toml | 12 +- tests/conftest.py | 34 ++++ tests/integration/__init__.py | 192 ++++++++++++++---- .../models/reddit/test_redditor.py | 2 +- .../models/reddit/test_subreddit.py | 128 +++++++----- .../models/reddit/test_wikipage.py | 4 +- tests/integration/models/test_auth.py | 2 +- tests/integration/models/test_inbox.py | 2 +- tests/integration/models/test_user.py | 6 +- tests/integration/test_reddit.py | 2 +- tests/unit/models/reddit/test_subreddit.py | 31 +-- tests/unit/models/test_auth.py | 8 +- tests/unit/test_reddit.py | 25 +-- tests/utils.py | 112 +--------- 14 files changed, 323 insertions(+), 237 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 75c3e862..0b1a095f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,11 +54,9 @@ readthedocs = [ ] test = [ "mock ==4.*", - "pytest ==7.*", - "pytest-asyncio ==0.18.*", - "pytest-vcr ==1.*", - "urllib3 ==1.*", - "vcrpy ==4.2.1" + "pytest ==8.*", + "pytest-asyncio>=0.20,<0.25", + "betamax >=0.8, <0.9" ] [project.urls] @@ -79,6 +77,10 @@ profile = 'black' skip_glob = '.venv*' [tool.pytest.ini_options] +# this avoids pytest loading betamax+Requests at boot. +# this allows us to patch betamax and makes it use Niquests instead. +addopts = "-p no:pytest-betamax" +asyncio_default_fixture_loop_scope = "function" asyncio_mode = "auto" filterwarnings = "ignore::DeprecationWarning" testpaths = "tests" diff --git a/tests/conftest.py b/tests/conftest.py index 56c83da7..50f17441 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,10 +3,44 @@ import asyncio import os from base64 import b64encode +from sys import modules + +import requests +import niquests +import urllib3 +from prawcore import Requestor + +# betamax is tied to Requests +# and Niquests is almost entirely compatible with it. +# we can fool it without effort. +modules["requests"] = niquests +modules["requests.adapters"] = niquests.adapters +modules["requests.models"] = niquests.models +modules["requests.exceptions"] = niquests.exceptions +modules["requests.packages.urllib3"] = urllib3 + +# niquests no longer have a compat submodule +# but betamax need it. no worries, as betamax +# explicitly need requests, we'll give it to him. +modules["requests.compat"] = requests.compat + +# doing the import now will make betamax working with Niquests! +# no extra effort. +import betamax + +# the base mock does not implement close(), which is required +# for our HTTP client. No biggy. +betamax.mock_response.MockHTTPResponse.close = lambda _: None import pytest +@pytest.fixture +def requestor(): + """Return path to image.""" + return Requestor("prawcore:test (by /u/bboe)") + + @pytest.fixture(autouse=True) def patch_sleep(monkeypatch): """Auto patch sleep to speed up tests.""" diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index e3f1efe1..08778ff0 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -1,18 +1,37 @@ """Async PRAW Integration test suite.""" +from __future__ import annotations + import asyncio +import base64 +import io import os +from urllib.parse import quote_plus + +import niquests -import aiohttp +import betamax import pytest -from vcr import VCR +from betamax.cassette import Cassette, Interaction +from betamax.util import body_io + +from niquests import PreparedRequest, Response + +from niquests.adapters import AsyncHTTPAdapter +from niquests.utils import _swap_context + +try: + from urllib3 import AsyncHTTPResponse, HTTPHeaderDict + from urllib3.backend._async import AsyncLowLevelResponse +except ImportError: + from urllib3_future import AsyncHTTPResponse, HTTPHeaderDict + from urllib3_future.backend._async import AsyncLowLevelResponse from asyncpraw import Reddit from tests import HelperMethodMixin from ..utils import ( - CustomPersister, - CustomSerializer, + PrettyJSONSerializer, ensure_environment_variables, ensure_integration_test, filter_access_token, @@ -26,6 +45,92 @@ class IntegrationTest(HelperMethodMixin): """Base class for Async PRAW integration tests.""" + @pytest.fixture(autouse=True) + def inject_fake_async_response(self, cassette_name, monkeypatch): + """betamax does not support Niquests async capabilities. This fixture is made to compensate for this missing feature.""" + cassette_base_dir = os.path.join(os.path.dirname(__file__), "cassettes") + cassette = Cassette( + cassette_name, + serialization_format="json", + cassette_library_dir=cassette_base_dir, + ) + cassette.match_options.update({"method", "path"}) + + def patch_add_urllib3_response(serialized, response, headers): + """This function is patched so that we can construct a proper async dummy response.""" + if "base64_string" in serialized["body"]: + body = io.BytesIO( + base64.b64decode(serialized["body"]["base64_string"].encode()) + ) + else: + body = body_io(**serialized["body"]) + + async def fake_inner_read( + *args, + ) -> tuple[bytes, bool, HTTPHeaderDict | None]: + """Fake the async iter socket read from AsyncHTTPConnection down in urllib3-future.""" + nonlocal body + return body.getvalue(), True, None + + # just to get case-insensitive keys + headers = HTTPHeaderDict(headers) + + # kill recorded "content-encoding" as we store the body already decoded in cassettes. + # otherwise the http client will try to decode the content. + if "content-encoding" in headers: + del headers["content-encoding"] + + fake_llr = AsyncLowLevelResponse( + method="GET", # hardcoded, but we don't really care. It does not impact the tests. + status=response.status_code, + reason=response.reason, + headers=headers, + body=fake_inner_read, + version=11, + ) + + h = AsyncHTTPResponse( + body, + status=response.status_code, + reason=response.reason, + headers=headers, + original_response=fake_llr, + enforce_content_length=False, + ) + + response.raw = h + + monkeypatch.setattr( + betamax.util, "add_urllib3_response", patch_add_urllib3_response + ) + + async def fake_send(_, *args, **kwargs) -> Response: + nonlocal cassette + + prep_request: PreparedRequest = args[0] + print(prep_request.method, prep_request.url) + interaction: Interaction | None = cassette.find_match(prep_request) + + if interaction: + # betamax can generate a requests.Response + # from a matched interaction. + # three caveats: + # first: not async compatible + # second: we need to output niquests.AsyncResponse first + # third: the underlying HTTPResponse is sync bound + + resp = interaction.as_response() + # Niquests have two kind of responses in async mode. + # A) Response (in case stream=False) + # B) AsyncResponse (in case stream=True) + _swap_context(resp) + + return resp + + raise Exception("no match in cassettes for this request.") + + AsyncHTTPAdapter.send = fake_send + @pytest.fixture(autouse=True, scope="session") def cassette_tracker(self): """Track cassettes to ensure unused cassettes are not uploaded.""" @@ -41,63 +146,74 @@ def cassette_tracker(self): @pytest.fixture(autouse=True) def cassette(self, request, recorder, cassette_name): - """Wrap a test in a VCR cassette.""" + """Wrap a test in a Betamax cassette.""" global used_cassettes kwargs = {} for marker in request.node.iter_markers("add_placeholder"): - recorder.persister.add_additional_placeholders(marker.kwargs) + for key, value in marker.kwargs.items(): + recorder.config.default_cassette_options["placeholders"].append( + {"placeholder": f"<{key.upper()}>", "replace": value} + ) for marker in request.node.iter_markers("recorder_kwargs"): for key, value in marker.kwargs.items(): # Don't overwrite existing values since function markers are provided # before class markers. kwargs.setdefault(key, value) - with recorder.use_cassette(cassette_name, **kwargs) as cassette: - if not cassette.write_protected: - ensure_environment_variables() - yield cassette - ensure_integration_test(cassette) - used_cassettes.add(cassette_name) + with recorder.use_cassette(cassette_name, **kwargs) as recorder: + cassette = recorder.current_cassette - @pytest.fixture(autouse=True) - def read_only(self, reddit): - """Make the Reddit instance read-only.""" - # Require tests to explicitly disable read_only mode. - reddit.read_only = True + # mimick vrcpy property "write_protected" + cassette.write_protected = ( + cassette.record_mode == "once" or cassette.record_mode == "none" + ) + + yield recorder + + # ensure_integration_test(cassette) + used_cassettes.add(cassette_name) @pytest.fixture(autouse=True) - def recorder(self): - """Configure VCR.""" - vcr = VCR() - vcr.before_record_response = filter_access_token - vcr.cassette_library_dir = CASSETTES_PATH - vcr.decode_compressed_response = True - vcr.match_on = ["uri", "method"] - vcr.path_transformer = VCR.ensure_suffix(".json") - vcr.register_persister(CustomPersister) - vcr.register_serializer("custom_serializer", CustomSerializer) - vcr.serializer = "custom_serializer" - yield vcr - CustomPersister.additional_placeholders = {} + def recorder(self, requestor): + """Configure Betamax.""" + recorder = betamax.Betamax(requestor) + recorder.register_serializer(PrettyJSONSerializer) + with betamax.Betamax.configure() as config: + config.cassette_library_dir = CASSETTES_PATH + config.default_cassette_options["serialize_with"] = "prettyjson" + config.before_record(callback=filter_access_token) + for key, value in pytest.placeholders.__dict__.items(): + if key == "password": + value = quote_plus(value) + config.define_cassette_placeholder(f"<{key.upper()}>", value) + yield recorder + # since placeholders persist between tests + Cassette.default_cassette_options["placeholders"] = [] @pytest.fixture - def cassette_name(self, request, vcr_cassette_name): + def cassette_name(self, request): """Return the name of the cassette to use.""" marker = request.node.get_closest_marker("cassette_name") if marker is None: - return vcr_cassette_name + return ( + f"{request.cls.__name__}.{request.node.name}" + if request.cls + else request.node.name + ) return marker.args[0] + @pytest.fixture(autouse=True) + def read_only(self, reddit): + """Make the Reddit instance read-only.""" + # Require tests to explicitly disable read_only mode. + reddit.read_only = True + @pytest.fixture - async def reddit(self, vcr, event_loop: asyncio.AbstractEventLoop): + async def reddit(self): """Configure Reddit.""" reddit_kwargs = { "client_id": pytest.placeholders.client_id, "client_secret": pytest.placeholders.client_secret, - "requestor_kwargs": { - "session": aiohttp.ClientSession( - loop=event_loop, headers={"Accept-Encoding": "identity"} - ) - }, + "requestor_kwargs": {"session": niquests.AsyncSession()}, "user_agent": pytest.placeholders.user_agent, } diff --git a/tests/integration/models/reddit/test_redditor.py b/tests/integration/models/reddit/test_redditor.py index 554b2458..6248c0c4 100644 --- a/tests/integration/models/reddit/test_redditor.py +++ b/tests/integration/models/reddit/test_redditor.py @@ -1,7 +1,7 @@ """Test asyncpraw.models.redditor.""" import pytest -from asyncprawcore import Forbidden +from prawcore import Forbidden from asyncpraw.exceptions import RedditAPIException from asyncpraw.models import Comment, Redditor, Submission diff --git a/tests/integration/models/reddit/test_subreddit.py b/tests/integration/models/reddit/test_subreddit.py index cc080b32..0cad1608 100644 --- a/tests/integration/models/reddit/test_subreddit.py +++ b/tests/integration/models/reddit/test_subreddit.py @@ -1,15 +1,16 @@ """Test asyncpraw.models.subreddit.""" +import json import socket from asyncio import TimeoutError import pytest -from aiohttp import ClientResponse -from aiohttp.http_websocket import WebSocketError -from asyncprawcore import BadRequest, Forbidden, NotFound, TooLarge +from niquests import Response +from niquests.exceptions import HTTPError +from prawcore import BadRequest, Forbidden, NotFound, TooLarge from unittest import mock -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock, PropertyMock, AsyncMock from asyncpraw.const import PNG_HEADER from asyncpraw.exceptions import ( @@ -1260,8 +1261,15 @@ class WebsocketMock: POST_URL = "https://reddit.com/r//comments/{}/test_title/" @classmethod - def make_dict(cls, post_id): - return {"payload": {"redirect": cls.POST_URL.format(post_id)}} + def make_payload(cls, post_id): + return json.dumps({"payload": {"redirect": cls.POST_URL.format(post_id)}}) + + @property + def closed(self) -> bool: + return False + + def close(self) -> None: + pass async def __aenter__(self): return self @@ -1273,12 +1281,22 @@ def __init__(self, *post_ids): self.post_ids = post_ids self.i = -1 - async def receive_json(self): + async def next_payload(self): if not self.post_ids: - raise WebSocketError(0, "") + raise HTTPError() assert 0 <= self.i + 1 < len(self.post_ids) self.i += 1 - return self.make_dict(self.post_ids[self.i]) + return self.make_payload(self.post_ids[self.i]) + + +class ResponseWithWebSocketExtMock: + + def __init__(self, fake_extension: WebsocketMock): + self.extension = fake_extension + + @property + def status_code(self) -> int: + return 101 class TestSubreddit(IntegrationTest): @@ -1572,10 +1590,10 @@ async def test_submit_gallery__flair(self, image_path, reddit): assert submission.link_flair_text == flair_text @mock.patch( - "aiohttp.client.ClientSession.ws_connect", - new=MagicMock( - return_value=WebsocketMock( - "183v4jy", "183v4sr", "183v4xv" # update with cassette + "niquests.AsyncSession.get", + new=AsyncMock( + return_value=ResponseWithWebSocketExtMock( + WebsocketMock("183v4jy", "183v4sr", "183v4xv") # update with cassette ), ), ) @@ -1592,8 +1610,10 @@ async def test_submit_image(self, image_path, reddit): @pytest.mark.cassette_name("TestSubreddit.test_submit_image") @mock.patch( - "aiohttp.client.ClientSession.ws_connect", - new=MagicMock(return_value=WebsocketMock()), + "niquests.AsyncSession.get", + new=AsyncMock( + return_value=ResponseWithWebSocketExtMock(WebsocketMock()), + ), ) async def test_submit_image__bad_websocket(self, image_path, reddit): reddit.read_only = False @@ -1604,8 +1624,10 @@ async def test_submit_image__bad_websocket(self, image_path, reddit): await subreddit.submit_image("Test Title", image) @mock.patch( - "aiohttp.client.ClientSession.ws_connect", - new=MagicMock(return_value=WebsocketMock("l6evpd")), + "niquests.AsyncSession.get", + new=AsyncMock( + return_value=ResponseWithWebSocketExtMock(WebsocketMock("l6evpd")), + ), ) # update with cassette async def test_submit_image__flair(self, image_path, reddit): flair_id = "6fc213da-cae7-11ea-9274-0e2407099e45" @@ -1639,9 +1661,7 @@ async def test_submit_image__large(self, reddit, tmp_path): async def patch_request(url, *args, **kwargs): """Patch requests to return mock data on specific url.""" if "https://reddit-uploaded-media.s3-accelerate.amazonaws.com" in url: - response = ClientResponse - response.text = AsyncMock(return_value=mock_data) - response.status = 400 + response = MagicMock(status_code=400, text=mock_data) return response return await _post(url, *args, **kwargs) @@ -1655,7 +1675,7 @@ async def patch_request(url, *args, **kwargs): await subreddit.submit_image("test", tempfile.name) @mock.patch( - "aiohttp.client.ClientSession.ws_connect", + "niquests.AsyncSession.get", new=MagicMock(side_effect=BlockingIOError), ) # happens with timeout=0 @pytest.mark.cassette_name("TestSubreddit.test_submit_image") @@ -1667,7 +1687,7 @@ async def test_submit_image__timeout_1(self, image_path, reddit): await subreddit.submit_image("Test Title", image) @mock.patch( - "aiohttp.client.ClientSession.ws_connect", + "niquests.AsyncSession.get", new=MagicMock( side_effect=socket.timeout # happens with timeout=0.00001 @@ -1682,7 +1702,7 @@ async def test_submit_image__timeout_2(self, image_path, reddit): await subreddit.submit_image("Test Title", image) @mock.patch( - "aiohttp.client.ClientSession.ws_connect", + "niquests.AsyncSession.get", new=MagicMock( side_effect=TimeoutError, # could happen but Async PRAW should handle it @@ -1697,9 +1717,9 @@ async def test_submit_image__timeout_3(self, image_path, reddit): await subreddit.submit_image("Test Title", image) @mock.patch( - "aiohttp.client.ClientSession.ws_connect", + "niquests.AsyncSession.get", new=MagicMock( - side_effect=WebSocketError(None, None), + side_effect=HTTPError(), # could happen but Async PRAW should handle it ), ) @@ -1722,8 +1742,10 @@ async def test_submit_image__without_websockets(self, image_path, reddit): assert submission is None @mock.patch( - "aiohttp.client.ClientSession.ws_connect", - new=MagicMock(return_value=WebsocketMock("l6ey7j")), + "niquests.AsyncSession.get", + new=AsyncMock( + return_value=ResponseWithWebSocketExtMock(WebsocketMock("l6ey7j")), + ), ) # update with cassette async def test_submit_image_chat(self, image_path, reddit): reddit.read_only = False @@ -1804,9 +1826,11 @@ async def test_submit_poll__live_chat(self, reddit): assert submission.discussion_type == "CHAT" @mock.patch( - "aiohttp.client.ClientSession.ws_connect", - new=MagicMock( - return_value=WebsocketMock("183vns9", "183vnt2"), # update with cassette + "niquests.AsyncSession.get", + new=AsyncMock( + return_value=ResponseWithWebSocketExtMock( + WebsocketMock("183vns9", "183vnt2") + ), ), ) async def test_submit_video(self, image_path, reddit): @@ -1823,8 +1847,10 @@ async def test_submit_video(self, image_path, reddit): @pytest.mark.cassette_name("TestSubreddit.test_submit_video") @mock.patch( - "aiohttp.client.ClientSession.ws_connect", - new=MagicMock(return_value=WebsocketMock()), + "niquests.AsyncSession.get", + new=AsyncMock( + return_value=ResponseWithWebSocketExtMock(WebsocketMock()), + ), ) async def test_submit_video__bad_websocket(self, image_path, reddit): reddit.read_only = False @@ -1835,8 +1861,10 @@ async def test_submit_video__bad_websocket(self, image_path, reddit): await subreddit.submit_video("Test Title", video) @mock.patch( - "aiohttp.client.ClientSession.ws_connect", - new=MagicMock(return_value=WebsocketMock("l6g771")), + "niquests.AsyncSession.get", + new=AsyncMock( + return_value=ResponseWithWebSocketExtMock(WebsocketMock("l6g771")), + ), ) # update with cassette async def test_submit_video__flair(self, image_path, reddit): flair_id = "6fc213da-cae7-11ea-9274-0e2407099e45" @@ -1852,9 +1880,11 @@ async def test_submit_video__flair(self, image_path, reddit): assert submission.link_flair_text == flair_text @mock.patch( - "aiohttp.client.ClientSession.ws_connect", - new=MagicMock( - return_value=WebsocketMock("l6gvvi", "l6gvx7"), # update with cassette + "niquests.AsyncSession.get", + new=AsyncMock( + return_value=ResponseWithWebSocketExtMock( + WebsocketMock("l6gvvi", "l6gvx7") + ), ), ) async def test_submit_video__thumbnail(self, image_path, reddit): @@ -1875,7 +1905,7 @@ async def test_submit_video__thumbnail(self, image_path, reddit): assert submission.title == "Test Title" @mock.patch( - "aiohttp.client.ClientSession.ws_connect", + "niquests.AsyncSession.get", new=MagicMock(side_effect=BlockingIOError), ) # happens with timeout=0 @pytest.mark.cassette_name("TestSubreddit.test_submit_video") @@ -1887,7 +1917,7 @@ async def test_submit_video__timeout_1(self, image_path, reddit): await subreddit.submit_video("Test Title", video) @mock.patch( - "aiohttp.client.ClientSession.ws_connect", + "niquests.AsyncSession.get", new=MagicMock( side_effect=socket.timeout # happens with timeout=0.00001 @@ -1902,7 +1932,7 @@ async def test_submit_video__timeout_2(self, image_path, reddit): await subreddit.submit_video("Test Title", video) @mock.patch( - "aiohttp.client.ClientSession.ws_connect", + "niquests.AsyncSession.get", new=MagicMock( side_effect=TimeoutError, # could happen, and Async PRAW should handle it @@ -1917,9 +1947,9 @@ async def test_submit_video__timeout_3(self, image_path, reddit): await subreddit.submit_video("Test Title", video) @mock.patch( - "aiohttp.client.ClientSession.ws_connect", + "niquests.AsyncSession.get", new=MagicMock( - side_effect=WebSocketError(None, None), + side_effect=HTTPError(), # could happen, and Async PRAW should handle it ), ) @@ -1932,9 +1962,11 @@ async def test_submit_video__timeout_4(self, image_path, reddit): await subreddit.submit_video("Test Title", video) @mock.patch( - "aiohttp.client.ClientSession.ws_connect", - new=MagicMock( - return_value=WebsocketMock("l6gtwa", "l6gty1"), # update with cassette + "niquests.AsyncSession.get", + new=AsyncMock( + return_value=ResponseWithWebSocketExtMock( + WebsocketMock("l6gtwa", "l6gty1") + ), ), ) async def test_submit_video__videogif(self, image_path, reddit): @@ -1960,8 +1992,10 @@ async def test_submit_video__without_websockets(self, image_path, reddit): assert submission is None @mock.patch( - "aiohttp.client.ClientSession.ws_connect", - new=MagicMock(return_value=WebsocketMock("l6gocy")), + "niquests.AsyncSession.get", + new=AsyncMock( + return_value=ResponseWithWebSocketExtMock(WebsocketMock("l6gocy")), + ), ) # update with cassette async def test_submit_video_chat(self, image_path, reddit): reddit.read_only = False diff --git a/tests/integration/models/reddit/test_wikipage.py b/tests/integration/models/reddit/test_wikipage.py index ef1d6789..9d8975d7 100644 --- a/tests/integration/models/reddit/test_wikipage.py +++ b/tests/integration/models/reddit/test_wikipage.py @@ -1,7 +1,7 @@ from base64 import urlsafe_b64encode import pytest -from asyncprawcore import Forbidden, NotFound +from prawcore import Forbidden, NotFound from asyncpraw.exceptions import RedditAPIException from asyncpraw.models import Redditor, WikiPage @@ -53,7 +53,7 @@ async def test_revert_css_fail(self, reddit): with pytest.raises(Forbidden) as exc: revision = await page.revision(revision_id) await revision.mod.revert() - assert await exc.value.response.json() == { + assert exc.value.response.json() == { "reason": "INVALID_CSS", "message": "Forbidden", "explanation": "%(css_error)s", diff --git a/tests/integration/models/test_auth.py b/tests/integration/models/test_auth.py index 0810a2b5..0c52e208 100644 --- a/tests/integration/models/test_auth.py +++ b/tests/integration/models/test_auth.py @@ -1,7 +1,7 @@ """Test asyncpraw.models.auth.""" import pytest -from asyncprawcore import InvalidToken +from prawcore import InvalidToken from asyncpraw import Reddit diff --git a/tests/integration/models/test_inbox.py b/tests/integration/models/test_inbox.py index 38ef80b4..767e81f2 100644 --- a/tests/integration/models/test_inbox.py +++ b/tests/integration/models/test_inbox.py @@ -1,7 +1,7 @@ """Test asyncpraw.models.inbox.""" import pytest -from asyncprawcore import Forbidden +from prawcore import Forbidden from asyncpraw.models import Comment, Message, Redditor, Subreddit diff --git a/tests/integration/models/test_user.py b/tests/integration/models/test_user.py index 63f347a1..879bc4f6 100644 --- a/tests/integration/models/test_user.py +++ b/tests/integration/models/test_user.py @@ -1,6 +1,6 @@ """Test asyncpraw.models.user.""" -import asyncprawcore.exceptions +import prawcore.exceptions import pytest from asyncpraw.exceptions import RedditAPIException @@ -112,7 +112,7 @@ async def test_pin__comment(self, reddit): async def test_pin__deleted_submission(self, reddit): reddit.read_only = False - with pytest.raises(asyncprawcore.exceptions.BadRequest): + with pytest.raises(prawcore.exceptions.BadRequest): await reddit.user.pin(Submission(reddit, "rmhl6m")) async def test_pin__empty_slot(self, reddit): @@ -180,7 +180,7 @@ async def test_pin__remove_num(self, reddit): async def test_pin__removed_submission(self, reddit): reddit.read_only = False - with pytest.raises(asyncprawcore.exceptions.BadRequest): + with pytest.raises(prawcore.exceptions.BadRequest): await reddit.user.pin(Submission(reddit, "rmi7ab")) async def test_pin__replace_slot(self, reddit): diff --git a/tests/integration/test_reddit.py b/tests/integration/test_reddit.py index 7a419918..f4954c79 100644 --- a/tests/integration/test_reddit.py +++ b/tests/integration/test_reddit.py @@ -3,7 +3,7 @@ from base64 import urlsafe_b64encode import pytest -from asyncprawcore.exceptions import BadRequest, ServerError +from prawcore.exceptions import BadRequest, ServerError from asyncpraw.exceptions import RedditAPIException from asyncpraw.models import LiveThread diff --git a/tests/unit/models/reddit/test_subreddit.py b/tests/unit/models/reddit/test_subreddit.py index 63ea879d..ffb63994 100644 --- a/tests/unit/models/reddit/test_subreddit.py +++ b/tests/unit/models/reddit/test_subreddit.py @@ -1,6 +1,6 @@ -import sys +import json -import aiohttp +import niquests import pytest from unittest import mock @@ -65,16 +65,17 @@ def test_hash(self, reddit): return_value=("fake_media_url", "fake_websocket_url"), ), ) - @mock.patch("aiohttp.client.ClientSession.ws_connect") + @mock.patch("niquests.AsyncSession.get") async def test_invalid_media(self, connection_mock, reddit): - reddit._core._requestor._http = aiohttp.ClientSession() - recv_mock = MagicMock() - recv_mock.receive_json = AsyncMock( - return_value={"payload": {}, "type": "failed"} + reddit._core._requestor._http = niquests.AsyncSession() + connection_mock.return_value = AsyncMock( + status_code=101, + extension=MagicMock( + next_payload=AsyncMock( + return_value=json.dumps({"payload": {}, "type": "failed"}) + ) + ), ) - context_manager = MagicMock() - context_manager.__aenter__.return_value = recv_mock - connection_mock.return_value = context_manager with pytest.raises(MediaPostFailed): await Subreddit(reddit, display_name="test").submit_image( @@ -82,7 +83,7 @@ async def test_invalid_media(self, connection_mock, reddit): ) await reddit._core._requestor._http.close() - @mock.patch("aiohttp.client.ClientSession.ws_connect", new=AsyncMock()) + @mock.patch("niquests.AsyncSession.get", new=AsyncMock()) @mock.patch( "asyncpraw.Reddit.post", new=AsyncMock( @@ -94,13 +95,13 @@ async def test_invalid_media(self, connection_mock, reddit): ) @mock.patch("asyncpraw.models.Subreddit._read_and_post_media") async def test_media_upload_500(self, mock_method, reddit): - from aiohttp.http_exceptions import HttpProcessingError - from asyncprawcore.exceptions import ServerError + from prawcore.exceptions import ServerError + from niquests.exceptions import HTTPError response = MagicMock() - response.status = 201 + response.status_code = 201 response.raise_for_status = MagicMock( - side_effect=HttpProcessingError(code=500, message="") + side_effect=HTTPError(f"Server Error", response=response) ) mock_method.return_value = response with pytest.raises(ServerError): diff --git a/tests/unit/models/test_auth.py b/tests/unit/models/test_auth.py index b1d15a52..4437992a 100644 --- a/tests/unit/models/test_auth.py +++ b/tests/unit/models/test_auth.py @@ -1,5 +1,7 @@ """Test asyncpraw.models.auth.""" +from urllib.parse import urlencode, quote, quote_plus + import pytest from asyncpraw import Reddit @@ -78,7 +80,7 @@ async def test_url__installed_app(self, installed_app): url = installed_app.auth.url(scopes=["dummy scope"], state="dummy state") assert "client_id=dummy+client" in url assert "duration=permanent" in url - assert "redirect_uri=https://dummy.tld/" in url + assert ("redirect_uri=" + quote_plus("https://dummy.tld/")) in url assert "response_type=code" in url assert "scope=dummy+scope" in url assert "state=dummy+state" in url @@ -89,7 +91,7 @@ async def test_url__installed_app__implicit(self, installed_app): ) assert "client_id=dummy+client" in url assert "duration=temporary" in url - assert "redirect_uri=https://dummy.tld/" in url + assert ("redirect_uri=" + quote_plus("https://dummy.tld/")) in url assert "response_type=token" in url assert "scope=dummy+scope" in url assert "state=dummy+state" in url @@ -98,7 +100,7 @@ def test_url__web_app(self, web_app): url = web_app.auth.url(scopes=["dummy scope"], state="dummy state") assert "client_id=dummy+client" in url assert "secret" not in url - assert "redirect_uri=https://dummy.tld/" in url + assert ("redirect_uri=" + quote_plus("https://dummy.tld/")) in url assert "response_type=code" in url assert "scope=dummy+scope" in url assert "state=dummy+state" in url diff --git a/tests/unit/test_reddit.py b/tests/unit/test_reddit.py index 7e5aae9f..b17e7259 100644 --- a/tests/unit/test_reddit.py +++ b/tests/unit/test_reddit.py @@ -3,11 +3,11 @@ import types import pytest -from asyncprawcore import Requestor -from asyncprawcore.exceptions import BadRequest +from prawcore import AsyncRequestor as Requestor +from prawcore.exceptions import BadRequest from unittest import mock -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, PropertyMock from asyncpraw import Reddit, __version__ from asyncpraw.config import Config @@ -46,11 +46,8 @@ async def test_check_for_updates_update_checker_missing(self, mock_update_check) assert not mock_update_check.called async def test_close_session(self): - temp_reddit = Reddit(**self.REQUIRED_DUMMY_SETTINGS) - assert not temp_reddit.requestor._http.closed - async with temp_reddit as reddit: + async with Reddit(**self.REQUIRED_DUMMY_SETTINGS) as reddit: pass - assert reddit.requestor._http.closed and temp_reddit.requestor._http.closed def test_comment(self, reddit): assert Comment(reddit, id="cklfmye").id == "cklfmye" @@ -71,8 +68,6 @@ def test_conflicting_settings(self): async def test_context_manager(self): async with Reddit(**self.REQUIRED_DUMMY_SETTINGS) as reddit: assert not reddit._validate_on_submit - assert not reddit.requestor._http.closed - assert reddit.requestor._http.closed def test_info__invalid_param(self, reddit): with pytest.raises(TypeError) as excinfo: @@ -455,13 +450,15 @@ def test_reddit__site_name_no_section(self): Reddit("bad_site_name") assert "asyncpraw.readthedocs.io" in excinfo.value.message - @mock.patch("asyncprawcore.sessions.Session") + @mock.patch("prawcore._async.sessions.AsyncSession") async def test_request__badrequest_with_no_json_body(self, mock_session): - response = MagicMock(status=400, text=AsyncMock(return_value="")) + response = MagicMock(status_code=400, text="") response.json.side_effect = ValueError - mock_session.return_value.request = MagicMock( - side_effect=BadRequest(response=response) - ) + + async def fake_return(*args, **kwargs): + raise BadRequest(response=response) + + mock_session.return_value.request = fake_return async with Reddit( client_id="dummy", client_secret="dummy", user_agent="dummy" diff --git a/tests/utils.py b/tests/utils.py index a2523813..3ec2ce35 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,13 +1,10 @@ """Pytest utils for integration tests.""" import json -import os -from datetime import datetime -from typing import Dict import pytest -from vcr.persisters.filesystem import FilesystemPersister -from vcr.serialize import deserialize, serialize + +from betamax.serializers import JSONSerializer from tests.conftest import placeholders as _placeholders @@ -63,105 +60,8 @@ def filter_access_token(response): return response -class CustomPersister(FilesystemPersister): - """Custom persister to handle placeholders.""" - - additional_placeholders = {} - - @classmethod - def add_additional_placeholders(cls, placeholders: Dict[str, str]): - """Add additional placeholders.""" - cls.additional_placeholders.update(placeholders) - - @classmethod - def clear_additional_placeholders(cls): - """Clear additional placeholders.""" - cls.additional_placeholders = {} - - @classmethod - def load_cassette(cls, cassette_path, serializer): - """Load cassette.""" - try: - with open(cassette_path) as f: - cassette_content = f.read() - except OSError: - raise ValueError("Cassette not found.") - for replacement, value in [ - (v, f"<{k.upper()}>") - for k, v in {**cls.additional_placeholders, **_placeholders}.items() - ]: - cassette_content = cassette_content.replace(value, replacement) - cassette = deserialize(cassette_content, serializer) - return cassette - - @classmethod - def save_cassette(cls, cassette_path, cassette_dict, serializer): - """Save cassette.""" - data = serialize(cassette_dict, serializer) - for replacement, value in [ - (f"<{k.upper()}>", v) - for k, v in {**cls.additional_placeholders, **_placeholders}.items() - ]: - data = data.replace(value, replacement) - dirname, filename = os.path.split(cassette_path) - if dirname and not os.path.exists(dirname): - os.makedirs(dirname) - with open(cassette_path, "w") as f: - f.write(data) - - -class CustomSerializer: - """Custom serializer to save in a prettified json format.""" - - @staticmethod - def _serialize_file(file_name): - with open(file_name, "rb") as f: - return f.read().decode("utf-8", "replace") - - @staticmethod - def deserialize(cassette_string): - return json.loads(cassette_string) - - @classmethod - def _serialize_dict(cls, data: dict): - """This is to filter out buffered readers.""" - new_dict = {} - for key, value in data.items(): - if key == "file": - new_dict[key] = cls._serialize_file(value.name) - elif isinstance(value, dict): - new_dict[key] = cls._serialize_dict(value) - elif isinstance(value, list): - new_dict[key] = cls._serialize_list(value) - else: - new_dict[key] = value - return new_dict - - @classmethod - def _serialize_list(cls, data: list): - new_list = [] - for item in data: - if isinstance(item, dict): - new_list.append(cls._serialize_dict(item)) - elif isinstance(item, list): - new_list.append(cls._serialize_list(item)) - elif isinstance(item, tuple): - if item[0] == "file": - item = (item[0], cls._serialize_file(item[1].name)) - new_list.append(item) - else: - new_list.append(item) - return new_list +class PrettyJSONSerializer(JSONSerializer): + name = "prettyjson" - @classmethod - def serialize(cls, cassette_dict): - """Serialize cassette.""" - timestamp = datetime.utcnow().isoformat() - try: - i = timestamp.rindex(".") - except ValueError: - pass - else: - timestamp = timestamp[:i] - cassette_dict["recorded_at"] = timestamp - return f"{json.dumps(cls._serialize_dict(cassette_dict), sort_keys=True, indent=2)}\n" + def serialize(self, cassette_data): + return f"{json.dumps(cassette_data, sort_keys=True, indent=2)}\n"