Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
janbaykara committed Mar 13, 2024
1 parent cd28158 commit 21f859d
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 75 deletions.
66 changes: 26 additions & 40 deletions hub/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from strawberry.dataloader import DataLoader
from pyairtable import Api as AirtableAPI, Base as AirtableBase, Table as AirtableTable
from utils.asyncio import async_query
from utils.asyncio import async_queryset

User = get_user_model()

Expand Down Expand Up @@ -991,39 +991,22 @@ def get_import_data(self):
)

def get_import_dataframe(self):
return pd.DataFrame(list(self.get_import_data()))
enrichment_data = self.get_import_data()
json_list = [d.json for d in enrichment_data]
enrichment_df = pd.DataFrame.from_records(json_list)
return enrichment_df

def data_loader_factory(self):
async def fetch_enrichment_data(keys: List[self.EnrichmentLookup]) -> list[str]:
return_data = []
# geography_values_from_keys = list(set([
# get(key['postcode_data'], self.geography_column_type)
# for key
# in keys
# ]))
# Using async_query because Django's own ORM breaks down in this DataLoader
enrichment_data = await async_query(
GenericData.objects.all()
# Filtering GenericData produces this error:
# psycopg.errors.InvalidTextRepresentation: invalid input syntax for type json
# LINE 1: ...ericdata"."json" -> 'council district') IN (Jsonb('Gateshead...
# GenericData.objects.filter(**{
# f"json__'{self.geography_column}'__in": geography_values_from_keys
# }).values('json')
)
json_list = [
json.loads(d.json) if d.json is not None and d.json != "" else {}
for d in enrichment_data
]
enrichment_df = pd.DataFrame.from_records(json_list)
for index, key in enumerate(keys):
enrichment_df = await sync_to_async(self.get_import_dataframe)()
for key in keys:
try:
# TODO: Use pandas to join dataframes instead
# query enrichment_df by matching the postcode_data to the geography_column
relevant_member_geography = get(key['postcode_data'], self.geography_column_type, "")
if relevant_member_geography == "" or relevant_member_geography is None:
return_data.append(None)
else:
# TODO: Use pandas to join dataframes instead
enrichment_value = enrichment_df.loc[
enrichment_df[self.geography_column] == relevant_member_geography,
key['source_path']
Expand Down Expand Up @@ -1078,16 +1061,19 @@ async def map_one(self, member: Union[str, dict], loaders: Loaders) -> MappedMem
if postcode_data is not None:
update_fields[destination_column] = get(postcode_data, source_path)
else:
source_loader = loaders['source_loaders'].get(source, None)
if source_loader is not None:
update_fields[destination_column] = await source_loader.load(
self.EnrichmentLookup(
member_id=self.get_record_id(member),
postcode_data=postcode_data,
source_id=source,
source_path=source_path
try:
source_loader = loaders['source_loaders'].get(source, None)
if source_loader is not None:
update_fields[destination_column] = await source_loader.load(
self.EnrichmentLookup(
member_id=self.get_record_id(member),
postcode_data=postcode_data,
source_id=source,
source_path=source_path
)
)
)
except Exception as e:
continue
# Return the member and config data
return self.MappedMember(
member=member,
Expand Down Expand Up @@ -1118,21 +1104,21 @@ async def refresh_one(self, member_id: Union[str, any]):
return
loaders = await self.get_loaders()
mapped_record = await self.map_one(member_id, loaders)
await self.update_one(mapped_record=mapped_record)
return await self.update_one(mapped_record=mapped_record)

async def refresh_many(self, member_ids: list[Union[str, any]]):
if len(self.get_update_mapping()) == 0:
return
loaders = await self.get_loaders()
mapped_records = await self.map_many(member_ids, loaders)
await self.update_many(mapped_records=mapped_records)
return await self.update_many(mapped_records=mapped_records)

async def refresh_all(self):
if len(self.get_update_mapping()) == 0:
return
loaders = await self.get_loaders()
mapped_records = await self.map_all(loaders)
await self.update_all(mapped_records=mapped_records)
return await self.update_all(mapped_records=mapped_records)

# UI

Expand Down Expand Up @@ -1325,10 +1311,10 @@ def get_record_dict(self, record):
return record['fields']

async def update_one(self, mapped_record):
self.table.update(mapped_record['member']['id'], mapped_record['update_fields'])
return self.table.update(mapped_record['member']['id'], mapped_record['update_fields'])

async def update_many(self, mapped_records):
self.table.batch_update([
return self.table.batch_update([
{
"id": mapped_record['member']['id'],
"fields": mapped_record['update_fields']
Expand All @@ -1337,7 +1323,7 @@ async def update_many(self, mapped_records):
])

async def update_all(self, mapped_records):
self.table.batch_update([
return self.table.batch_update([
{
"id": mapped_record['member']['id'],
"fields": mapped_record['update_fields']
Expand Down
145 changes: 119 additions & 26 deletions hub/tests/test_sources.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,75 @@
from django.test import TestCase
from django.conf import settings
from datetime import datetime
from asgiref.sync import async_to_sync
from asgiref.sync import async_to_sync, sync_to_async

from hub.models import AirtableSource
from hub.models import AirtableSource, Organisation


class TestAirtableSource(TestCase):
### Test prep
source: AirtableSource

def setUp(self) -> None:
self.records_to_delete = []
self.records_to_delete: list[tuple[str, AirtableSource]] = []

self.organisation = Organisation.objects.create(
name="Test Organisation",
slug="test-organisation"
)

self.custom_data_layer: AirtableSource = AirtableSource.objects.create(
name="Mayoral regions custom data layer",
data_type=AirtableSource.DataSourceType.OTHER,
organisation=self.organisation,
base_id=settings.TEST_AIRTABLE_CUSTOMDATALAYER_BASE_ID,
table_id=settings.TEST_AIRTABLE_CUSTOMDATALAYER_TABLE_NAME,
api_key=settings.TEST_AIRTABLE_CUSTOMDATALAYER_API_KEY,
geography_column="council district",
geography_column_type=AirtableSource.PostcodesIOGeographyTypes.COUNCIL,
)

self.source: AirtableSource = AirtableSource.objects.create(
name="Test Airtable Source",
base_id=settings.TEST_AIRTABLE_BASE_ID,
table_id=settings.TEST_AIRTABLE_TABLE_NAME,
api_key=settings.TEST_AIRTABLE_API_KEY,
name="My test Airtable member list",
data_type=AirtableSource.DataSourceType.MEMBER,
organisation=self.organisation,
base_id=settings.TEST_AIRTABLE_MEMBERLIST_BASE_ID,
table_id=settings.TEST_AIRTABLE_MEMBERLIST_TABLE_NAME,
api_key=settings.TEST_AIRTABLE_MEMBERLIST_API_KEY,
geography_column="Postcode",
geography_column_type=AirtableSource.PostcodesIOGeographyTypes.POSTCODE,
auto_update_enabled=True,
update_mapping=[
{
"source": "postcodes.io",
"source_path": "parliamentary_constituency_2025",
"destination_column": "constituency"
},
{
"source": str(self.custom_data_layer.id),
"source_path": "mayoral region",
"destination_column": "mayoral region"
}
]
)

self.source.teardown_webhooks()

def tearDown(self) -> None:
for record_id in self.records_to_delete:
self.source.table.delete(record_id)
for record_id, source in self.records_to_delete:
source.table.delete(record_id)
self.source.teardown_webhooks()
return super().tearDown()

def create_test_record(self, record):
record = self.source.table.create(record)
self.records_to_delete.append(record['id'])
def create_test_record(self, record, source=None):
source = source or self.source
record = source.table.create(record)
self.records_to_delete.append((record['id'], source))
return record

def create_many_test_records(self, records):
records = self.source.table.batch_create(records)
self.records_to_delete += [record['id'] for record in records]
def create_many_test_records(self, records, source=None):
source = source or self.source
records = source.table.batch_create(records)
self.records_to_delete += [(record['id'], source)for record in records]
return records

### Tests begin
Expand All @@ -57,21 +83,61 @@ async def test_airtable_webhooks(self):
self.source.setup_webhooks()
self.assertTrue(self.source.webhook_healthcheck())

async def test_import_async(self):
self.create_many_test_records([
{
"council district": "County Durham",
"mayoral region": "North East Mayoral Combined Authority"
},
{
"council district": "Northumberland",
"mayoral region": "North East Mayoral Combined Authority"
}
], source=self.custom_data_layer)
await sync_to_async(self.custom_data_layer.import_all)()
enrichment_df = await sync_to_async(self.custom_data_layer.get_import_dataframe)()
self.assertGreaterEqual(len(enrichment_df.index), 2)

def test_import_all(self):
# Confirm the database is empty
original_count = self.source.get_import_data().count()
assert original_count == 0
original_count = self.custom_data_layer.get_import_data().count()
self.assertEqual(original_count, 0)
# Add some test data
self.create_many_test_records([
{ "Postcode": "import_test_1" },
{ "Postcode": "import_test_2" }
])
assert len(list(async_to_sync(self.source.fetch_all)())) >= 2
{
"council district": "County Durham",
"mayoral region": "North East Mayoral Combined Authority"
},
{
"council district": "Northumberland",
"mayoral region": "North East Mayoral Combined Authority"
}
], source=self.custom_data_layer)
self.assertGreaterEqual(len(list(async_to_sync(self.custom_data_layer.fetch_all)())), 2)
# Check that the import is storing it all
fetch_count = len(list(async_to_sync(self.source.fetch_all)()))
self.source.import_all()
import_count = self.source.get_import_data().count()
assert import_count == fetch_count
fetch_count = len(list(async_to_sync(self.custom_data_layer.fetch_all)()))
self.custom_data_layer.import_all()
import_data = self.custom_data_layer.get_import_data()
import_count = len(import_data)
self.assertEqual(import_count, fetch_count)
# assert that 'council district' and 'mayoral region' keys are in the JSON object
self.assertIn("council district", import_data[0].json)
self.assertIn("mayoral region", import_data[0].json)
self.assertIn(import_data[0].json['council district'], [
"Newcastle upon Tyne",
"North Tyneside",
"South Tyneside",
"Gateshead",
"County Durham",
"Sunderland",
"Northumberland",
])
self.assertIn(import_data[0].json['mayoral region'], ["North East Mayoral Combined Authority"])
df = self.custom_data_layer.get_import_dataframe()
# assert len(df.index) == import_count
self.assertIn("council district", list(df.columns.values))
self.assertIn("mayoral region", list(df.columns.values))
self.assertEqual(len(df.index), import_count)

async def test_airtable_fetch_one(self):
record = self.create_test_record({ "Postcode": "EH99 1SP" })
Expand Down Expand Up @@ -109,6 +175,33 @@ async def test_airtable_refresh_one(self):
"Edinburgh East and Musselburgh"
)

def test_pivot_table(self):
'''
This is testing the ability for self.source to be updated using data from self.custom_data_layer
i.e. to test the pivot table functionality
that brings custom campaign data back into the CRM, based on geography
'''
# Add some test data
self.create_many_test_records([
{
"council district": "County Durham",
"mayoral region": "North East Mayoral Combined Authority"
},
{
"council district": "Northumberland",
"mayoral region": "North East Mayoral Combined Authority"
}
], source=self.custom_data_layer)
# Check that the import is storing it all
self.custom_data_layer.import_all()
# Add a test record
record = self.create_test_record({ "Postcode": "NE12 6DD" })
mapped_member = async_to_sync(self.source.map_one)(record, loaders=async_to_sync(self.source.get_loaders)())
self.assertEqual(
mapped_member['update_fields']['mayoral region'],
"North East Mayoral Combined Authority"
)

async def test_airtable_refresh_many(self):
records = self.create_many_test_records([
{ "Postcode": "G11 5RD" },
Expand Down
18 changes: 12 additions & 6 deletions local_intelligence_hub/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@
HIDE_DEBUG_TOOLBAR=(bool, False),
GOOGLE_ANALYTICS=(str, ""),
GOOGLE_SITE_VERIFICATION=(str, ""),
TEST_AIRTABLE_BASE_ID=(str, ""),
TEST_AIRTABLE_TABLE_NAME=(str, ""),
TEST_AIRTABLE_API_KEY=(str, ""),
TEST_AIRTABLE_MEMBERLIST_BASE_ID=(str, ""),
TEST_AIRTABLE_MEMBERLIST_TABLE_NAME=(str, ""),
TEST_AIRTABLE_MEMBERLIST_API_KEY=(str, ""),
TEST_AIRTABLE_CUSTOMDATALAYER_BASE_ID=(str, ""),
TEST_AIRTABLE_CUSTOMDATALAYER_TABLE_NAME=(str, ""),
TEST_AIRTABLE_CUSTOMDATALAYER_API_KEY=(str, ""),
DJANGO_LOG_LEVEL=(str, "INFO"),
)
environ.Env.read_env(BASE_DIR / ".env")
Expand All @@ -55,9 +58,12 @@
MAPIT_API_KEY = env("MAPIT_API_KEY")
GOOGLE_ANALYTICS = env("GOOGLE_ANALYTICS")
GOOGLE_SITE_VERIFICATION = env("GOOGLE_SITE_VERIFICATION")
TEST_AIRTABLE_BASE_ID=env("TEST_AIRTABLE_BASE_ID")
TEST_AIRTABLE_TABLE_NAME=env("TEST_AIRTABLE_TABLE_NAME")
TEST_AIRTABLE_API_KEY=env("TEST_AIRTABLE_API_KEY")
TEST_AIRTABLE_MEMBERLIST_BASE_ID=env("TEST_AIRTABLE_MEMBERLIST_BASE_ID")
TEST_AIRTABLE_MEMBERLIST_TABLE_NAME=env("TEST_AIRTABLE_MEMBERLIST_TABLE_NAME")
TEST_AIRTABLE_MEMBERLIST_API_KEY=env("TEST_AIRTABLE_MEMBERLIST_API_KEY")
TEST_AIRTABLE_CUSTOMDATALAYER_BASE_ID=env("TEST_AIRTABLE_CUSTOMDATALAYER_BASE_ID")
TEST_AIRTABLE_CUSTOMDATALAYER_TABLE_NAME=env("TEST_AIRTABLE_CUSTOMDATALAYER_TABLE_NAME")
TEST_AIRTABLE_CUSTOMDATALAYER_API_KEY=env("TEST_AIRTABLE_CUSTOMDATALAYER_API_KEY")

# make sure CSRF checking still works behind load balancers
SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTO", "https")
Expand Down
20 changes: 17 additions & 3 deletions utils/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,20 @@
from django.db import connection
from django.db.models.query import QuerySet

async def async_query(queryset: QuerySet):

async def async_queryset(queryset: QuerySet, args: list[str] = []):
query = str(queryset.query)
for arg in args:
query = query.replace(str(arg), "%s")
args = [str(arg) for arg in args]
results = await async_query(
query,
args
)
return [queryset.model(*result) for result in results]


async def async_query(query: str, args: list[str] = []):
# Find and quote a database table name for a Model with users.
# table_name = connection.ops.quote_name(GenericData._meta.db_table)
# Create a new async connection.
Expand All @@ -16,7 +29,8 @@ async def async_query(queryset: QuerySet):
# Create a new async cursor and execute a query.
async with aconnection.cursor() as cursor:
await cursor.execute(
str(queryset.query)
query,
args
)
results = await cursor.fetchall()
return [queryset.model(*result) for result in results]
return results

0 comments on commit 21f859d

Please sign in to comment.