diff --git a/ampel/ztf/t3/skyportal/SkyPortalClient.py b/ampel/ztf/t3/skyportal/SkyPortalClient.py index bc1b655..f59c010 100755 --- a/ampel/ztf/t3/skyportal/SkyPortalClient.py +++ b/ampel/ztf/t3/skyportal/SkyPortalClient.py @@ -27,6 +27,7 @@ from matplotlib.colors import Normalize from matplotlib.figure import Figure +from ampel.base.AmpelBaseModel import AmpelBaseModel from ampel.base.AmpelUnit import AmpelUnit from ampel.enum.DocumentCode import DocumentCode from ampel.log.AmpelLogger import AmpelLogger @@ -146,7 +147,17 @@ def render_thumbnail(cutout_data: bytes) -> str: ZTF_FILTERS = {1: "ztfg", 2: "ztfr", 3: "ztfi"} -CUTOUT_TYPES = {"science": "new", "template": "ref", "difference": "sub"} + + +class CutoutSpec(AmpelBaseModel): + #: where to find cutouts + key: str = "ZTFCutoutImages" + #: mapping from cutout names to SkyPortal thumbnail types + types: dict[str, str] = { + "cutoutScience": "new", + "cutoutTemplate": "ref", + "cutoutDifference": "sub", + } class SkyPortalAPIError(IOError): @@ -242,15 +253,17 @@ async def request( url = self.base_url + "/api/" + endpoint labels = (verb, endpoint.split("/")[0]) async with self._semaphore: - with stat_http_time.labels(*labels).time(), stat_http_errors.labels( - *labels - ).count_exceptions( - ( # type: ignore[arg-type] - aiohttp.ClientResponseError, - aiohttp.ClientConnectionError, - asyncio.TimeoutError, - ) - ), stat_concurrent_requests.labels(*labels).track_inprogress(): + with ( + stat_http_time.labels(*labels).time(), + stat_http_errors.labels(*labels).count_exceptions( + ( # type: ignore[arg-type] + aiohttp.ClientResponseError, + aiohttp.ClientConnectionError, + asyncio.TimeoutError, + ) + ), + stat_concurrent_requests.labels(*labels).track_inprogress(), + ): async with self._session.request( verb, url, **{**self._request_kwargs, **kwargs} ) as response: @@ -599,10 +612,12 @@ async def post_t2_comments( async def post_candidate( self, view: "TransientView", + *, filters: None | list[str] = None, groups: None | list[str] = None, instrument: None | str = None, post_photometry: bool = True, + post_cutouts: None | CutoutSpec = None, annotate: bool = False, ) -> PostReport: """ @@ -650,7 +665,9 @@ async def post_candidate( ) } group_ids = {await self.get_by_name("groups", name) for name in (groups or [])} - assert "tag" in view.stock, f"{self.__class__} requires stocks with a `tag` field. Did you remember to set AlertConsumer.compiler_opts?" + assert ( + "tag" in view.stock + ), f"{self.__class__} requires stocks with a `tag` field. Did you remember to set AlertConsumer.compiler_opts?" assert view.stock["tag"] is not None instrument_id = ( await self.get_by_name("instrument", instrument) @@ -659,7 +676,9 @@ async def post_candidate( ) assert view.stock - assert "name" in view.stock, f"{self.__class__} requires stocks with a `name` field. Did you remember to set AlertConsumer.compiler_opts?" + assert ( + "name" in view.stock + ), f"{self.__class__} requires stocks with a `name` field. Did you remember to set AlertConsumer.compiler_opts?" assert view.stock["name"] is not None name = next( n for n in view.stock["name"] if isinstance(n, str) and n.startswith("ZTF") @@ -783,31 +802,32 @@ async def post_candidate( except SkyPortalAPIError as exc: ret["photometry_error"] = exc.args[0] - # SkyPortal only supports one of thumbnail per object and type - # ('new', 'ref', 'sub', 'sdss', 'dr8', 'new_gz', 'ref_gz', 'sub_gz') - # Post one of each type only if they do not yet exist. - existing_cutouts: set[str] = ( - {t["type"] for t in response["data"]["thumbnails"]} - if response["status"] == "success" - else set() - ) - for cutouts in (view.extra or {}).get("ZTFCutoutImages", {}).values(): - for kind, blob in (cutouts or {}).items(): - if CUTOUT_TYPES[kind] in existing_cutouts: - continue - assert isinstance(blob, bytes) - # FIXME: switch back to FITS when SkyPortal supports it - await self.post( - "thumbnail", - json={ - "obj_id": name, - "data": render_thumbnail(blob), - "ttype": CUTOUT_TYPES[kind], - }, - raise_exc=True, - ) - existing_cutouts.add(CUTOUT_TYPES[kind]) - ret["thumbnail_count"] += 1 + if post_cutouts is not None: + # SkyPortal only supports one of thumbnail per object and type + # ('new', 'ref', 'sub', 'sdss', 'dr8', 'new_gz', 'ref_gz', 'sub_gz') + # Post one of each type only if they do not yet exist. + existing_cutouts: set[str] = ( + {t["type"] for t in response["data"]["thumbnails"]} + if response["status"] == "success" + else set() + ) + for cutouts in (view.extra or {}).get(post_cutouts.key, {}).values(): + for kind, blob in (cutouts or {}).items(): + if post_cutouts.types[kind] in existing_cutouts: + continue + assert isinstance(blob, bytes) + # FIXME: switch back to FITS when SkyPortal supports it + await self.post( + "thumbnail", + json={ + "obj_id": name, + "data": render_thumbnail(blob), + "ttype": post_cutouts.types[kind], + }, + raise_exc=True, + ) + existing_cutouts.add(post_cutouts.types[kind]) + ret["thumbnail_count"] += 1 # represent latest T2 results as a comments latest_t2: dict[str, "T2DocView"] = {} diff --git a/ampel/ztf/t3/skyportal/SkyPortalPublisher.py b/ampel/ztf/t3/skyportal/SkyPortalPublisher.py index 21323f9..b009eba 100755 --- a/ampel/ztf/t3/skyportal/SkyPortalPublisher.py +++ b/ampel/ztf/t3/skyportal/SkyPortalPublisher.py @@ -16,7 +16,7 @@ from ampel.abstract.AbsPhotoT3Unit import AbsPhotoT3Unit from ampel.struct.JournalAttributes import JournalAttributes from ampel.types import StockId -from ampel.ztf.t3.skyportal.SkyPortalClient import BaseSkyPortalPublisher +from ampel.ztf.t3.skyportal.SkyPortalClient import BaseSkyPortalPublisher, CutoutSpec if TYPE_CHECKING: from ampel.content.JournalRecord import JournalRecord @@ -33,6 +33,7 @@ class SkyPortalPublisher(BaseSkyPortalPublisher, AbsPhotoT3Unit): #: Explicitly post photometry for each stock. If False, rely on some backend #: service (like Kowalski on Fritz) to fill in photometry for sources. include_photometry: bool = True + cutouts: None | CutoutSpec = CutoutSpec() process_name: None | str = None @@ -103,6 +104,7 @@ async def post_view( filters=self.filters, annotate=self.annotate, post_photometry=self.include_photometry, + post_cutouts=self.cutouts, ) ) )