-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Make SkyPortalClient synchronous * Use alerts endpoint to create all sources * Bump version to 0.10.3a1 * Include comments in source body * Fix up typing
- Loading branch information
1 parent
a275ac9
commit e8037b2
Showing
5 changed files
with
318 additions
and
933 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,12 +6,9 @@ | |
# Date: 03.11.2020 | ||
# Last Modified By: Jakob van Santen <[email protected]> | ||
|
||
import asyncio | ||
from collections.abc import Iterable | ||
from typing import Any | ||
|
||
import nest_asyncio | ||
|
||
from ampel.abstract.AbsBufferComplement import AbsBufferComplement | ||
from ampel.secret.NamedSecret import NamedSecret | ||
from ampel.struct.AmpelBuffer import AmpelBuffer | ||
|
@@ -29,12 +26,12 @@ class FritzReport(SkyPortalClient, AbsBufferComplement): | |
#: API token | ||
token: NamedSecret[str] = NamedSecret[str](label="fritz/jno/ampelbot") | ||
|
||
async def get_catalog_item(self, names: tuple[str, ...]) -> None | dict[str, Any]: | ||
def get_catalog_item(self, names: tuple[str, ...]) -> None | dict[str, Any]: | ||
"""Get catalog entry associated with the stock name""" | ||
for name in names: | ||
if name.startswith("ZTF"): | ||
try: | ||
record = await self.get(f"sources/{name}") | ||
record = self.get(f"sources/{name}") | ||
except SkyPortalAPIError: | ||
return None | ||
# strip out Fritz chatter | ||
|
@@ -45,28 +42,17 @@ async def get_catalog_item(self, names: tuple[str, ...]) -> None | dict[str, Any | |
} | ||
return None | ||
|
||
async def update_record(self, record: AmpelBuffer) -> None: | ||
def update_record(self, record: AmpelBuffer) -> None: | ||
if (stock := record["stock"]) is None: | ||
raise ValueError(f"{type(self).__name__} requires stock records") | ||
item = await self.get_catalog_item( | ||
item = self.get_catalog_item( | ||
tuple(name for name in (stock["name"] or []) if isinstance(name, str)) | ||
) | ||
if record.get("extra") is None or record["extra"] is None: | ||
record["extra"] = {self.__class__.__name__: item} | ||
else: | ||
record["extra"][self.__class__.__name__] = item | ||
|
||
async def update_records(self, records: Iterable[AmpelBuffer]) -> None: | ||
async with self.session(): | ||
await asyncio.gather(*[self.update_record(record) for record in records]) | ||
|
||
def complement(self, records: Iterable[AmpelBuffer], t3s: T3Store) -> None: | ||
# Patch event loop to be reentrant if it is already running, e.g. | ||
# within a notebook | ||
try: | ||
if asyncio.get_event_loop().is_running(): | ||
nest_asyncio.apply() | ||
except RuntimeError: | ||
# second call raises: RuntimeError: There is no current event loop in thread 'MainThread'. | ||
... | ||
asyncio.run(self.update_records(records)) | ||
for record in records: | ||
self.update_record(record) |
Oops, something went wrong.