Skip to content

customize runtime's initial pg-settings for raster and vector, as done for stac #42

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions infrastructure/handlers/raster_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,18 @@
import os

from eoapi.raster.app import app
from eoapi.raster.config import ApiSettings
from eoapi.raster.config import PostgresSettings
from mangum import Mangum
from titiler.pgstac.db import connect_to_db

logging.getLogger("mangum.lifespan").setLevel(logging.ERROR)
logging.getLogger("mangum.http").setLevel(logging.ERROR)

settings = ApiSettings()


@app.on_event("startup")
async def startup_event() -> None:
"""Connect to database on startup."""
await connect_to_db(
app,
settings=settings.load_postgres_settings(),
)
await connect_to_db(app, settings=PostgresSettings())


handler = Mangum(app, lifespan="off")
Expand Down
13 changes: 3 additions & 10 deletions infrastructure/handlers/vector_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,17 @@
import asyncio
import logging
import os
from importlib.resources import files as resources_files

from eoapi.vector.app import app
from eoapi.vector.config import ApiSettings
from eoapi.vector.config import PostgresSettings
from mangum import Mangum
from tipg.collections import register_collection_catalog
from tipg.database import connect_to_db

logging.getLogger("mangum.lifespan").setLevel(logging.ERROR)
logging.getLogger("mangum.http").setLevel(logging.ERROR)

settings = ApiSettings()

try:
from importlib.resources import files as resources_files # type: ignore
except ImportError:
# Try backported to PY<39 `importlib_resources`.
from importlib_resources import files as resources_files # type: ignore


CUSTOM_SQL_DIRECTORY = resources_files("eoapi.vector") / "sql"
sql_files = list(CUSTOM_SQL_DIRECTORY.glob("*.sql")) # type: ignore
Expand All @@ -31,7 +24,7 @@ async def startup_event() -> None:
"""Connect to database on startup."""
await connect_to_db(
app,
settings=settings.load_postgres_settings(),
settings=PostgresSettings(),
# We enable both pgstac and public schemas (pgstac will be used by custom functions)
schemas=["pgstac", "public"],
user_sql_files=sql_files,
Expand Down
4 changes: 2 additions & 2 deletions runtimes/eoapi/raster/eoapi/raster/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from titiler.pgstac.reader import PgSTACReader

from . import __version__ as eoapi_raster_version
from .config import ApiSettings
from .config import ApiSettings, PostgresSettings
from .logs import init_logging

settings = ApiSettings()
Expand Down Expand Up @@ -90,7 +90,7 @@
async def lifespan(app: FastAPI):
"""FastAPI Lifespan."""
logger.debug("Connecting to db...")
await connect_to_db(app, settings=settings.load_postgres_settings())
await connect_to_db(app, settings=PostgresSettings())
logger.debug("Connected to db.")

yield
Expand Down
40 changes: 22 additions & 18 deletions runtimes/eoapi/raster/eoapi/raster/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import base64
import json
from typing import Optional
from typing import Any, Dict, Optional

import boto3
from pydantic import field_validator
from pydantic import field_validator, model_validator
from pydantic_settings import BaseSettings
from titiler.pgstac.settings import PostgresSettings
from titiler.pgstac.settings import PostgresSettings as _PostgresSettings


def get_secret_dict(secret_name: str):
def get_secret_dict(secret_name: str) -> Dict:
"""Retrieve secrets from AWS Secrets Manager

Args:
Expand Down Expand Up @@ -43,8 +43,6 @@ class ApiSettings(BaseSettings):
debug: bool = False
root_path: str = ""

pgstac_secret_arn: Optional[str] = None

model_config = {
"env_file": ".env",
"extra": "allow",
Expand All @@ -60,18 +58,24 @@ def parse_cors_methods(cls, v):
"""Parse CORS methods."""
return [method.strip() for method in v.split(",")]

def load_postgres_settings(self) -> "PostgresSettings":
"""Load postgres connection params from AWS secret"""

if self.pgstac_secret_arn:
secret = get_secret_dict(self.pgstac_secret_arn)
class PostgresSettings(_PostgresSettings):
"""Extent titiler-pgstac PostgresSettings settings"""

return PostgresSettings(
postgres_host=secret["host"],
postgres_dbname=secret["dbname"],
postgres_user=secret["username"],
postgres_pass=secret["password"],
postgres_port=secret["port"],
pgstac_secret_arn: Optional[str] = None

@model_validator(mode="before")
def get_postgres_setting(cls, data: Any) -> Any:
if arn := data.get("pgstac_secret_arn"):
secret = get_secret_dict(arn)
data.update(
{
"postgres_host": secret["host"],
"postgres_dbname": secret["dbname"],
"postgres_user": secret["username"],
"postgres_pass": secret["password"],
"postgres_port": secret["port"],
}
)
else:
return PostgresSettings()

return data
2 changes: 1 addition & 1 deletion runtimes/eoapi/raster/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ classifiers = [
]
dynamic = ["version"]
dependencies = [
"titiler.pgstac==1.6.0",
"titiler.pgstac==1.7.1",
"titiler.extensions",
"starlette-cramjam>=0.4,<0.5",
"importlib_resources>=1.1.0;python_version<'3.9'",
Expand Down
4 changes: 2 additions & 2 deletions runtimes/eoapi/stac/eoapi/stac/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import base64
import json
from typing import Any, Optional
from typing import Any, Dict, Optional

import boto3
from pydantic import model_validator
from stac_fastapi.pgstac import config


def get_secret_dict(secret_name: str):
def get_secret_dict(secret_name: str) -> Dict:
"""Retrieve secrets from AWS Secrets Manager

Args:
Expand Down
7 changes: 3 additions & 4 deletions runtimes/eoapi/vector/eoapi/vector/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from contextlib import asynccontextmanager
from importlib.resources import files as resources_files # type: ignore
from importlib.resources import files as resources_files

import jinja2
from eoapi.auth_utils import OpenIdConnectAuth, OpenIdConnectSettings
Expand All @@ -17,13 +17,12 @@
from tipg.middleware import CacheControlMiddleware, CatalogUpdateMiddleware

from . import __version__ as eoapi_vector_version
from .config import ApiSettings
from .config import ApiSettings, PostgresSettings
from .logs import init_logging

CUSTOM_SQL_DIRECTORY = resources_files(__package__) / "sql"

settings = ApiSettings()
postgres_settings = settings.load_postgres_settings()
auth_settings = OpenIdConnectSettings()

# Logs
Expand Down Expand Up @@ -53,7 +52,7 @@ async def lifespan(app: FastAPI):
logger.debug("Connecting to db...")
await connect_to_db(
app,
settings=postgres_settings,
settings=PostgresSettings(),
# We enable both pgstac and public schemas (pgstac will be used by custom functions)
schemas=["pgstac", "public"],
user_sql_files=list(CUSTOM_SQL_DIRECTORY.glob("*.sql")), # type: ignore
Expand Down
39 changes: 22 additions & 17 deletions runtimes/eoapi/vector/eoapi/vector/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import base64
import json
from typing import Optional
from typing import Any, Dict, Optional

import boto3
from pydantic import field_validator
from pydantic import field_validator, model_validator
from pydantic_settings import BaseSettings
from tipg.settings import PostgresSettings
from tipg.settings import PostgresSettings as _PostgresSettings


def get_secret_dict(secret_name: str):
def get_secret_dict(secret_name: str) -> Dict:
"""Retrieve secrets from AWS Secrets Manager

Args:
Expand Down Expand Up @@ -43,7 +43,6 @@ class ApiSettings(BaseSettings):
debug: bool = False
root_path: str = ""

pgstac_secret_arn: Optional[str] = None
catalog_ttl: int = 300

model_config = {
Expand All @@ -61,18 +60,24 @@ def parse_cors_methods(cls, v):
"""Parse CORS methods."""
return [method.strip() for method in v.split(",")]

def load_postgres_settings(self) -> "PostgresSettings":
"""Load postgres connection params from AWS secret"""

if self.pgstac_secret_arn:
secret = get_secret_dict(self.pgstac_secret_arn)
class PostgresSettings(_PostgresSettings):
"""Extent tipg PostgresSettings settings"""

pgstac_secret_arn: Optional[str] = None

return PostgresSettings(
postgres_host=secret["host"],
postgres_dbname=secret["dbname"],
postgres_user=secret["username"],
postgres_pass=secret["password"],
postgres_port=secret["port"],
@model_validator(mode="before")
def get_postgres_setting(cls, data: Any) -> Any:
if arn := data.get("pgstac_secret_arn"):
secret = get_secret_dict(arn)
data.update(
{
"postgres_host": secret["host"],
"postgres_dbname": secret["dbname"],
"postgres_user": secret["username"],
"postgres_pass": secret["password"],
"postgres_port": secret["port"],
}
)
else:
return PostgresSettings()

return data