diff --git a/infrastructure/handlers/raster_handler.py b/infrastructure/handlers/raster_handler.py index 955ccb9..439a2de 100644 --- a/infrastructure/handlers/raster_handler.py +++ b/infrastructure/handlers/raster_handler.py @@ -5,23 +5,18 @@ import os from eoapi.raster.app import app -from eoapi.raster.config import ApiSettings +from eoapi.raster.config import PostgresSettings from mangum import Mangum from titiler.pgstac.db import connect_to_db logging.getLogger("mangum.lifespan").setLevel(logging.ERROR) logging.getLogger("mangum.http").setLevel(logging.ERROR) -settings = ApiSettings() - @app.on_event("startup") async def startup_event() -> None: """Connect to database on startup.""" - await connect_to_db( - app, - settings=settings.load_postgres_settings(), - ) + await connect_to_db(app, settings=PostgresSettings()) handler = Mangum(app, lifespan="off") diff --git a/infrastructure/handlers/vector_handler.py b/infrastructure/handlers/vector_handler.py index 73d3285..e507dba 100644 --- a/infrastructure/handlers/vector_handler.py +++ b/infrastructure/handlers/vector_handler.py @@ -3,9 +3,10 @@ import asyncio import logging import os +from importlib.resources import files as resources_files from eoapi.vector.app import app -from eoapi.vector.config import ApiSettings +from eoapi.vector.config import PostgresSettings from mangum import Mangum from tipg.collections import register_collection_catalog from tipg.database import connect_to_db @@ -13,14 +14,6 @@ logging.getLogger("mangum.lifespan").setLevel(logging.ERROR) logging.getLogger("mangum.http").setLevel(logging.ERROR) -settings = ApiSettings() - -try: - from importlib.resources import files as resources_files # type: ignore -except ImportError: - # Try backported to PY<39 `importlib_resources`. - from importlib_resources import files as resources_files # type: ignore - CUSTOM_SQL_DIRECTORY = resources_files("eoapi.vector") / "sql" sql_files = list(CUSTOM_SQL_DIRECTORY.glob("*.sql")) # type: ignore @@ -31,7 +24,7 @@ async def startup_event() -> None: """Connect to database on startup.""" await connect_to_db( app, - settings=settings.load_postgres_settings(), + settings=PostgresSettings(), # We enable both pgstac and public schemas (pgstac will be used by custom functions) schemas=["pgstac", "public"], user_sql_files=sql_files, diff --git a/runtimes/eoapi/raster/eoapi/raster/app.py b/runtimes/eoapi/raster/eoapi/raster/app.py index 90108fa..95acc00 100644 --- a/runtimes/eoapi/raster/eoapi/raster/app.py +++ b/runtimes/eoapi/raster/eoapi/raster/app.py @@ -47,7 +47,7 @@ from titiler.pgstac.reader import PgSTACReader from . import __version__ as eoapi_raster_version -from .config import ApiSettings +from .config import ApiSettings, PostgresSettings from .logs import init_logging settings = ApiSettings() @@ -90,7 +90,7 @@ async def lifespan(app: FastAPI): """FastAPI Lifespan.""" logger.debug("Connecting to db...") - await connect_to_db(app, settings=settings.load_postgres_settings()) + await connect_to_db(app, settings=PostgresSettings()) logger.debug("Connected to db.") yield diff --git a/runtimes/eoapi/raster/eoapi/raster/config.py b/runtimes/eoapi/raster/eoapi/raster/config.py index 0e9f74f..4a4d0fd 100644 --- a/runtimes/eoapi/raster/eoapi/raster/config.py +++ b/runtimes/eoapi/raster/eoapi/raster/config.py @@ -2,15 +2,15 @@ import base64 import json -from typing import Optional +from typing import Any, Dict, Optional import boto3 -from pydantic import field_validator +from pydantic import field_validator, model_validator from pydantic_settings import BaseSettings -from titiler.pgstac.settings import PostgresSettings +from titiler.pgstac.settings import PostgresSettings as _PostgresSettings -def get_secret_dict(secret_name: str): +def get_secret_dict(secret_name: str) -> Dict: """Retrieve secrets from AWS Secrets Manager Args: @@ -43,8 +43,6 @@ class ApiSettings(BaseSettings): debug: bool = False root_path: str = "" - pgstac_secret_arn: Optional[str] = None - model_config = { "env_file": ".env", "extra": "allow", @@ -60,18 +58,24 @@ def parse_cors_methods(cls, v): """Parse CORS methods.""" return [method.strip() for method in v.split(",")] - def load_postgres_settings(self) -> "PostgresSettings": - """Load postgres connection params from AWS secret""" - if self.pgstac_secret_arn: - secret = get_secret_dict(self.pgstac_secret_arn) +class PostgresSettings(_PostgresSettings): + """Extent titiler-pgstac PostgresSettings settings""" - return PostgresSettings( - postgres_host=secret["host"], - postgres_dbname=secret["dbname"], - postgres_user=secret["username"], - postgres_pass=secret["password"], - postgres_port=secret["port"], + pgstac_secret_arn: Optional[str] = None + + @model_validator(mode="before") + def get_postgres_setting(cls, data: Any) -> Any: + if arn := data.get("pgstac_secret_arn"): + secret = get_secret_dict(arn) + data.update( + { + "postgres_host": secret["host"], + "postgres_dbname": secret["dbname"], + "postgres_user": secret["username"], + "postgres_pass": secret["password"], + "postgres_port": secret["port"], + } ) - else: - return PostgresSettings() + + return data diff --git a/runtimes/eoapi/raster/pyproject.toml b/runtimes/eoapi/raster/pyproject.toml index a7c3442..faae65d 100644 --- a/runtimes/eoapi/raster/pyproject.toml +++ b/runtimes/eoapi/raster/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ ] dynamic = ["version"] dependencies = [ - "titiler.pgstac==1.6.0", + "titiler.pgstac==1.7.1", "titiler.extensions", "starlette-cramjam>=0.4,<0.5", "importlib_resources>=1.1.0;python_version<'3.9'", diff --git a/runtimes/eoapi/stac/eoapi/stac/config.py b/runtimes/eoapi/stac/eoapi/stac/config.py index e5c1bb1..8d521af 100644 --- a/runtimes/eoapi/stac/eoapi/stac/config.py +++ b/runtimes/eoapi/stac/eoapi/stac/config.py @@ -2,14 +2,14 @@ import base64 import json -from typing import Any, Optional +from typing import Any, Dict, Optional import boto3 from pydantic import model_validator from stac_fastapi.pgstac import config -def get_secret_dict(secret_name: str): +def get_secret_dict(secret_name: str) -> Dict: """Retrieve secrets from AWS Secrets Manager Args: diff --git a/runtimes/eoapi/vector/eoapi/vector/app.py b/runtimes/eoapi/vector/eoapi/vector/app.py index dce4e20..a1618f2 100644 --- a/runtimes/eoapi/vector/eoapi/vector/app.py +++ b/runtimes/eoapi/vector/eoapi/vector/app.py @@ -2,7 +2,7 @@ import logging from contextlib import asynccontextmanager -from importlib.resources import files as resources_files # type: ignore +from importlib.resources import files as resources_files import jinja2 from eoapi.auth_utils import OpenIdConnectAuth, OpenIdConnectSettings @@ -17,13 +17,12 @@ from tipg.middleware import CacheControlMiddleware, CatalogUpdateMiddleware from . import __version__ as eoapi_vector_version -from .config import ApiSettings +from .config import ApiSettings, PostgresSettings from .logs import init_logging CUSTOM_SQL_DIRECTORY = resources_files(__package__) / "sql" settings = ApiSettings() -postgres_settings = settings.load_postgres_settings() auth_settings = OpenIdConnectSettings() # Logs @@ -53,7 +52,7 @@ async def lifespan(app: FastAPI): logger.debug("Connecting to db...") await connect_to_db( app, - settings=postgres_settings, + settings=PostgresSettings(), # We enable both pgstac and public schemas (pgstac will be used by custom functions) schemas=["pgstac", "public"], user_sql_files=list(CUSTOM_SQL_DIRECTORY.glob("*.sql")), # type: ignore diff --git a/runtimes/eoapi/vector/eoapi/vector/config.py b/runtimes/eoapi/vector/eoapi/vector/config.py index 0b4894c..fc2045d 100644 --- a/runtimes/eoapi/vector/eoapi/vector/config.py +++ b/runtimes/eoapi/vector/eoapi/vector/config.py @@ -2,15 +2,15 @@ import base64 import json -from typing import Optional +from typing import Any, Dict, Optional import boto3 -from pydantic import field_validator +from pydantic import field_validator, model_validator from pydantic_settings import BaseSettings -from tipg.settings import PostgresSettings +from tipg.settings import PostgresSettings as _PostgresSettings -def get_secret_dict(secret_name: str): +def get_secret_dict(secret_name: str) -> Dict: """Retrieve secrets from AWS Secrets Manager Args: @@ -43,7 +43,6 @@ class ApiSettings(BaseSettings): debug: bool = False root_path: str = "" - pgstac_secret_arn: Optional[str] = None catalog_ttl: int = 300 model_config = { @@ -61,18 +60,24 @@ def parse_cors_methods(cls, v): """Parse CORS methods.""" return [method.strip() for method in v.split(",")] - def load_postgres_settings(self) -> "PostgresSettings": - """Load postgres connection params from AWS secret""" - if self.pgstac_secret_arn: - secret = get_secret_dict(self.pgstac_secret_arn) +class PostgresSettings(_PostgresSettings): + """Extent tipg PostgresSettings settings""" + + pgstac_secret_arn: Optional[str] = None - return PostgresSettings( - postgres_host=secret["host"], - postgres_dbname=secret["dbname"], - postgres_user=secret["username"], - postgres_pass=secret["password"], - postgres_port=secret["port"], + @model_validator(mode="before") + def get_postgres_setting(cls, data: Any) -> Any: + if arn := data.get("pgstac_secret_arn"): + secret = get_secret_dict(arn) + data.update( + { + "postgres_host": secret["host"], + "postgres_dbname": secret["dbname"], + "postgres_user": secret["username"], + "postgres_pass": secret["password"], + "postgres_port": secret["port"], + } ) - else: - return PostgresSettings() + + return data