Skip to content

Commit

Permalink
SkyPortalPublisher: make cutout posting configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
jvansanten committed Jun 1, 2024
1 parent 300e256 commit 9b48ce8
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 38 deletions.
94 changes: 57 additions & 37 deletions ampel/ztf/t3/skyportal/SkyPortalClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from matplotlib.colors import Normalize
from matplotlib.figure import Figure

from ampel.base.AmpelBaseModel import AmpelBaseModel
from ampel.base.AmpelUnit import AmpelUnit
from ampel.enum.DocumentCode import DocumentCode
from ampel.log.AmpelLogger import AmpelLogger
Expand Down Expand Up @@ -146,7 +147,17 @@ def render_thumbnail(cutout_data: bytes) -> str:


ZTF_FILTERS = {1: "ztfg", 2: "ztfr", 3: "ztfi"}
CUTOUT_TYPES = {"science": "new", "template": "ref", "difference": "sub"}


class CutoutSpec(AmpelBaseModel):
#: where to find cutouts
key: str = "ZTFCutoutImages"
#: mapping from cutout names to SkyPortal thumbnail types
types: dict[str, str] = {
"cutoutScience": "new",
"cutoutTemplate": "ref",
"cutoutDifference": "sub",
}


class SkyPortalAPIError(IOError):
Expand Down Expand Up @@ -242,15 +253,17 @@ async def request(
url = self.base_url + "/api/" + endpoint
labels = (verb, endpoint.split("/")[0])
async with self._semaphore:
with stat_http_time.labels(*labels).time(), stat_http_errors.labels(
*labels
).count_exceptions(
( # type: ignore[arg-type]
aiohttp.ClientResponseError,
aiohttp.ClientConnectionError,
asyncio.TimeoutError,
)
), stat_concurrent_requests.labels(*labels).track_inprogress():
with (
stat_http_time.labels(*labels).time(),
stat_http_errors.labels(*labels).count_exceptions(
( # type: ignore[arg-type]
aiohttp.ClientResponseError,
aiohttp.ClientConnectionError,
asyncio.TimeoutError,
)
),
stat_concurrent_requests.labels(*labels).track_inprogress(),
):
async with self._session.request(
verb, url, **{**self._request_kwargs, **kwargs}
) as response:
Expand Down Expand Up @@ -599,10 +612,12 @@ async def post_t2_comments(
async def post_candidate(
self,
view: "TransientView",
*,
filters: None | list[str] = None,
groups: None | list[str] = None,
instrument: None | str = None,
post_photometry: bool = True,
post_cutouts: None | CutoutSpec = None,
annotate: bool = False,
) -> PostReport:
"""
Expand Down Expand Up @@ -650,7 +665,9 @@ async def post_candidate(
)
}
group_ids = {await self.get_by_name("groups", name) for name in (groups or [])}
assert "tag" in view.stock, f"{self.__class__} requires stocks with a `tag` field. Did you remember to set AlertConsumer.compiler_opts?"
assert (
"tag" in view.stock
), f"{self.__class__} requires stocks with a `tag` field. Did you remember to set AlertConsumer.compiler_opts?"
assert view.stock["tag"] is not None
instrument_id = (
await self.get_by_name("instrument", instrument)
Expand All @@ -659,7 +676,9 @@ async def post_candidate(
)

assert view.stock
assert "name" in view.stock, f"{self.__class__} requires stocks with a `name` field. Did you remember to set AlertConsumer.compiler_opts?"
assert (
"name" in view.stock
), f"{self.__class__} requires stocks with a `name` field. Did you remember to set AlertConsumer.compiler_opts?"
assert view.stock["name"] is not None
name = next(
n for n in view.stock["name"] if isinstance(n, str) and n.startswith("ZTF")
Expand Down Expand Up @@ -783,31 +802,32 @@ async def post_candidate(
except SkyPortalAPIError as exc:
ret["photometry_error"] = exc.args[0]

# SkyPortal only supports one of thumbnail per object and type
# ('new', 'ref', 'sub', 'sdss', 'dr8', 'new_gz', 'ref_gz', 'sub_gz')
# Post one of each type only if they do not yet exist.
existing_cutouts: set[str] = (
{t["type"] for t in response["data"]["thumbnails"]}
if response["status"] == "success"
else set()
)
for cutouts in (view.extra or {}).get("ZTFCutoutImages", {}).values():
for kind, blob in (cutouts or {}).items():
if CUTOUT_TYPES[kind] in existing_cutouts:
continue
assert isinstance(blob, bytes)
# FIXME: switch back to FITS when SkyPortal supports it
await self.post(
"thumbnail",
json={
"obj_id": name,
"data": render_thumbnail(blob),
"ttype": CUTOUT_TYPES[kind],
},
raise_exc=True,
)
existing_cutouts.add(CUTOUT_TYPES[kind])
ret["thumbnail_count"] += 1
if post_cutouts is not None:
# SkyPortal only supports one of thumbnail per object and type
# ('new', 'ref', 'sub', 'sdss', 'dr8', 'new_gz', 'ref_gz', 'sub_gz')
# Post one of each type only if they do not yet exist.
existing_cutouts: set[str] = (
{t["type"] for t in response["data"]["thumbnails"]}
if response["status"] == "success"
else set()
)
for cutouts in (view.extra or {}).get(post_cutouts.key, {}).values():
for kind, blob in (cutouts or {}).items():
if post_cutouts.types[kind] in existing_cutouts:
continue
assert isinstance(blob, bytes)
# FIXME: switch back to FITS when SkyPortal supports it
await self.post(
"thumbnail",
json={
"obj_id": name,
"data": render_thumbnail(blob),
"ttype": post_cutouts.types[kind],
},
raise_exc=True,
)
existing_cutouts.add(post_cutouts.types[kind])
ret["thumbnail_count"] += 1

# represent latest T2 results as a comments
latest_t2: dict[str, "T2DocView"] = {}
Expand Down
4 changes: 3 additions & 1 deletion ampel/ztf/t3/skyportal/SkyPortalPublisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ampel.abstract.AbsPhotoT3Unit import AbsPhotoT3Unit
from ampel.struct.JournalAttributes import JournalAttributes
from ampel.types import StockId
from ampel.ztf.t3.skyportal.SkyPortalClient import BaseSkyPortalPublisher
from ampel.ztf.t3.skyportal.SkyPortalClient import BaseSkyPortalPublisher, CutoutSpec

if TYPE_CHECKING:
from ampel.content.JournalRecord import JournalRecord
Expand All @@ -33,6 +33,7 @@ class SkyPortalPublisher(BaseSkyPortalPublisher, AbsPhotoT3Unit):
#: Explicitly post photometry for each stock. If False, rely on some backend
#: service (like Kowalski on Fritz) to fill in photometry for sources.
include_photometry: bool = True
cutouts: None | CutoutSpec = CutoutSpec()

process_name: None | str = None

Expand Down Expand Up @@ -103,6 +104,7 @@ async def post_view(
filters=self.filters,
annotate=self.annotate,
post_photometry=self.include_photometry,
post_cutouts=self.cutouts,
)
)
)

0 comments on commit 9b48ce8

Please sign in to comment.