Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions src/labthings_fastapi/actions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
InvocationCancelledError,
invocation_logger,
)
from ..outputs.blob import BlobIOContextDep
from ..outputs.blob import Blob, BlobDataManager

if TYPE_CHECKING:
# We only need these imports for type hints, so this avoids circular imports.
Expand All @@ -40,6 +40,7 @@ def __init__(
self,
action: ActionDescriptor,
thing: Thing,
blob_data_manager: BlobDataManager,
input: Optional[BaseModel] = None,
dependencies: Optional[dict[str, Any]] = None,
default_stop_timeout: float = 5,
Expand All @@ -56,6 +57,8 @@ def __init__(
self.dependencies = dependencies if dependencies is not None else {}
self.cancel_hook = cancel_hook

self._blob_data_manager = blob_data_manager

# A UUID for the Invocation (not the same as the threading.Thread ident)
self._ID = id if id is not None else uuid.uuid4() # Task ID

Expand Down Expand Up @@ -181,6 +184,9 @@ def run(self):
ret = action.__get__(thing)(**kwargs, **self.dependencies)

with self._status_lock:
if isinstance(ret, Blob):
blob_id = self._blob_data_manager.add_blob(ret.data)
ret.href = f"/blob/{blob_id}"
self._return_value = ret
self._status = InvocationStatus.COMPLETED
self.action.emit_changed_event(self.thing, self._status)
Expand Down Expand Up @@ -241,7 +247,8 @@ def emit(self, record):
class ActionManager:
"""A class to manage a collection of actions"""

def __init__(self):
def __init__(self, server):
self._server = server
self._invocations = {}
self._invocations_lock = Lock()

Expand Down Expand Up @@ -271,6 +278,7 @@ def invoke_action(
dependencies=dependencies,
id=id,
cancel_hook=cancel_hook,
blob_data_manager=self._server.blob_data_manager,
)
self.append_invocation(thread)
thread.start()
Expand Down Expand Up @@ -312,17 +320,15 @@ def attach_to_app(self, app: FastAPI):
"""Add /action_invocations and /action_invocation/{id} endpoints to FastAPI"""

@app.get(ACTION_INVOCATIONS_PATH, response_model=list[InvocationModel])
def list_all_invocations(request: Request, _blob_manager: BlobIOContextDep):
def list_all_invocations(request: Request):
return self.list_invocations(as_responses=True, request=request)

@app.get(
ACTION_INVOCATIONS_PATH + "/{id}",
response_model=InvocationModel,
responses={404: {"description": "Invocation ID not found"}},
)
def action_invocation(
id: uuid.UUID, request: Request, _blob_manager: BlobIOContextDep
):
def action_invocation(id: uuid.UUID, request: Request):
try:
with self._invocations_lock:
return self._invocations[id].response(request=request)
Expand All @@ -346,7 +352,7 @@ def action_invocation(
503: {"description": "No result is available for this invocation"},
},
)
def action_invocation_output(id: uuid.UUID, _blob_manager: BlobIOContextDep):
def action_invocation_output(id: uuid.UUID):
"""Get the output of an action invocation

This returns just the "output" component of the action invocation. If the
Expand Down
3 changes: 1 addition & 2 deletions src/labthings_fastapi/descriptors/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
input_model_from_signature,
return_type,
)
from ..outputs.blob import BlobIOContextDep

from ..thing_description import type_to_dataschema
from ..thing_description.model import ActionAffordance, ActionOp, Form, Union
from ..utilities import labthings_data, get_blocking_portal
Expand Down Expand Up @@ -178,7 +178,6 @@ def add_to_fastapi(self, app: FastAPI, thing: Thing):
# the function to the decorator.
def start_action(
action_manager: ActionManagerContextDep,
_blob_manager: BlobIOContextDep,
request: Request,
body,
id: InvocationID,
Expand Down
177 changes: 13 additions & 164 deletions src/labthings_fastapi/outputs/blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,33 +43,24 @@ def get_image(self) -> MyImageBlob:
"""

from __future__ import annotations
from contextvars import ContextVar
import io
import os
import re
import shutil
from typing import (
Annotated,
Callable,
Literal,
Mapping,
Optional,
)
from weakref import WeakValueDictionary
from typing_extensions import TypeAlias
from tempfile import TemporaryDirectory
import uuid

from fastapi import FastAPI, Depends, Request
from fastapi import FastAPI
from fastapi.responses import FileResponse, Response
from pydantic import (
BaseModel,
create_model,
model_serializer,
model_validator,
)
from labthings_fastapi.dependencies.thing_server import find_thing_server
from starlette.exceptions import HTTPException
from typing_extensions import Self, Protocol, runtime_checkable


Expand Down Expand Up @@ -203,88 +194,25 @@ class Blob(BaseModel):
documentation.
"""

href: str
href: str = "blob://local"
"""The URL where the data may be retrieved. This will be `blob://local`
if the data is stored locally."""
media_type: str = "*/*"
"""The MIME type of the data. This should be overridden in subclasses."""
rel: Literal["output"] = "output"
description: str = (
"The output from this action is not serialised to JSON, so it must be "
"retrieved as a file. This link will return the file."
)
media_type: str = "*/*"
"""The MIME type of the data. This should be overridden in subclasses."""

_data: Optional[ServerSideBlobData] = None
"""This object holds the data, either in memory or as a file.

If `_data` is `None`, then the Blob has not been deserialised yet, and the
`href` should point to a valid address where the data may be downloaded.
"""

@model_validator(mode="after")
def retrieve_data(self):
"""Retrieve the data from the URL

When a [`Blob`](#labthings_fastapi.outputs.blob.Blob) is created
using its constructor, [`pydantic`](https://docs.pydantic.dev/latest/)
will attempt to deserialise it by retrieving the data from the URL
specified in `href`. Currently, this must be a URL pointing to a
[`Blob`](#labthings_fastapi.outputs.blob.Blob) that already exists on
this server.

This validator will only work if the function to resolve URLs to
[`BlobData`](#labthings_fastapi.outputs.blob.BlobData) objects
has been set in the context variable
[`url_to_blobdata_ctx`](#labthings_fastapi.outputs.blob.url_to_blobdata_ctx).
This is done when actions are being invoked over HTTP by the
[`BlobIOContextDep`](#labthings_fastapi.outputs.blob.BlobIOContextDep) dependency.
"""
if self.href == "blob://local":
if self._data:
return self
raise ValueError("Blob objects must have data if the href is blob://local")
try:
url_to_blobdata = url_to_blobdata_ctx.get()
self._data = url_to_blobdata(self.href)
self.href = "blob://local"
except LookupError:
raise LookupError(
"Blobs may only be created from URLs passed in over HTTP."
f"The URL in question was {self.href}."
)
return self
_data: ServerSideBlobData
"""This object holds the data, either in memory or as a file."""

@model_serializer(mode="plain", when_used="always")
def to_dict(self) -> Mapping[str, str]:
"""Serialise the Blob to a dictionary and make it downloadable

When [`pydantic`](https://docs.pydantic.dev/latest/) serialises this object,
it will call this method to convert it to a dictionary. There is a
significant side-effect, which is that we will add the blob to the
[`BlobDataManager`](#labthings_fastapi.outputs.blob.BlobDataManager) so it
can be downloaded.

This serialiser will only work if the function to assign URLs to
[`BlobData`](#labthings_fastapi.outputs.blob.BlobData) objects
has been set in the context variable
[`blobdata_to_url_ctx`](#labthings_fastapi.outputs.blob.blobdata_to_url_ctx).
This is done when actions are being returned over HTTP by the
[`BlobIOContextDep`](#labthings_fastapi.outputs.blob.BlobIOContextDep) dependency.
"""
if self.href == "blob://local":
try:
blobdata_to_url = blobdata_to_url_ctx.get()
# MyPy seems to miss that `self.data` is a property, hence the ignore
href = blobdata_to_url(self.data) # type: ignore[arg-type]
except LookupError:
raise LookupError(
"Blobs may only be serialised inside the "
"context created by BlobIOContextDep."
)
else:
href = self.href
"""Serialise the Blob to a dictionary and make it downloadable"""
return {
"href": href,
"href": self.href,
"media_type": self.media_type,
"rel": self.rel,
"description": self.description,
Expand Down Expand Up @@ -348,9 +276,8 @@ def open(self) -> io.IOBase:
@classmethod
def from_bytes(cls, data: bytes) -> Self:
"""Create a BlobOutput from a bytes object"""
return cls.model_construct( # type: ignore[return-value]
href="blob://local",
_data=BlobBytes(data, media_type=cls.default_media_type()),
return cls.model_construct(
_data=BlobBytes(data, media_type=cls.default_media_type())
)

@classmethod
Expand All @@ -362,8 +289,7 @@ def from_temporary_directory(cls, folder: TemporaryDirectory, file: str) -> Self
collected.
"""
file_path = os.path.join(folder.name, file)
return cls.model_construct( # type: ignore[return-value]
href="blob://local",
return cls.model_construct(
_data=BlobFile(
file_path,
media_type=cls.default_media_type(),
Expand All @@ -381,36 +307,15 @@ def from_file(cls, file: str) -> Self:
temporary. If you are using temporary files, consider creating your
Blob with `from_temporary_directory` instead.
"""
return cls.model_construct( # type: ignore[return-value]
href="blob://local",
_data=BlobFile(file, media_type=cls.default_media_type()),
return cls.model_construct(
_data=BlobFile(file, media_type=cls.default_media_type())
)

def response(self):
""" "Return a suitable response for serving the output"""
return self.data.response()


def blob_type(media_type: str) -> type[Blob]:
"""Create a BlobOutput subclass for a given media type

This convenience function may confuse static type checkers, so it is usually
clearer to make a subclass instead, e.g.:

```python
class MyImageBlob(Blob):
media_type = "image/png"
```
"""
if "'" in media_type or "\\" in media_type:
raise ValueError("media_type must not contain single quotes or backslashes")
return create_model(
f"{media_type.replace('/', '_')}_blob",
__base__=Blob,
media_type=(eval(f"Literal[r'{media_type}']"), media_type),
)


class BlobDataManager:
"""A class to manage BlobData objects

Expand Down Expand Up @@ -452,59 +357,3 @@ def download_blob(self, blob_id: uuid.UUID):
def attach_to_app(self, app: FastAPI):
"""Attach the BlobDataManager to a FastAPI app"""
app.get("/blob/{blob_id}")(self.download_blob)


blobdata_to_url_ctx = ContextVar[Callable[[ServerSideBlobData], str]]("blobdata_to_url")
"""This context variable gives access to a function that makes BlobData objects
downloadable, by assigning a URL and adding them to the
[`BlobDataManager`](#labthings_fastapi.outputs.blob.BlobDataManager).

It is only available within a
[`blob_serialisation_context_manager`](#labthings_fastapi.outputs.blob.blob_serialisation_context_manager)
because it requires access to the `BlobDataManager` and the `url_for` function
from the FastAPI app.
"""

url_to_blobdata_ctx = ContextVar[Callable[[str], BlobData]]("url_to_blobdata")
"""This context variable gives access to a function that makes BlobData objects
from a URL, by retrieving them from the
[`BlobDataManager`](#labthings_fastapi.outputs.blob.BlobDataManager).

It is only available within a
[`blob_serialisation_context_manager`](#labthings_fastapi.outputs.blob.blob_serialisation_context_manager)
because it requires access to the `BlobDataManager`.
"""


async def blob_serialisation_context_manager(request: Request):
"""Set context variables to allow blobs to be [de]serialised"""
thing_server = find_thing_server(request.app)
blob_manager: BlobDataManager = thing_server.blob_data_manager
url_for = request.url_for

def blobdata_to_url(blob: ServerSideBlobData) -> str:
blob_id = blob_manager.add_blob(blob)
return str(url_for("download_blob", blob_id=blob_id))

def url_to_blobdata(url: str) -> BlobData:
m = re.search(r"blob/([0-9a-z\-]+)", url)
if not m:
raise HTTPException(
status_code=404, detail="Could not find blob ID in href"
)
invocation_id = uuid.UUID(m.group(1))
return blob_manager.get_blob(invocation_id)

t1 = blobdata_to_url_ctx.set(blobdata_to_url)
t2 = url_to_blobdata_ctx.set(url_to_blobdata)
try:
yield blob_manager
finally:
blobdata_to_url_ctx.reset(t1)
url_to_blobdata_ctx.reset(t2)


BlobIOContextDep: TypeAlias = Annotated[
BlobDataManager, Depends(blob_serialisation_context_manager)
]
"""A dependency that enables `Blob`s to be serialised and deserialised."""
2 changes: 1 addition & 1 deletion src/labthings_fastapi/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, settings_folder: Optional[str] = None):
self.app = FastAPI(lifespan=self.lifespan)
self.set_cors_middleware()
self.settings_folder = settings_folder or "./settings"
self.action_manager = ActionManager()
self.action_manager = ActionManager(self)
self.action_manager.attach_to_app(self.app)
self.blob_data_manager = BlobDataManager()
self.blob_data_manager.attach_to_app(self.app)
Expand Down
Loading