Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api/mymdc): use async mymdc client #36

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
15 changes: 15 additions & 0 deletions api/src/damnit_api/base/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from pathlib import Path

from fastapi import HTTPException


class DWAError(Exception): ...


class DWAHTTPError(DWAError, HTTPException): ...


class InvalidProposalPathError(DWAError):
def __init__(self, path: Path):
self.path = path
super().__init__(f"Invalid proposal path: {path}")
53 changes: 53 additions & 0 deletions api/src/damnit_api/base/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import re
from pathlib import Path
from typing import NewType, Self

from pydantic import BaseModel

from .exceptions import InvalidProposalPathError

ProposalNumber = NewType("ProposalNumber", int)

_RE_PNFS_SUB = re.compile(
r"/pnfs/xfel\.eu/exfel/archive/XFEL/(?:proc|raw)"
r"/(?P<inst>[^/]+)/(?P<cycle>[^/]+)/p(?P<prop>[^/]+)"
)

_RE_GPFS = re.compile(
r"/gpfs/exfel/exp/(?P<inst>[^/]+)/(?P<cycle>[^/]+)/p(?P<prop>[^/]+)"
)

_RE_GPFS_SUB = re.compile(
r"/gpfs/exfel/(?:u/scratch|u/usr|d/proc|d/raw)"
r"/(?P<inst>[^/]+)/(?P<cycle>[^/]+)/p(?P<prop>[^/]+)"
)

_RE_LIST = [_RE_PNFS_SUB, _RE_GPFS, _RE_GPFS_SUB]


class ProposalPath(BaseModel):
instrument: str
cycle: int
number: ProposalNumber

@property
def dirname(self) -> str:
return f"p{self.number:06d}"

@property
def path(self) -> Path:
return Path(f"/gpfs/exfel/exp/{self.instrument}/{self.cycle}/{self.dirname}")

@classmethod
def from_path(cls, path: Path) -> Self:
match = [m.match(str(path)) for m in _RE_LIST]
match = [m for m in match if m]

if not match:
raise InvalidProposalPathError(path)

group = match[0].groupdict()

inst, cycle, no = group["inst"], group["cycle"], int(group["prop"])

return cls(instrument=inst, cycle=int(cycle), number=ProposalNumber(no))
Empty file.
111 changes: 111 additions & 0 deletions api/src/damnit_api/metadata/mymdc/clients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Async MyMdC Client

TODO: I've copy-pasted this code across a few different projects, when/if an async HTTPX
MyMdC client package is created this can be removed and replaced with calls to that."""

import datetime as dt
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any

import httpx
from structlog import get_logger

from ...settings import MyMdCCredentials, Settings

logger = get_logger(__name__)

if TYPE_CHECKING: # pragma: no cover
from fastapi import FastAPI


CLIENT: "MyMdCClient" = None # type: ignore[assignment]


async def _configure(settings: Settings, _: "FastAPI"):
global CLIENT
logger.info("Configuring MyMdC client", settings=settings.mymdc)
auth = MyMdCAuth.model_validate(settings.mymdc, from_attributes=True)
await auth.acquire_token()
CLIENT = MyMdCClient(auth=auth)


class MyMdCAuth(httpx.Auth, MyMdCCredentials):
async def acquire_token(self):
"""Acquires a new token if none is stored or if the existing token expired,
otherwise reuses the existing token.

Token data stored under `_access_token` and `_expires_at`.
"""
expired = self._expires_at <= dt.datetime.now(tz=dt.UTC)
if self._access_token and not expired:
logger.debug("Reusing existing MyMdC token", expires_at=self._expires_at)
return self._access_token

logger.info(
"Requesting new MyMdC token",
access_token_none=not self._access_token,
expires_at=self._expires_at,
expired=expired,
)

async with httpx.AsyncClient() as client:
data = {
"grant_type": "client_credentials",
"client_id": self.client_id,
"client_secret": self.client_secret.get_secret_value(),
}

if self.scope:
data["scope"] = self.scope

response = await client.post(str(self.token_url), data=data)

data = response.json()

if any(k not in data for k in ["access_token", "expires_in"]):
logger.critical(
"Response from MyMdC missing required fields, check webservice "
"`user-id` and `user-secret`.",
response=response.text,
status_code=response.status_code,
)
msg = "Invalid response from MyMdC"
raise ValueError(msg) # TODO: custom exception, frontend feedback

expires_in = dt.timedelta(seconds=data["expires_in"])
self._access_token = data["access_token"]
self._expires_at = dt.datetime.now(tz=dt.UTC) + expires_in

logger.info("Acquired new MyMdC token", expires_at=self._expires_at)
return self._access_token

async def async_auth_flow(
self, request: httpx.Request
) -> AsyncGenerator[httpx.Request, Any]:
"""Fetches bearer token (if required) and adds required authorization headers to
the request.

Yields:
AsyncGenerator[httpx.Request, Any]: yields `request` with additional headers
"""
bearer_token = await self.acquire_token()

request.headers["Authorization"] = f"Bearer {bearer_token}"
request.headers["accept"] = "application/json; version=1"
request.headers["X-User-Email"] = self.email

yield request


class MyMdCClient(httpx.AsyncClient):
def __init__(self, auth: MyMdCAuth | None = None) -> None:
"""Client for the MyMdC API."""
if auth is None:
auth = MyMdCAuth() # type: ignore[call-arg]

logger.debug("Creating MyMdC client", auth=auth)

super().__init__(
auth=auth,
base_url="https://in.xfel.eu/metadata/",
)
1 change: 1 addition & 0 deletions api/src/damnit_api/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class MyMdCCredentials(BaseSettings):
email: str
token_url: HttpUrl
base_url: HttpUrl
scope: str | None = "public"

_access_token: str = ""
_expires_at: datetime = datetime.fromisocalendar(1970, 1, 1).astimezone(UTC)
Expand Down