Skip to content

Commit

Permalink
Test refresh access token
Browse files Browse the repository at this point in the history
  • Loading branch information
maximearmstrong committed Aug 8, 2024
1 parent d9720b3 commit 865a662
Showing 1 changed file with 76 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import datetime
import json
import re
from unittest import mock

import pytest
import responses
Expand All @@ -8,7 +11,9 @@

@responses.activate
def test_trigger_connection() -> None:
ab_resource = AirbyteCloudResource(client_id="some_client_id", client_secret="some_client_secret", poll_interval=0)
ab_resource = AirbyteCloudResource(
client_id="some_client_id", client_secret="some_client_secret", poll_interval=0
)
responses.add(
responses.POST,
f"{ab_resource.api_base_url}/applications/token",
Expand Down Expand Up @@ -48,7 +53,9 @@ def test_trigger_connection_fail() -> None:
[AirbyteState.SUCCEEDED, AirbyteState.CANCELLED, AirbyteState.ERROR, "unrecognized"],
)
def test_sync_and_poll(state) -> None:
ab_resource = AirbyteCloudResource(client_id="some_client_id", client_secret="some_client_secret", poll_interval=0)
ab_resource = AirbyteCloudResource(
client_id="some_client_id", client_secret="some_client_secret", poll_interval=0
)

responses.add(
responses.POST,
Expand Down Expand Up @@ -98,7 +105,9 @@ def test_sync_and_poll(state) -> None:

@responses.activate
def test_start_sync_bad_out_fail() -> None:
ab_resource = AirbyteCloudResource(client_id="some_client_id", client_secret="some_client_secret", poll_interval=0)
ab_resource = AirbyteCloudResource(
client_id="some_client_id", client_secret="some_client_secret", poll_interval=0
)

responses.add(
responses.POST,
Expand All @@ -116,3 +125,67 @@ def test_start_sync_bad_out_fail() -> None:
match=re.escape("Max retries (3) exceeded with url: https://api.airbyte.com/v1/jobs."),
):
ab_resource.start_sync("some_connection")


@responses.activate
def test_refresh_access_token() -> None:
ab_resource = AirbyteCloudResource(
client_id="some_client_id", client_secret="some_client_secret", poll_interval=0
)
responses.add(
responses.POST,
f"{ab_resource.api_base_url}/applications/token",
json={"access_token": "some_access_token"},
)
responses.add(
method=responses.POST,
url=ab_resource.api_base_url + "/jobs",
json={"jobId": 1, "status": "pending", "jobType": "sync"},
status=200,
)

test_time_first_call = datetime.datetime(2024, 1, 1, 0, 0, 0)
test_time_before_expiration = datetime.datetime(2024, 1, 1, 0, 2, 0)
test_time_after_expiration = datetime.datetime(2024, 1, 1, 0, 3, 0)
with mock.patch("dagster_airbyte.resources.datetime", wraps=datetime.datetime) as dt:
# Test first call, must get the access token before calling the jobs api
dt.now.return_value = test_time_first_call
ab_resource.start_sync("some_connection")

assert len(responses.calls) == 2
access_token_call = responses.calls[0]
jobs_api_call = responses.calls[1]

assert "Authorization" not in access_token_call.request.headers
access_token_call_body = json.loads(access_token_call.request.body.decode("utf-8"))
assert access_token_call_body["client_id"] == "some_client_id"
assert access_token_call_body["client_secret"] == "some_client_secret"
assert jobs_api_call.request.headers["Authorization"] == "Bearer some_access_token"

responses.calls.reset()

# Test second call, occurs before the access token expiration, only the jobs api is called
dt.now.return_value = test_time_before_expiration
ab_resource.start_sync("some_connection")

assert len(responses.calls) == 1
jobs_api_call = responses.calls[0]

assert jobs_api_call.request.headers["Authorization"] == "Bearer some_access_token"

responses.calls.reset()

# Test third call, occurs after the token expiration,
# must refresh the access token before calling the jobs api
dt.now.return_value = test_time_after_expiration
ab_resource.start_sync("some_connection")

assert len(responses.calls) == 2
access_token_call = responses.calls[0]
jobs_api_call = responses.calls[1]

assert "Authorization" not in access_token_call.request.headers
access_token_call_body = json.loads(access_token_call.request.body.decode("utf-8"))
assert access_token_call_body["client_id"] == "some_client_id"
assert access_token_call_body["client_secret"] == "some_client_secret"
assert jobs_api_call.request.headers["Authorization"] == "Bearer some_access_token"

0 comments on commit 865a662

Please sign in to comment.