From 6353ae44de4fcf269ba4cfe25a38d34b85c677d2 Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Thu, 20 Mar 2025 13:48:27 +0100 Subject: [PATCH] STAC: update stac-fastapi version to 5.0 --- infrastructure/handlers/stac_handler.py | 3 +- runtimes/eoapi/stac/eoapi/stac/app.py | 5 +- runtimes/eoapi/stac/eoapi/stac/client.py | 211 +++++++++-------------- runtimes/eoapi/stac/eoapi/stac/config.py | 10 +- runtimes/eoapi/stac/pyproject.toml | 2 +- 5 files changed, 96 insertions(+), 135 deletions(-) diff --git a/infrastructure/handlers/stac_handler.py b/infrastructure/handlers/stac_handler.py index 1bc4a66..fd6245d 100644 --- a/infrastructure/handlers/stac_handler.py +++ b/infrastructure/handlers/stac_handler.py @@ -5,6 +5,7 @@ import os from eoapi.stac.app import app +from eoapi.stac.config import PostgresSettings from mangum import Mangum from stac_fastapi.pgstac.db import connect_to_db @@ -15,7 +16,7 @@ @app.on_event("startup") async def startup_event() -> None: """Connect to database on startup.""" - await connect_to_db(app) + await connect_to_db(app, postgres_settings=PostgresSettings()) handler = Mangum(app, lifespan="off") diff --git a/runtimes/eoapi/stac/eoapi/stac/app.py b/runtimes/eoapi/stac/eoapi/stac/app.py index 12070a7..0e660d6 100644 --- a/runtimes/eoapi/stac/eoapi/stac/app.py +++ b/runtimes/eoapi/stac/eoapi/stac/app.py @@ -44,7 +44,7 @@ from . import __version__ as eoapi_devseed_version from .api import StacApi from .client import FiltersClient, PgSTACClient -from .config import Settings +from .config import PostgresSettings, Settings from .extensions import ( HTMLorGeoMultiOutputExtension, HTMLorGeoOutputExtension, @@ -65,6 +65,7 @@ templates = Jinja2Templates(env=jinja2_env) settings = Settings() +pg_settings = PostgresSettings() auth_settings = OpenIdConnectSettings() @@ -172,7 +173,7 @@ @asynccontextmanager async def lifespan(app: FastAPI): """FastAPI Lifespan.""" - await connect_to_db(app) + await connect_to_db(app, postgres_settings=pg_settings) yield await close_db_connection(app) diff --git a/runtimes/eoapi/stac/eoapi/stac/client.py b/runtimes/eoapi/stac/eoapi/stac/client.py index 3814471..7539333 100644 --- a/runtimes/eoapi/stac/eoapi/stac/client.py +++ b/runtimes/eoapi/stac/eoapi/stac/client.py @@ -14,20 +14,19 @@ Type, get_args, ) -from urllib.parse import unquote_plus, urlencode, urljoin +from urllib.parse import unquote_plus, urlencode import attr import jinja2 import orjson -from fastapi import Request +from fastapi import HTTPException, Request from geojson_pydantic.geometries import parse_geometry_obj +from pydantic import ValidationError from stac_fastapi.api.models import JSONResponse from stac_fastapi.pgstac.core import CoreCrudClient from stac_fastapi.pgstac.extensions.filter import FiltersClient as PgSTACFiltersClient from stac_fastapi.pgstac.models.links import ItemCollectionLinks from stac_fastapi.pgstac.types.search import PgstacSearch -from stac_fastapi.types.errors import NotFoundError -from stac_fastapi.types.requests import get_base_url from stac_fastapi.types.stac import ( Collection, Collections, @@ -275,7 +274,6 @@ class PgSTACClient(CoreCrudClient): async def landing_page( self, - request: Request, f: Optional[str] = None, **kwargs, ) -> LandingPage: @@ -287,67 +285,9 @@ async def landing_page( API landing page, serving as an entry point to the API. """ - base_url = get_base_url(request) - - landing_page = self._landing_page( - base_url=base_url, - conformance_classes=self.conformance_classes(), - extension_schemas=[], - ) - - # Add Queryables link - if self.extension_is_enabled("FilterExtension") or self.extension_is_enabled( - "SearchFilterExtension" - ): - landing_page["links"].append( - { - "rel": Relations.queryables.value, - "type": MimeTypes.jsonschema.value, - "title": "Queryables", - "href": urljoin(base_url, "queryables"), - } - ) - - # Add Aggregation links - if self.extension_is_enabled("AggregationExtension"): - landing_page["links"].extend( - [ - { - "rel": "aggregate", - "type": "application/json", - "title": "Aggregate", - "href": urljoin(base_url, "aggregate"), - }, - { - "rel": "aggregations", - "type": "application/json", - "title": "Aggregations", - "href": urljoin(base_url, "aggregations"), - }, - ] - ) - - # Add OpenAPI URL - landing_page["links"].append( - { - "rel": Relations.service_desc.value, - "type": MimeTypes.openapi.value, - "title": "OpenAPI service description", - "href": str(request.url_for("openapi")), - } - ) + request: Request = kwargs["request"] - # Add human readable service-doc - landing_page["links"].append( - { - "rel": Relations.service_doc.value, - "type": MimeTypes.html.value, - "title": "OpenAPI service documentation", - "href": str(request.url_for("swagger_ui_html")), - } - ) - - landing = LandingPage(**landing_page) + landing = await super().landing_page(**kwargs) output_type: Optional[MimeTypes] if f: @@ -476,6 +416,37 @@ async def get_collection( return collection + async def get_item( + self, + item_id: str, + collection_id: str, + request: Request, + f: Optional[str] = None, + **kwargs, + ) -> Item: + item = await super().get_item(item_id, collection_id, request, **kwargs) + + output_type: Optional[MimeTypes] + if f: + output_type = MimeTypes[f] + else: + accepted_media = [MimeTypes[v] for v in get_args(GeoResponseType)] + output_type = accept_media_type( + request.headers.get("accept", ""), accepted_media + ) + + if output_type == MimeTypes.html: + return create_html_response( + request, + item, + template_name="item", + title=f"{collection_id}/{item_id} item", + ) + + return item + + # NOTE: We can't use `super.item_collection(...)` because of the `fields` extension + # which, when used, might return a JSONResponse directly instead of a ItemCollection (TypeDict) async def item_collection( self, collection_id: str, @@ -493,16 +464,6 @@ async def item_collection( f: Optional[str] = None, **kwargs, ) -> ItemCollection: - output_type: Optional[MimeTypes] - if f: - output_type = MimeTypes[f] - else: - accepted_media = [MimeTypes[v] for v in get_args(GeoMultiResponseType)] - output_type = accept_media_type( - request.headers.get("accept", ""), accepted_media - ) - - # Check if collection exist await self.get_collection(collection_id, request=request) base_args = { @@ -521,12 +482,30 @@ async def item_collection( sortby=sortby, ) - search_request = self.pgstac_search_model(**clean) + try: + search_request = self.pgstac_search_model(**clean) + except ValidationError as e: + raise HTTPException( + status_code=400, detail=f"Invalid parameters provided {e}" + ) from e + item_collection = await self._search_base(search_request, request=request) item_collection["links"] = await ItemCollectionLinks( collection_id=collection_id, request=request ).get_links(extra_links=item_collection["links"]) + ####################################################################### + # Custom Responses + ####################################################################### + output_type: Optional[MimeTypes] + if f: + output_type = MimeTypes[f] + else: + accepted_media = [MimeTypes[v] for v in get_args(GeoMultiResponseType)] + output_type = accept_media_type( + request.headers.get("accept", ""), accepted_media + ) + # Additional Headers for StreamingResponse additional_headers = {} links = item_collection.get("links", []) @@ -581,45 +560,8 @@ async def item_collection( return ItemCollection(**item_collection) - async def get_item( - self, - item_id: str, - collection_id: str, - request: Request, - f: Optional[str] = None, - **kwargs, - ) -> Item: - output_type: Optional[MimeTypes] - if f: - output_type = MimeTypes[f] - else: - accepted_media = [MimeTypes[v] for v in get_args(GeoResponseType)] - output_type = accept_media_type( - request.headers.get("accept", ""), accepted_media - ) - - # Check if collection exist - await self.get_collection(collection_id, request=request) - - search_request = self.pgstac_search_model( - ids=[item_id], collections=[collection_id], limit=1 - ) - item_collection = await self._search_base(search_request, request=request) - if not item_collection["features"]: - raise NotFoundError( - f"Item {item_id} in Collection {collection_id} does not exist." - ) - - if output_type == MimeTypes.html: - return create_html_response( - request, - item_collection["features"][0], - template_name="item", - title=f"{collection_id}/{item_id} item", - ) - - return Item(**item_collection["features"][0]) - + # NOTE: We can't use `super.get_search(...)` because of the `fields` extension + # which, when used, might return a JSONResponse directly instead of a ItemCollection (TypeDict) async def get_search( self, request: Request, @@ -639,16 +581,6 @@ async def get_search( f: Optional[str] = None, **kwargs, ) -> ItemCollection: - output_type: Optional[MimeTypes] - if f: - output_type = MimeTypes[f] - else: - accepted_media = [MimeTypes[v] for v in get_args(GeoMultiResponseType)] - output_type = accept_media_type( - request.headers.get("accept", ""), accepted_media - ) - - # Parse request parameters base_args = { "collections": collections, "ids": ids, @@ -668,9 +600,27 @@ async def get_search( filter_lang=filter_lang, ) - search_request = self.pgstac_search_model(**clean) + try: + search_request = self.pgstac_search_model(**clean) + except ValidationError as e: + raise HTTPException( + status_code=400, detail=f"Invalid parameters provided {e}" + ) from e + item_collection = await self._search_base(search_request, request=request) + ####################################################################### + # Custom Responses + ####################################################################### + output_type: Optional[MimeTypes] + if f: + output_type = MimeTypes[f] + else: + accepted_media = [MimeTypes[v] for v in get_args(GeoMultiResponseType)] + output_type = accept_media_type( + request.headers.get("accept", ""), accepted_media + ) + # Additional Headers for StreamingResponse additional_headers = {} links = item_collection.get("links", []) @@ -720,19 +670,24 @@ async def get_search( return ItemCollection(**item_collection) + # NOTE: We can't use `super.post_search(...)` because of the `fields` extension + # which, when used, might return a JSONResponse directly instead of a ItemCollection (TypeDict) async def post_search( self, search_request: PgstacSearch, request: Request, **kwargs, ) -> ItemCollection: + item_collection = await self._search_base(search_request, request=request) + + ####################################################################### + # Custom Responses + ####################################################################### accepted_media = [MimeTypes[v] for v in get_args(PostMultiResponseType)] output_type = accept_media_type( request.headers.get("accept", ""), accepted_media ) - item_collection = await self._search_base(search_request, request=request) - # Additional Headers for StreamingResponse additional_headers = {} links = item_collection.get("links", []) diff --git a/runtimes/eoapi/stac/eoapi/stac/config.py b/runtimes/eoapi/stac/eoapi/stac/config.py index 1f03025..e5c1bb1 100644 --- a/runtimes/eoapi/stac/eoapi/stac/config.py +++ b/runtimes/eoapi/stac/eoapi/stac/config.py @@ -33,7 +33,7 @@ def get_secret_dict(secret_name: str): class Settings(config.Settings): - """Extent stac-fastapi-pgstac settings""" + """Extent stac-fastapi-pgstac API settings""" stac_fastapi_title: str = "eoAPI-stac" stac_fastapi_description: str = "Custom stac-fastapi application for eoAPI-Devseed" @@ -41,13 +41,17 @@ class Settings(config.Settings): cachecontrol: str = "public, max-age=3600" - pgstac_secret_arn: Optional[str] = None - titiler_endpoint: Optional[str] = None enable_transaction: bool = False debug: bool = False + +class PostgresSettings(config.PostgresSettings): + """Extent stac-fastapi-pgstac PostgresSettings settings""" + + pgstac_secret_arn: Optional[str] = None + @model_validator(mode="before") def get_postgres_setting(cls, data: Any) -> Any: if arn := data.get("pgstac_secret_arn"): diff --git a/runtimes/eoapi/stac/pyproject.toml b/runtimes/eoapi/stac/pyproject.toml index d594ee4..04f4dbc 100644 --- a/runtimes/eoapi/stac/pyproject.toml +++ b/runtimes/eoapi/stac/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ ] dynamic = ["version"] dependencies = [ - "stac-fastapi.pgstac>=4.0.2,<4.1", + "stac-fastapi.pgstac>=5.0,<5.1", "jinja2>=2.11.2,<4.0.0", "starlette-cramjam>=0.4,<0.5", "psycopg_pool",