Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
maximearmstrong committed Aug 8, 2024
1 parent 4f28103 commit d9720b3
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,11 @@ def _refresh_access_token(self) -> None:
def _needs_refreshed_access_token(self) -> bool:
# The access token expire every 3 minutes in Airbyte Cloud.
# Refresh is needed after 2.5 minutes to avoid the "token expired" error message.
return not self._access_token_value or self._access_token_timestamp <= datetime.timestamp(
datetime.now() - timedelta(seconds=150)
return (
not self._access_token_value
or not self._access_token_timestamp
or self._access_token_timestamp
<= datetime.timestamp(datetime.now() - timedelta(seconds=150))
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ def test_assets_cloud() -> None:
)

with responses.RequestsMock() as rsps:
rsps.add(
rsps.POST,
f"{ab_url}/applications/token",
json={"access_token": "some_access_token"},
)
rsps.add(
rsps.POST,
f"{ab_url}/jobs",
Expand All @@ -238,11 +243,6 @@ def test_assets_cloud() -> None:
f"{ab_url}/jobs/1",
json={"jobId": 1, "status": "succeeded", "jobType": "sync"},
)
rsps.add(
rsps.POST,
f"{ab_url}/applications/token",
json={"access_token": "some_access_token"},
)

res = materialize_to_memory(
ab_assets,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@

@responses.activate
def test_trigger_connection() -> None:
ab_resource = AirbyteCloudResource(api_key="some_key", poll_interval=0)
ab_resource = AirbyteCloudResource(client_id="some_client_id", client_secret="some_client_secret", poll_interval=0)
responses.add(
responses.POST,
f"{ab_resource.api_base_url}/applications/token",
json={"access_token": "some_access_token"},
)
responses.add(
method=responses.POST,
url=ab_resource.api_base_url + "/jobs",
Expand All @@ -19,8 +24,17 @@ def test_trigger_connection() -> None:
assert resp == {"job": {"id": 1, "status": "pending"}}


@responses.activate
def test_trigger_connection_fail() -> None:
ab_resource = AirbyteCloudResource(api_key="some_key")
ab_resource = AirbyteCloudResource(
client_id="some_client_id", client_secret="some_client_secret"
)
responses.add(
responses.POST,
f"{ab_resource.api_base_url}/applications/token",
json={"access_token": "some_access_token"},
)

with pytest.raises(
Failure,
match=re.escape("Max retries (3) exceeded with url: https://api.airbyte.com/v1/jobs."),
Expand All @@ -34,7 +48,13 @@ def test_trigger_connection_fail() -> None:
[AirbyteState.SUCCEEDED, AirbyteState.CANCELLED, AirbyteState.ERROR, "unrecognized"],
)
def test_sync_and_poll(state) -> None:
ab_resource = AirbyteCloudResource(api_key="some_key", poll_interval=0)
ab_resource = AirbyteCloudResource(client_id="some_client_id", client_secret="some_client_secret", poll_interval=0)

responses.add(
responses.POST,
f"{ab_resource.api_base_url}/applications/token",
json={"access_token": "some_access_token"},
)
responses.add(
method=responses.POST,
url=ab_resource.api_base_url + "/jobs",
Expand Down Expand Up @@ -78,8 +98,13 @@ def test_sync_and_poll(state) -> None:

@responses.activate
def test_start_sync_bad_out_fail() -> None:
ab_resource = AirbyteCloudResource(api_key="some_key", poll_interval=0)
ab_resource = AirbyteCloudResource(client_id="some_client_id", client_secret="some_client_secret", poll_interval=0)

responses.add(
responses.POST,
f"{ab_resource.api_base_url}/applications/token",
json={"access_token": "some_access_token"},
)
responses.add(
method=responses.POST,
url=ab_resource.api_base_url + "/jobs",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,9 @@ def downstream_asset(dagster_tags):


def test_load_from_instance_cloud() -> None:
airbyte_cloud_instance = AirbyteCloudResource(api_key="foo", poll_interval=0)
airbyte_cloud_instance = AirbyteCloudResource(
client_id="some_client_id", client_secret="some_client_secret", poll_interval=0
)

with pytest.raises(
DagsterInvalidInvocationError,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from base64 import b64encode

import pytest
Expand Down Expand Up @@ -102,7 +103,9 @@ def airbyte_sync_job():


def test_airbyte_sync_op_cloud() -> None:
ab_resource = AirbyteCloudResource(api_key="some_key")
ab_resource = AirbyteCloudResource(
client_id="some_client_id", client_secret="some_client_secret"
)
ab_url = ab_resource.api_base_url

@op
Expand All @@ -127,6 +130,11 @@ def airbyte_sync_job() -> None:
airbyte_sync_op(start_after=foo_op())

with responses.RequestsMock() as rsps:
rsps.add(
rsps.POST,
f"{ab_url}/applications/token",
json={"access_token": "some_access_token"},
)
rsps.add(
rsps.POST,
f"{ab_url}/jobs",
Expand All @@ -150,5 +158,14 @@ def airbyte_sync_job() -> None:
connection_details={},
)

for call in rsps.calls:
assert call.request.headers["Authorization"] == "Bearer some_key"
# The first call is to get the access token.
access_token_call = rsps.calls[0]
api_calls = rsps.calls[1:]

assert "Authorization" not in access_token_call.request.headers
access_token_call_body = json.loads(access_token_call.request.body.decode("utf-8"))
assert access_token_call_body["client_id"] == "some_client_id"
assert access_token_call_body["client_secret"] == "some_client_secret"

for call in api_calls:
assert call.request.headers["Authorization"] == "Bearer some_access_token"

0 comments on commit d9720b3

Please sign in to comment.