diff --git a/hub/models.py b/hub/models.py index 9de04500f..b74a9fdbe 100644 --- a/hub/models.py +++ b/hub/models.py @@ -35,6 +35,7 @@ from strawberry.dataloader import DataLoader from pyairtable import Api as AirtableAPI, Base as AirtableBase, Table as AirtableTable +from utils.asyncio import async_query User = get_user_model() @@ -987,49 +988,64 @@ def get_import_data(self): def get_import_dataframe(self): return pd.DataFrame(list(self.get_import_data())) - async def get_loaders (self) -> Loaders: - def build_source_loader(source: ExternalDataSource) -> DataLoader: - async def fetch_enrichment_data(keys: List[self.EnrichmentLookup]) -> list[str]: - print("batch querying", source) - updated_keys = keys.copy() - # batch fetch the source data, joining on the geography_column - print("enrichment_layer", source) - # pandas DF for the source data - enrichment_data = [datum async for datum in GenericData.objects.filter(**{ - f"{source.geography_column}__in": [get(key['postcode_data'], source.geography_column_type) for key in keys] - }).all()] - print("enrichment_data", enrichment_data) - enrichment_df = pd.DataFrame(enrichment_data) - print("enrichment_df", enrichment_df) - for index, key in enumerate(keys): + def data_loader_factory(self): + async def fetch_enrichment_data(keys: List[self.EnrichmentLookup]) -> list[str]: + return_data = [] + # geography_values_from_keys = list(set([ + # get(key['postcode_data'], self.geography_column_type) + # for key + # in keys + # ])) + # Using async_query because Django's own ORM breaks down in this DataLoader + enrichment_data = await async_query( + GenericData.objects.all() + # Filtering GenericData produces this error: + # psycopg.errors.InvalidTextRepresentation: invalid input syntax for type json + # LINE 1: ...ericdata"."json" -> 'council district') IN (Jsonb('Gateshead... + # GenericData.objects.filter(**{ + # f"json__'{self.geography_column}'__in": geography_values_from_keys + # }).values('json') + ) + json_list = [ + json.loads(d.json) if d.json is not None and d.json != "" else {} + for d in enrichment_data + ] + enrichment_df = pd.DataFrame.from_records(json_list) + for index, key in enumerate(keys): + try: # TODO: Use pandas to join dataframes instead - if key['source_id'] == source.id: - # query enrichment_df by matching the postcode_data to the geography_column + # query enrichment_df by matching the postcode_data to the geography_column + relevant_member_geography = get(key['postcode_data'], self.geography_column_type, "") + if relevant_member_geography == "" or relevant_member_geography is None: + return_data.append(None) + else: enrichment_value = enrichment_df.loc[ - source.geography_column == get(key['postcode_data'], source.geography_column_type) - ] - updated_keys[index]['source_data'] = enrichment_value - - return updated_keys - - def cache_key_fn (key: self.EnrichmentLookup) -> str: - return f"{key['member_id']}_{key['source_id']}" + enrichment_df[self.geography_column] == relevant_member_geography, + key['source_path'] + ].values[0] + return_data.append(enrichment_value) + except Exception as e: + return_data.append(None) - return DataLoader(load_fn=fetch_enrichment_data, cache_key_fn=cache_key_fn) - - source_loaders = {} - - async for source in ExternalDataSource.objects.filter( - organisation=self.organisation, - geography_column__isnull=False, - geography_column_type__isnull=False - ).all(): - source_loaders[str(source.id)] = build_source_loader(source) + return return_data + + def cache_key_fn (key: self.EnrichmentLookup) -> str: + return f"{key['member_id']}_{key['source_id']}" + + return DataLoader(load_fn=fetch_enrichment_data, cache_key_fn=cache_key_fn) + async def get_loaders (self) -> Loaders: loaders = self.Loaders( postcodesIO=DataLoader(load_fn=get_bulk_postcode_geo), fetch_record=DataLoader(load_fn=self.fetch_many_loader, cache=False), - source_loaders=source_loaders + source_loaders={ + str(source.id): source.data_loader_factory() + async for source in ExternalDataSource.objects.filter( + organisation=self.organisation_id, + geography_column__isnull=False, + geography_column_type__isnull=False + ).all() + } ) return loaders @@ -1057,19 +1073,14 @@ async def map_one(self, member: Union[str, dict], loaders: Loaders) -> MappedMem if postcode_data is not None: update_fields[destination_column] = get(postcode_data, source_path) else: - pass - # TODO: fix this — there's an async error when making Django requests inside the source_loader DataLoaders — then re-enable. - # print("Custom enrichment layer requested", source, source_path, destination_column) - # update_value = await loaders['source_loaders'][source].load( - # self.EnrichmentLookup( - # member_id=self.get_record_id(member), - # postcode_data=postcode_data, - # source_id=source, - # source_path=source_path - # ) - # ) - # print("Custom mapping", source, source_path, update_value) - # update_fields[destination_column] = get(update_value['source_data'], source_path) + update_fields[destination_column] = await loaders['source_loaders'][source].load( + self.EnrichmentLookup( + member_id=self.get_record_id(member), + postcode_data=postcode_data, + source_id=source, + source_path=source_path + ) + ) # Return the member and config data return self.MappedMember( member=member, @@ -1263,7 +1274,11 @@ def field_definitions(self): return [ self.FieldDefinition( label=field.name, - value=field.id, + # For `value`, we use the field name because + # because in the UI we want users to type the field name, not the field ID + # and so self.fetch_all doesn't use table(return_fields_by_field_id=True) + # TODO: implement a field ID lookup in the UI, then revisit this + value=field.name, description=field.description ) for field in self.table.schema().fields diff --git a/nextjs/src/components/UpdateMappingForm.tsx b/nextjs/src/components/UpdateMappingForm.tsx index 5d26c5ebc..35a9b873a 100644 --- a/nextjs/src/components/UpdateMappingForm.tsx +++ b/nextjs/src/components/UpdateMappingForm.tsx @@ -1,7 +1,7 @@ "use client"; import { Button } from "@/components/ui/button"; -import { enrichmentDataSources } from "@/lib/data"; +import { EnrichmentDataSource, enrichmentDataSources } from "@/lib/data"; import { FormProvider, useFieldArray, useForm } from "react-hook-form"; import { EnrichmentLayersQuery, ExternalDataSourceInput, PostcodesIoGeographyTypes } from "@/__generated__/graphql"; import { Input } from "@/components/ui/input"; @@ -68,25 +68,23 @@ export function UpdateMappingForm({ }, ); - // TODO: Fix source_loader code in API, then re-enable this - // const customEnrichmentLayers = useQuery(ENRICHMENT_LAYERS) - // const sources: EnrichmentDataSource[] = useMemo(() => { - // return enrichmentDataSources.concat( - // customEnrichmentLayers.data?.externalDataSources - // .filter(source => !!source.geographyColumn) - // .map((source) => ({ - // slug: source.id, - // name: source.name, - // author: "", - // description: "", - // descriptionURL: "", - // colour: "", - // builtIn: false, - // sourcePaths: source.fieldDefinitions || [] - // })) || [] - // ) - // }, [enrichmentDataSources, customEnrichmentLayers.data?.externalDataSources]) - const sources = enrichmentDataSources + const customEnrichmentLayers = useQuery(ENRICHMENT_LAYERS) + const sources: EnrichmentDataSource[] = useMemo(() => { + return enrichmentDataSources.concat( + customEnrichmentLayers.data?.externalDataSources + .filter(source => !!source.geographyColumn) + .map((source) => ({ + slug: source.id, + name: source.name, + author: "", + description: "", + descriptionURL: "", + colour: "", + builtIn: false, + sourcePaths: source.fieldDefinitions || [] + })) || [] + ) + }, [enrichmentDataSources, customEnrichmentLayers.data?.externalDataSources]) return ( diff --git a/utils/asyncio.py b/utils/asyncio.py new file mode 100644 index 000000000..dd6346903 --- /dev/null +++ b/utils/asyncio.py @@ -0,0 +1,22 @@ +import psycopg +from django.db import connection +from django.db.models.query import QuerySet + +async def async_query(queryset: QuerySet): + # Find and quote a database table name for a Model with users. + # table_name = connection.ops.quote_name(GenericData._meta.db_table) + # Create a new async connection. + aconnection = await psycopg.AsyncConnection.connect( + **{ + **connection.get_connection_params(), + "cursor_factory": psycopg.AsyncCursor, + }, + ) + async with aconnection: + # Create a new async cursor and execute a query. + async with aconnection.cursor() as cursor: + await cursor.execute( + str(queryset.query) + ) + results = await cursor.fetchall() + return [queryset.model(*result) for result in results] \ No newline at end of file diff --git a/utils/postcodesIO.py b/utils/postcodesIO.py index 7249cb6d1..e04d5c7c0 100644 --- a/utils/postcodesIO.py +++ b/utils/postcodesIO.py @@ -3,6 +3,7 @@ import requests from utils.py import get, get_path, batch_and_aggregate from utils.geo import create_point +import httpx from dataclasses import dataclass @@ -86,9 +87,13 @@ def get_postcode_geo(postcode: str) -> PostcodesIOResult: @batch_and_aggregate(settings.POSTCODES_IO_BATCH_MAXIMUM) async def get_bulk_postcode_geo(postcodes) -> PostcodesIOBulkResult: - response = requests.post(f'{settings.POSTCODES_IO_URL}/postcodes', json={ - "postcodes": postcodes - },) + async with httpx.AsyncClient() as client: + response = await client.post(f'{settings.POSTCODES_IO_URL}/postcodes', json={ + "postcodes": postcodes + },) + if response.status_code != httpx.codes.OK: + raise Exception(f'Failed to bulk geocode postcodes: {postcodes}.') + data = response.json() status = get(data, 'status') result: List[ResultElement] = get(data, 'result') diff --git a/utils/py.py b/utils/py.py index 07f830782..a17da298c 100644 --- a/utils/py.py +++ b/utils/py.py @@ -1,9 +1,11 @@ from benedict import benedict -def get(d, path): +def get(d, path, default=None): if isinstance(d, benedict): - return d[path] - return benedict(d)[path] + val = d[path] + else: + val = benedict(d)[path] + return val if val is not None else default def is_sequence(arg): if isinstance(arg, str):