From a3fe0f4e631b9f59b67b422f1c5bc3b005ae95a0 Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Wed, 8 Feb 2023 16:48:13 +0100 Subject: [PATCH 01/21] bump version --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index bb251ec..ae624fb 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -master-v1.0.0 +master-v1.0.1 From b3eb4b5850a9c9002cbe5bb83d742b1bd413dd07 Mon Sep 17 00:00:00 2001 From: Emir K Date: Fri, 26 Aug 2022 21:34:23 +0200 Subject: [PATCH 02/21] remove api a bit still need sites instruments. --- flows/fileio.py | 3 +- flows/instruments/instruments.py | 2 +- flows/instruments/sites.py | 18 ++ flows/photometry.py | 24 +- flows/run_imagematch.py | 30 +- flows/target.py | 7 +- flows/tns.py | 472 +++++++++++++++---------------- flows/visibility.py | 17 +- flows/ztf.py | 138 --------- tests/test_tns.py | 10 +- typings/astropy/__init__.pyi | 1 + 11 files changed, 309 insertions(+), 413 deletions(-) create mode 100644 flows/instruments/sites.py delete mode 100644 flows/ztf.py create mode 100644 typings/astropy/__init__.pyi diff --git a/flows/fileio.py b/flows/fileio.py index 1c37a64..c1ea16e 100644 --- a/flows/fileio.py +++ b/flows/fileio.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Optional, Protocol, Dict, Union +from typing import Optional, Protocol, Dict, TypeVar, Union from configparser import ConfigParser from bottleneck import allnan from tendrils import api, utils @@ -12,6 +12,7 @@ from .filters import get_reference_filter logger = create_logger() +DataFileType = TypeVar("DataFileType", bound=dict) class DirectoryProtocol(Protocol): archive_local: str diff --git a/flows/instruments/instruments.py b/flows/instruments/instruments.py index f162188..c5024cc 100644 --- a/flows/instruments/instruments.py +++ b/flows/instruments/instruments.py @@ -50,7 +50,7 @@ class LCOGT(Instrument): - siteid = None # Can be between 1, 3, 4, 6, 17, 19. @TODO: Refactor to own classes. + siteid = 1 # Can be between 1, 3, 4, 6, 17, 19. @TODO: Refactor to own classes. peakmax: int = 60000 origin = 'LCOGT' diff --git a/flows/instruments/sites.py b/flows/instruments/sites.py new file mode 100644 index 0000000..9f7a5d2 --- /dev/null +++ b/flows/instruments/sites.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass +from astropy.coordinates import EarthLocation +from typing import Optional +from tendrils import api + +@dataclass +class Site: + siteid: int + sitename: str + longtitude: float + latitude: float + elevation: float + earth_location: EarthLocation + site_keyword: Optional[str] = None + + @classmethod + def get_site(cls, siteid: int) -> 'Site': + return cls(**api.get_site(siteid)) \ No newline at end of file diff --git a/flows/photometry.py b/flows/photometry.py index 98330af..86918a2 100644 --- a/flows/photometry.py +++ b/flows/photometry.py @@ -21,7 +21,7 @@ from astropy.modeling import fitting from astropy.wcs.utils import proj_plane_pixel_area import multiprocessing -from tendrils import api +#from tendrils import api from .magnitudes import instrumental_mag from .result_model import ResultsTable @@ -49,19 +49,19 @@ logger = create_logger() -def get_datafile(fileid: int) -> Dict: - """ - Get datafile from API, log it, return. - """ - datafile = api.get_datafile(fileid) - logger.debug("Datafile: %s", datafile) - return datafile +# def get_datafile(fileid: int) -> Dict: +# """ +# Get datafile from API, log it, return. +# """ +# datafile = api.get_datafile(fileid) +# logger.debug("Datafile: %s", datafile) +# return datafile -def get_catalog(targetid: int) -> Dict: - catalog = api.get_catalog(targetid, output='table') - logger.debug(f"catalog obtained for target: {targetid}") - return catalog +# def get_catalog(targetid: int) -> Dict: +# catalog = api.get_catalog(targetid, output='table') +# logger.debug(f"catalog obtained for target: {targetid}") +# return catalog class PSFBuilder: diff --git a/flows/run_imagematch.py b/flows/run_imagematch.py index 418b696..ee9d46f 100644 --- a/flows/run_imagematch.py +++ b/flows/run_imagematch.py @@ -6,7 +6,9 @@ .. codeauthor:: Rasmus Handberg """ +from typing import Dict, AnyStr, Any, Optional import numpy as np +from numpy.typing import NDArray import logging import os import subprocess @@ -18,7 +20,9 @@ from astropy.io import fits from astropy.wcs.utils import proj_plane_pixel_area from tendrils import api + from .load_image import load_image +from .target import Target @@ -38,7 +42,8 @@ # return dist.install_scripts # -------------------------------------------------------------------------------------------------- -def run_imagematch(datafile, target=None, star_coord=None, fwhm=None, pixel_scale=None): +def run_imagematch(datafile: Dict[str, Any], target: Target, fwhm: float = 9, + pixel_scale: Optional[float] = None) -> NDArray: """ Run ImageMatch on a datafile. @@ -60,9 +65,9 @@ def run_imagematch(datafile, target=None, star_coord=None, fwhm=None, pixel_scal # If the target was not provided in the function call, # use the API to get the target information: - if target is None: - catalog = api.get_catalog(datafile['targetid'], output='table') - target = catalog['target'][0] + #if target is None: + # catalog = api.get_catalog(datafile['targetid'], output='table') + # target = catalog['target'][0] # Find the path to where the ImageMatch program is installed. # This is to avoid problems with it not being on the users PATH @@ -75,7 +80,10 @@ def run_imagematch(datafile, target=None, star_coord=None, fwhm=None, pixel_scal else: out = subprocess.check_output(["whereis", "ImageMatch"], universal_newlines=True) out = re.match('ImageMatch: (.+)', out.strip()) - imgmatch = out.group(1) + imgmatch = "None" + if out is not None: + imgmatch = out.group(1) + if not os.path.isfile(imgmatch): raise FileNotFoundError("ImageMatch not found") @@ -93,16 +101,16 @@ def run_imagematch(datafile, target=None, star_coord=None, fwhm=None, pixel_scal if not os.path.isfile(config_file): raise FileNotFoundError(config_file) + if pixel_scale is None: - if datafile['site'] in (1, 3, 4, 6): - # LCOGT provides the pixel scale directly in the header - pixel_scale = 'PIXSCALE' - else: - image = load_image(science_image) + image = load_image(science_image) + pixel_scale = image.header.get("PIXSCALE", None) + if pixel_scale is None: pixel_area = proj_plane_pixel_area(image.wcs) pixel_scale = np.sqrt(pixel_area) * 3600 # arcsec/pixel logger.info("Calculated science image pixel scale: %f", pixel_scale) + if datafile['template']['site'] in (1, 3, 4, 6): # LCOGT provides the pixel scale directly in the header mscale = 'PIXSCALE' @@ -134,7 +142,7 @@ def run_imagematch(datafile, target=None, star_coord=None, fwhm=None, pixel_scal cmd = '"{python:s}" "{imgmatch:s}" -cfg "{config_file:s}" -snx {target_ra:.10f}d -sny {target_dec:.10f}d -p {kernel_radius:d} -o {order:d} -s {match:f} -scale {pixel_scale:} -mscale {mscale:} -m "{reference_image:s}" "{science_image:s}"'.format( python=sys.executable, imgmatch=imgmatch, config_file=config_file, reference_image=os.path.basename(reference_image), science_image=os.path.basename(science_image), - target_ra=target['ra'], target_dec=target['decl'], match=match_threshold, kernel_radius=kernel_radius, + target_ra=target.ra, target_dec=target.dec, match=match_threshold, kernel_radius=kernel_radius, pixel_scale=pixel_scale, mscale=mscale, order=1) logger.info("Executing command: %s", cmd) diff --git a/flows/target.py b/flows/target.py index 5bdd35c..0cfaf6b 100644 --- a/flows/target.py +++ b/flows/target.py @@ -2,7 +2,7 @@ from typing import Optional, Dict import numpy as np -from numpy.typing import ArrayLike +from numpy.typing import NDArray from astropy.coordinates import SkyCoord from astropy.wcs import WCS from tendrils import api @@ -27,7 +27,7 @@ def calc_pixels(self, wcs: WCS) -> None: self._add_pixel_coordinates(pixel_pos=pixels) def _add_pixel_coordinates(self, pixel_column: Optional[int] = None, pixel_row: Optional[int] = None, - pixel_pos: Optional[ArrayLike] = None) -> None: + pixel_pos: Optional[NDArray] = None) -> None: """ Add pixel coordinates to target. """ @@ -58,6 +58,9 @@ def from_fid(cls, fid: int, datafile: Optional[Dict] = None) -> 'Target': """ Create target from fileid. """ + datafile = datafile or api.get_datafile(fid) + if datafile is None: + raise ValueError(f'No datafile found for fid={fid}') d = api.get_target(datafile['target_name']) | datafile return cls.from_dict(d) diff --git a/flows/tns.py b/flows/tns.py index 7b7ffa9..bae1050 100644 --- a/flows/tns.py +++ b/flows/tns.py @@ -1,236 +1,236 @@ -""" -TNS API FUNCTIONS -Pre-provided helper functions for the TNS API, type annotations added -Obtained from https://wis-tns.weizmann.ac.il/content/tns-getting-started -""" -from __future__ import annotations -import logging -from astropy.table import Table -import astropy.units as u -from astropy.coordinates import SkyCoord -import requests -import json -import datetime -from tendrils.utils import load_config -from typing import Optional, Union - -url_tns_api = 'https://www.wis-tns.org/api/get' -url_tns_search = 'https://www.wis-tns.org/search' -DateType = Union[datetime.datetime, str] - - -class TNSConfigError(RuntimeError): - pass - - -def _load_tns_config() -> dict[str, str]: - logger = logging.getLogger(__name__) - - config = load_config() - api_key = config.get('TNS', 'api_key', fallback=None) - if api_key is None: - raise TNSConfigError("No TNS API-KEY has been defined in config") - - tns_bot_id = config.getint('TNS', 'bot_id', fallback=93222) - tns_bot_name = config.get('TNS', 'bot_name', fallback='AUFLOWS_BOT') - tns_user_id = config.getint('TNS', 'user_id', fallback=None) - tns_user_name = config.get('TNS', 'user_name', fallback=None) - - if tns_user_id and tns_user_name: - logger.debug('Using TNS credentials: user=%s', tns_user_name) - user_agent = 'tns_marker{"tns_id":' + str(tns_user_id) + ',"type":"user","name":"' + tns_user_name + '"}' - elif tns_bot_id and tns_bot_name: - logger.debug('Using TNS credentials: bot=%s', tns_bot_name) - user_agent = 'tns_marker{"tns_id":' + str(tns_bot_id) + ',"type":"bot","name":"' + tns_bot_name + '"}' - else: - raise TNSConfigError("No TNS bot_id or bot_name has been defined in config") - - return {'api-key': api_key, 'user-agent': user_agent} - - -def tns_search(coord: Optional[SkyCoord] = None, radius: u.Quantity = 3 * u.arcsec, objname: Optional[str] = None, - internal_name: Optional[str] = None) -> Optional[dict]: - """ - Cone-search TNS for object near coordinate. - - Parameters: - coord (:class:`astropy.coordinates.SkyCoord`): Central coordinate to search around. - radius (Angle, optional): Radius to search around ``coord``. - objname (str, optional): Search on object name. - internal_name (str, optional): Search on internal name. - - Returns: - dict: Dictionary with TSN response. - """ - - # API key for Bot - tnsconf = _load_tns_config() - - # change json_list to json format - json_file = {'radius': radius.to('arcsec').value, 'units': 'arcsec', 'objname': objname, - 'internal_name': internal_name} - if coord: - json_file['ra'] = coord.icrs.ra.deg - json_file['dec'] = coord.icrs.dec.deg - - # construct the list of (key,value) pairs - headers = {'user-agent': tnsconf['user-agent']} - search_data = [('api_key', (None, tnsconf['api-key'])), ('data', (None, json.dumps(json_file)))] - - # search obj using request module - res = requests.post(url_tns_api + '/search', files=search_data, headers=headers) - res.raise_for_status() - parsed = res.json() - data = parsed['data'] - - if 'reply' in data: - return data['reply'] - return None - - -def tns_get_obj(name: str) -> Optional[dict]: - """ - Search TNS for object by name. - - Parameters: - name (str): Object name to search for. - - Returns: - dict: Dictionary with TSN response. - """ - - # API key for Bot - tnsconf = _load_tns_config() - - # construct the list of (key,value) pairs - headers = {'user-agent': tnsconf['user-agent']} - params = {'objname': name, 'photometry': '0', 'spectra': '0'} - get_data = [('api_key', (None, tnsconf['api-key'])), ('data', (None, json.dumps(params)))] - - # get obj using request module - res = requests.post(url_tns_api + '/object', files=get_data, headers=headers) - res.raise_for_status() - parsed = res.json() - data = parsed['data'] - - if 'reply' in data: - reply = data['reply'] - if not reply: - return None - if 'objname' not in reply: # Bit of a cheat, but it is simple and works - return None - - reply['internal_names'] = [name.strip() for name in reply['internal_names'].split(',') if name.strip()] - return reply - return None - - -def tns_getnames(months: Optional[int] = None, date_begin: Optional[DateType] = None, - date_end: Optional[DateType] = None, zmin: Optional[float] = None, - zmax: Optional[float] = None, objtype: tuple[int] = (3, 104)) -> list[str]: - """ - Get SN names from TNS. - - Parameters: - months (int, optional): Only return objects reported within the last X months. - date_begin (date, optional): Discovery date begin. - date_end (date, optional): Discovery date end. - zmin (float, optional): Minimum redshift. - zmax (float, optional): Maximum redshift. - objtype (list, optional): Constraint object type. - Default is to query for - - 3: SN Ia - - 104: SN Ia-91T-like - - Returns: - list: List of names fulfilling search criteria. - """ - - logger = logging.getLogger(__name__) - - # Change formats of input to be ready for query: - if isinstance(date_begin, datetime.datetime): - date_begin = date_begin.date() - elif isinstance(date_begin, str): - date_begin = datetime.datetime.strptime(date_begin, '%Y-%m-%d').date() - - if isinstance(date_end, datetime.datetime): - date_end = date_end.date() - elif isinstance(date_end, str): - date_end = datetime.datetime.strptime(date_end, '%Y-%m-%d').date() - - if isinstance(objtype, (list, tuple)): - objtype = ','.join([str(o) for o in objtype]) - - # Do some sanity checks: - if date_end < date_begin: - raise ValueError("Dates are in the wrong order.") - - date_now = datetime.datetime.now(datetime.timezone.utc).date() - if months is not None and date_end is not None and date_end < date_now - datetime.timedelta(days=months * 30): - logger.warning('Months limit restricts days_begin, consider increasing limit_months.') - - # API key for Bot - tnsconf = _load_tns_config() - - # Parameters for query: - params = {'discovered_period_value': months, # Reported Within The Last - 'discovered_period_units': 'months', 'unclassified_at': 0, # Limit to unclasssified ATs - 'classified_sne': 1, # Limit to classified SNe - 'include_frb': 0, # Include FRBs - # 'name': , - 'name_like': 0, 'isTNS_AT': 'all', 'public': 'all', # 'ra': - # 'decl': - # 'radius': - # 'coords_unit': 'arcsec', - 'reporting_groupid[]': 'null', 'groupid[]': 'null', 'classifier_groupid[]': 'null', 'objtype[]': objtype, - 'at_type[]': 'null', 'date_start[date]': date_begin.isoformat(), 'date_end[date]': date_end.isoformat(), - # 'discovery_mag_min': - # 'discovery_mag_max': - # 'internal_name': - # 'discoverer': - # 'classifier': - # 'spectra_count': - 'redshift_min': zmin, 'redshift_max': zmax, # 'hostname': - # 'ext_catid': - # 'ra_range_min': - # 'ra_range_max': - # 'decl_range_min': - # 'decl_range_max': - 'discovery_instrument[]': 'null', 'classification_instrument[]': 'null', 'associated_groups[]': 'null', - # 'at_rep_remarks': - # 'class_rep_remarks': - # 'frb_repeat': 'all' - # 'frb_repeater_of_objid': - 'frb_measured_redshift': 0, # 'frb_dm_range_min': - # 'frb_dm_range_max': - # 'frb_rm_range_min': - # 'frb_rm_range_max': - # 'frb_snr_range_min': - # 'frb_snr_range_max': - # 'frb_flux_range_min': - # 'frb_flux_range_max': - 'num_page': 500, 'display[redshift]': 0, 'display[hostname]': 0, 'display[host_redshift]': 0, - 'display[source_group_name]': 0, 'display[classifying_source_group_name]': 0, - 'display[discovering_instrument_name]': 0, 'display[classifing_instrument_name]': 0, - 'display[programs_name]': 0, 'display[internal_name]': 0, 'display[isTNS_AT]': 0, 'display[public]': 0, - 'display[end_pop_period]': 0, 'display[spectra_count]': 0, 'display[discoverymag]': 0, - 'display[discmagfilter]': 0, 'display[discoverydate]': 0, 'display[discoverer]': 0, 'display[remarks]': 0, - 'display[sources]': 0, 'display[bibcode]': 0, 'display[ext_catalogs]': 0, 'format': 'csv'} - - # Query TNS for names: - headers = {'user-agent': tnsconf['user-agent']} - con = requests.get(url_tns_search, params=params, headers=headers) - con.raise_for_status() - - # Parse the CSV table: - # Ensure that there is a newline in table string. - # AstroPy uses this to distinguish file-paths from pure-string inputs: - text = str(con.text) + "\n" - tab = Table.read(text, format='ascii.csv', guess=False, delimiter=',', quotechar='"', header_start=0, data_start=1) - - # Pull out the names only if they begin with "SN": - names_list = [name.replace(' ', '') for name in tab['Name'] if name.startswith('SN')] - names_list = sorted(names_list) - - return names_list +# """ +# TNS API FUNCTIONS +# Pre-provided helper functions for the TNS API, type annotations added +# Obtained from https://wis-tns.weizmann.ac.il/content/tns-getting-started +# """ +# from __future__ import annotations +# import logging +# from astropy.table import Table +# import astropy.units as u +# from astropy.coordinates import SkyCoord +# import requests +# import json +# import datetime +# from tendrils.utils import load_config +# from typing import Optional, Union + +# url_tns_api = 'https://www.wis-tns.org/api/get' +# url_tns_search = 'https://www.wis-tns.org/search' +# DateType = Union[datetime.datetime, str] + + +# class TNSConfigError(RuntimeError): +# pass + + +# def _load_tns_config() -> dict[str, str]: +# logger = logging.getLogger(__name__) + +# config = load_config() +# api_key = config.get('TNS', 'api_key', fallback=None) +# if api_key is None: +# raise TNSConfigError("No TNS API-KEY has been defined in config") + +# tns_bot_id = config.getint('TNS', 'bot_id', fallback=93222) +# tns_bot_name = config.get('TNS', 'bot_name', fallback='AUFLOWS_BOT') +# tns_user_id = config.getint('TNS', 'user_id', fallback=None) +# tns_user_name = config.get('TNS', 'user_name', fallback=None) + +# if tns_user_id and tns_user_name: +# logger.debug('Using TNS credentials: user=%s', tns_user_name) +# user_agent = 'tns_marker{"tns_id":' + str(tns_user_id) + ',"type":"user","name":"' + tns_user_name + '"}' +# elif tns_bot_id and tns_bot_name: +# logger.debug('Using TNS credentials: bot=%s', tns_bot_name) +# user_agent = 'tns_marker{"tns_id":' + str(tns_bot_id) + ',"type":"bot","name":"' + tns_bot_name + '"}' +# else: +# raise TNSConfigError("No TNS bot_id or bot_name has been defined in config") + +# return {'api-key': api_key, 'user-agent': user_agent} + + +# def tns_search(coord: Optional[SkyCoord] = None, radius: u.Quantity = 3 * u.arcsec, objname: Optional[str] = None, +# internal_name: Optional[str] = None) -> Optional[dict]: +# """ +# Cone-search TNS for object near coordinate. + +# Parameters: +# coord (:class:`astropy.coordinates.SkyCoord`): Central coordinate to search around. +# radius (Angle, optional): Radius to search around ``coord``. +# objname (str, optional): Search on object name. +# internal_name (str, optional): Search on internal name. + +# Returns: +# dict: Dictionary with TSN response. +# """ + +# # API key for Bot +# tnsconf = _load_tns_config() + +# # change json_list to json format +# json_file = {'radius': radius.to('arcsec').value, 'units': 'arcsec', 'objname': objname, +# 'internal_name': internal_name} +# if coord: +# json_file['ra'] = coord.icrs.ra.deg +# json_file['dec'] = coord.icrs.dec.deg + +# # construct the list of (key,value) pairs +# headers = {'user-agent': tnsconf['user-agent']} +# search_data = [('api_key', (None, tnsconf['api-key'])), ('data', (None, json.dumps(json_file)))] + +# # search obj using request module +# res = requests.post(url_tns_api + '/search', files=search_data, headers=headers) +# res.raise_for_status() +# parsed = res.json() +# data = parsed['data'] + +# if 'reply' in data: +# return data['reply'] +# return None + + +# def tns_get_obj(name: str) -> Optional[dict]: +# """ +# Search TNS for object by name. + +# Parameters: +# name (str): Object name to search for. + +# Returns: +# dict: Dictionary with TSN response. +# """ + +# # API key for Bot +# tnsconf = _load_tns_config() + +# # construct the list of (key,value) pairs +# headers = {'user-agent': tnsconf['user-agent']} +# params = {'objname': name, 'photometry': '0', 'spectra': '0'} +# get_data = [('api_key', (None, tnsconf['api-key'])), ('data', (None, json.dumps(params)))] + +# # get obj using request module +# res = requests.post(url_tns_api + '/object', files=get_data, headers=headers) +# res.raise_for_status() +# parsed = res.json() +# data = parsed['data'] + +# if 'reply' in data: +# reply = data['reply'] +# if not reply: +# return None +# if 'objname' not in reply: # Bit of a cheat, but it is simple and works +# return None + +# reply['internal_names'] = [name.strip() for name in reply['internal_names'].split(',') if name.strip()] +# return reply +# return None + + +# def tns_getnames(months: Optional[int] = None, date_begin: Optional[DateType] = None, +# date_end: Optional[DateType] = None, zmin: Optional[float] = None, +# zmax: Optional[float] = None, objtype: tuple[int] = (3, 104)) -> list[str]: +# """ +# Get SN names from TNS. + +# Parameters: +# months (int, optional): Only return objects reported within the last X months. +# date_begin (date, optional): Discovery date begin. +# date_end (date, optional): Discovery date end. +# zmin (float, optional): Minimum redshift. +# zmax (float, optional): Maximum redshift. +# objtype (list, optional): Constraint object type. +# Default is to query for +# - 3: SN Ia +# - 104: SN Ia-91T-like + +# Returns: +# list: List of names fulfilling search criteria. +# """ + +# logger = logging.getLogger(__name__) + +# # Change formats of input to be ready for query: +# if isinstance(date_begin, datetime.datetime): +# date_begin = date_begin.date() +# elif isinstance(date_begin, str): +# date_begin = datetime.datetime.strptime(date_begin, '%Y-%m-%d').date() + +# if isinstance(date_end, datetime.datetime): +# date_end = date_end.date() +# elif isinstance(date_end, str): +# date_end = datetime.datetime.strptime(date_end, '%Y-%m-%d').date() + +# if isinstance(objtype, (list, tuple)): +# objtype = ','.join([str(o) for o in objtype]) + +# # Do some sanity checks: +# if date_end < date_begin: +# raise ValueError("Dates are in the wrong order.") + +# date_now = datetime.datetime.now(datetime.timezone.utc).date() +# if months is not None and date_end is not None and date_end < date_now - datetime.timedelta(days=months * 30): +# logger.warning('Months limit restricts days_begin, consider increasing limit_months.') + +# # API key for Bot +# tnsconf = _load_tns_config() + +# # Parameters for query: +# params = {'discovered_period_value': months, # Reported Within The Last +# 'discovered_period_units': 'months', 'unclassified_at': 0, # Limit to unclasssified ATs +# 'classified_sne': 1, # Limit to classified SNe +# 'include_frb': 0, # Include FRBs +# # 'name': , +# 'name_like': 0, 'isTNS_AT': 'all', 'public': 'all', # 'ra': +# # 'decl': +# # 'radius': +# # 'coords_unit': 'arcsec', +# 'reporting_groupid[]': 'null', 'groupid[]': 'null', 'classifier_groupid[]': 'null', 'objtype[]': objtype, +# 'at_type[]': 'null', 'date_start[date]': date_begin.isoformat(), 'date_end[date]': date_end.isoformat(), +# # 'discovery_mag_min': +# # 'discovery_mag_max': +# # 'internal_name': +# # 'discoverer': +# # 'classifier': +# # 'spectra_count': +# 'redshift_min': zmin, 'redshift_max': zmax, # 'hostname': +# # 'ext_catid': +# # 'ra_range_min': +# # 'ra_range_max': +# # 'decl_range_min': +# # 'decl_range_max': +# 'discovery_instrument[]': 'null', 'classification_instrument[]': 'null', 'associated_groups[]': 'null', +# # 'at_rep_remarks': +# # 'class_rep_remarks': +# # 'frb_repeat': 'all' +# # 'frb_repeater_of_objid': +# 'frb_measured_redshift': 0, # 'frb_dm_range_min': +# # 'frb_dm_range_max': +# # 'frb_rm_range_min': +# # 'frb_rm_range_max': +# # 'frb_snr_range_min': +# # 'frb_snr_range_max': +# # 'frb_flux_range_min': +# # 'frb_flux_range_max': +# 'num_page': 500, 'display[redshift]': 0, 'display[hostname]': 0, 'display[host_redshift]': 0, +# 'display[source_group_name]': 0, 'display[classifying_source_group_name]': 0, +# 'display[discovering_instrument_name]': 0, 'display[classifing_instrument_name]': 0, +# 'display[programs_name]': 0, 'display[internal_name]': 0, 'display[isTNS_AT]': 0, 'display[public]': 0, +# 'display[end_pop_period]': 0, 'display[spectra_count]': 0, 'display[discoverymag]': 0, +# 'display[discmagfilter]': 0, 'display[discoverydate]': 0, 'display[discoverer]': 0, 'display[remarks]': 0, +# 'display[sources]': 0, 'display[bibcode]': 0, 'display[ext_catalogs]': 0, 'format': 'csv'} + +# # Query TNS for names: +# headers = {'user-agent': tnsconf['user-agent']} +# con = requests.get(url_tns_search, params=params, headers=headers) +# con.raise_for_status() + +# # Parse the CSV table: +# # Ensure that there is a newline in table string. +# # AstroPy uses this to distinguish file-paths from pure-string inputs: +# text = str(con.text) + "\n" +# tab = Table.read(text, format='ascii.csv', guess=False, delimiter=',', quotechar='"', header_start=0, data_start=1) + +# # Pull out the names only if they begin with "SN": +# names_list = [name.replace(' ', '') for name in tab['Name'] if name.startswith('SN')] +# names_list = sorted(names_list) + +# return names_list diff --git a/flows/visibility.py b/flows/visibility.py index 2736b6b..bb6e108 100644 --- a/flows/visibility.py +++ b/flows/visibility.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -Target visibility plotting. +Target visibility plotting. +@TODO: Move to flows-tools. .. codeauthor:: Rasmus Handberg """ @@ -17,10 +18,12 @@ from astropy.coordinates import SkyCoord, AltAz, get_sun, get_moon from astropy.visualization import quantity_support from tendrils import api +from .target import Target +from typing import Optional # -------------------------------------------------------------------------------------------------- -def visibility(target, siteid=None, date=None, output=None, overwrite=True): +def visibility(target: Target, siteid: Optional[int] = None, date=None, output=None, overwrite=True): """ Create visibility plot. @@ -44,10 +47,8 @@ def visibility(target, siteid=None, date=None, output=None, overwrite=True): elif isinstance(date, str): date = datetime.strptime(date, '%Y-%m-%d') - tgt = api.get_target(target) - # Coordinates of object: - obj = SkyCoord(ra=tgt['ra'], dec=tgt['decl'], unit='deg', frame='icrs') + obj = SkyCoord(ra=target.ra, dec=target.dec, unit='deg', frame='icrs') if siteid is None: sites = api.get_all_sites() @@ -61,7 +62,7 @@ def visibility(target, siteid=None, date=None, output=None, overwrite=True): if output: if os.path.isdir(output): plotpath = os.path.join(output, "visibility_%s_%s_site%02d.png" % ( - tgt['target_name'], date.strftime('%Y%m%d'), site['siteid'])) + target.name, date.strftime('%Y%m%d'), site['siteid'])) else: plotpath = output logger.debug("Will save visibility plot to '%s'", plotpath) @@ -103,7 +104,7 @@ def visibility(target, siteid=None, date=None, output=None, overwrite=True): plt.grid(ls=':', lw=0.5) ax.plot(times.datetime, altaz_sun.alt, color='y', label='Sun') ax.plot(times.datetime, altaz_moon.alt, color=[0.75] * 3, ls='--', label='Moon') - objsc = ax.scatter(times.datetime, altaz_obj.alt, c=altaz_obj.az, label=tgt['target_name'], lw=0, s=8, + objsc = ax.scatter(times.datetime, altaz_obj.alt, c=altaz_obj.az, label=target.name, lw=0, s=8, cmap='twilight') ax.fill_between(times.datetime, 0 * u.deg, 90 * u.deg, altaz_sun.alt < -0 * u.deg, color='0.5', zorder=0) # , label='Night' @@ -115,7 +116,7 @@ def visibility(target, siteid=None, date=None, output=None, overwrite=True): ax.minorticks_on() ax.set_xlim(min_time.datetime, max_time.datetime) ax.set_ylim(0 * u.deg, 90 * u.deg) - ax.set_title("%s - %s - %s" % (str(tgt['target_name']), date.strftime('%Y-%m-%d'), site['sitename']), + ax.set_title("%s - %s - %s" % (target.name, date.strftime('%Y-%m-%d'), site['sitename']), fontsize=14) plt.xlabel('Time [UTC]', fontsize=14) plt.ylabel('Altitude [deg]', fontsize=16) diff --git a/flows/ztf.py b/flows/ztf.py deleted file mode 100644 index 75ad66c..0000000 --- a/flows/ztf.py +++ /dev/null @@ -1,138 +0,0 @@ -""" -Query ZTF target information using ALeRCE API. -https://alerceapi.readthedocs.io/ -""" - -import numpy as np -import astropy.units as u -from astropy.coordinates import Angle, SkyCoord -from astropy.table import Table -from astropy.time import Time -import datetime -import requests -from tendrils import api - - -# -------------------------------------------------------------------------------------------------- -def query_ztf_id(coo_centre, radius=3 * u.arcsec, discovery_date=None): - """ - Query ALeRCE ZTF api to lookup ZTF identifier. - - In case multiple identifiers are found within the search cone, the one - closest to the centre is returned. - - Parameters: - coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. - radius (Angle, optional): Search radius. Default 3 arcsec. - discovery_date (:class:`astropy.time.Time`, optional): Discovery date of target to - match against ZTF. The date is compared to the ZTF first timestamp and ZTF targets - are rejected if they are not within 15 days prior to the discovery date - and 90 days after. - - Returns: - str: ZTF identifier. - - .. codeauthor:: Emir Karamehmetoglu - .. codeauthor:: Rasmus Handberg - """ - - if isinstance(radius, (float, int)): - radius *= u.deg - - # Make json query for Alerce query API - query = {'ra': coo_centre.ra.deg, 'dec': coo_centre.dec.deg, 'radius': Angle(radius).arcsec, 'page_size': 20, - 'count': True} - - # Run http POST json query to alerce following their API - res = requests.get('https://api.alerce.online/ztf/v1/objects', params=query) - res.raise_for_status() - jsn = res.json() - - # If nothing was found, return None: - if jsn['total'] == 0: - return None - - # Start by removing anything marked as likely stellar-like source: - results = jsn['items'] - results = [itm for itm in results if not itm['stellar']] - if not results: - return None - - # Constrain on the discovery date if it is provided: - if discovery_date is not None: - # Extract the time of the first ZTF timestamp and compare it with - # the discovery time: - firstmjd = Time([itm['firstmjd'] for itm in results], format='mjd', scale='utc') - tdelta = firstmjd.utc.mjd - discovery_date.utc.mjd - - # Only keep results that are within the margins: - results = [itm for k, itm in enumerate(results) if -15 <= tdelta[k] <= 90] - if not results: - return None - - # Find target closest to the centre: - coords = SkyCoord(ra=[itm['meanra'] for itm in results], dec=[itm['meandec'] for itm in results], unit='deg', - frame='icrs') - - indx = np.argmin(coords.separation(coo_centre)) - - return results[indx]['oid'] - - -# -------------------------------------------------------------------------------------------------- -def download_ztf_photometry(targetid): - """ - Download ZTF photometry from ALERCE API. - - Parameters: - targetid (int): Target identifier. - - Returns: - :class:`astropy.table.Table`: ZTF photometry table. - - .. codeauthor:: Emir Karamehmetoglu - .. codeauthor:: Rasmus Handberg - """ - - # Get target info from Flows API: - tgt = api.get_target(targetid) - oid = tgt['ztf_id'] - target_name = tgt['target_name'] - if oid is None: - return None - - # Query ALERCE for detections of object based on oid - res = requests.get(f'https://api.alerce.online/ztf/v1/objects/{oid:s}/detections') - res.raise_for_status() - jsn = res.json() - - # Create Astropy table, cut out the needed columns - # and rename columns to something better for what we are doing: - tab = Table(data=jsn) - tab = tab[['fid', 'mjd', 'magpsf', 'sigmapsf']] - tab.rename_column('fid', 'photfilter') - tab.rename_column('mjd', 'time') - tab.rename_column('magpsf', 'mag') - tab.rename_column('sigmapsf', 'mag_err') - - # Remove bad values of time and magnitude: - tab['time'] = np.asarray(tab['time'], dtype='float64') - tab['mag'] = np.asarray(tab['mag'], dtype='float64') - tab['mag_err'] = np.asarray(tab['mag_err'], dtype='float64') - indx = np.isfinite(tab['time']) & np.isfinite(tab['mag']) & np.isfinite(tab['mag_err']) - tab = tab[indx] - - # Replace photometric filter numbers with keywords used in Flows: - photfilter_dict = {1: 'gp', 2: 'rp', 3: 'ip'} - tab['photfilter'] = [photfilter_dict[fid] for fid in tab['photfilter']] - - # Sort the table on photfilter and time: - tab.sort(['photfilter', 'time']) - - # Add meta information to table header: - tab.meta['target_name'] = target_name - tab.meta['targetid'] = targetid - tab.meta['ztf_id'] = oid - tab.meta['last_updated'] = datetime.datetime.now(tz=datetime.timezone.utc).isoformat() - - return tab diff --git a/tests/test_tns.py b/tests/test_tns.py index ae9f156..42dab79 100644 --- a/tests/test_tns.py +++ b/tests/test_tns.py @@ -10,7 +10,7 @@ import os import datetime from astropy.coordinates import SkyCoord -from flows import tns +from tendrils.utils import tns # -------------------------------------------------------------------------------------------------- @@ -20,7 +20,8 @@ def test_tns_search(): coo_centre = SkyCoord(ra=191.283890127, dec=-0.45909033652, unit='deg', frame='icrs') res = tns.tns_search(coo_centre) - print(res) + if res is None: + raise ValueError("No results found.") assert res[0]['objname'] == '2019yvr' assert res[0]['prefix'] == 'SN' @@ -31,7 +32,9 @@ def test_tns_search(): def test_tns_get_obj(): res = tns.tns_get_obj('2019yvr') - print(res) + if res is None: + raise ValueError("No results found.") + assert res['objname'] == '2019yvr' assert res['name_prefix'] == 'SN' @@ -41,7 +44,6 @@ def test_tns_get_obj(): reason="Disabled on GitHub Actions to avoid too many requests HTTP error") def test_tns_get_obj_noexist(): res = tns.tns_get_obj('1892doesnotexist') - print(res) assert res is None diff --git a/typings/astropy/__init__.pyi b/typings/astropy/__init__.pyi new file mode 100644 index 0000000..812cf75 --- /dev/null +++ b/typings/astropy/__init__.pyi @@ -0,0 +1 @@ +def __getattr__(name: str) -> Any: ... \ No newline at end of file From 0ea41799a5b24d14549805155710e92e83071fd5 Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Sun, 28 Aug 2022 19:31:30 +0200 Subject: [PATCH 03/21] add site generation from astropy and user input --- flows/instruments/sites.py | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/flows/instruments/sites.py b/flows/instruments/sites.py index 9f7a5d2..4650b86 100644 --- a/flows/instruments/sites.py +++ b/flows/instruments/sites.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from astropy.coordinates import EarthLocation +import astropy.units as u from typing import Optional from tendrils import api @@ -7,12 +8,40 @@ class Site: siteid: int sitename: str - longtitude: float + longitude: float latitude: float elevation: float - earth_location: EarthLocation + earth_location: Optional[EarthLocation] = None site_keyword: Optional[str] = None + + def __post_init__(self): + if self.earth_location is None: + return + self.earth_location = EarthLocation( + lat=self.latitude*u.deg, + lon=self.longitude*u.deg, + height=self.elevation*u.m + ) @classmethod - def get_site(cls, siteid: int) -> 'Site': - return cls(**api.get_site(siteid)) \ No newline at end of file + def from_flows(cls, siteid: int) -> 'Site': + return cls(**api.get_site(siteid)) + + @classmethod + def from_astropy(cls, sitename: str) -> 'Site': + loc = EarthLocation.of_site(sitename) + return cls(siteid=999, sitename=sitename, + longitude=loc.long.value, latitude=loc.lat.value, + elevation=loc.height.value, earth_location=loc) + + @classmethod + def from_query(cls) -> 'Site': + sitename = input('Enter a site name for logging: ') + longitude = float(input('Enter longitude in degrees: ')) + lat = float(input('Enter latitude in degrees: ')) + elevation = float(input('Enter elevation in meters: ')) + siteid = 1000 # hardcoded for user defined site + return cls(siteid=999, sitename=sitename, + longitude=longitude, latitude=lat, + elevation=elevation) + \ No newline at end of file From 366263fc7b5aa251281801fdc87bb800b5d76c75 Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Sun, 28 Aug 2022 19:31:55 +0200 Subject: [PATCH 04/21] add filter guessing from fits header --- flows/filters.py | 75 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/flows/filters.py b/flows/filters.py index 149e501..df1f111 100644 --- a/flows/filters.py +++ b/flows/filters.py @@ -1,3 +1,4 @@ +from typing import Optional from .utilities import create_logger logger = create_logger() FILTERS = { @@ -30,3 +31,77 @@ def get_reference_filter(photfilter: str) -> str: f"Using default {FALLBACK_FILTER} filter.") _ref_filter = FILTERS[FALLBACK_FILTER] return _ref_filter + +def clean_value(value: str) -> str: + """ + Clean value. + """ + return value.replace(' ', '').replace('-', '').replace('.', '').replace('_', '').lower() + +COMMON_FILTERS = { + 'B': 'B', 'V': 'V', 'R': 'R', 'g': 'gp', 'r': 'rp', + 'i': 'ip', 'u': 'up', 'z': 'zp', + 'Ks': 'K', 'Hs': 'H', 'Js': 'J', + 'Bessel-B': 'B', 'Bessel-V': 'V', 'Bessell-V': 'V', 'SDSS-U': 'up', + 'SDSS-G': 'gp', 'SDSS-R': 'rp', 'SDSS-I': 'ip', 'SDSS-Z': 'zp', + 'PS1-u': 'up', 'PS1-g': 'gp', 'PS1-r': 'rp', 'PS1-i': 'ip', 'PS1-z': 'zp', + 'PS2-u': 'up', 'PS2-g': 'gp', 'PS2-r': 'rp', 'PS2-i': 'ip', 'PS2-z': 'zp', + 'PS-u': 'up', 'PS-g': 'gp', 'PS-r': 'rp', 'PS-i': 'ip', 'PS-z': 'zp', + 'Yc': 'Y', 'Jc': 'J', 'Hc': 'H', 'Kc': 'K', + 'Yo': 'Y', 'Jo': 'J', 'Ho': 'H', 'Ko': 'K', + "J_Open": "J", "H_Open": "H", "K_Open": "K", + "B_Open": "B", "V_Open": "V", "R_Open": "r", "I_Open": "i", "Y_Open": "Y", + "g_Open": "gp", "r_Open": "rp", "i_Open": "ip", "z_Open": "zp", 'u_Open': 'up', +} + + +COMMON_FILTERS_LOWER = {clean_value(key): value for key, value in COMMON_FILTERS.items()} + + +def match_header_to_filter(header_dict: dict[str,str]) -> str: + """ + Extract flows filter from header. + """ + bad_keys = ["", "NONE", "Clear"] + filt = header_dict.get("FILTER") + if filt is not None and filt not in bad_keys: + filt = match_filter_to_flows(filt) + if filt is not None: + return filt + + for key, value in header_dict.items(): + if "FILT" in key.upper(): + if value not in bad_keys: + filt = match_filter_to_flows(value) + if filt is not None: + return filt + + + raise ValueError("Could not determine filter from header. Add FILTER keyword with a flows filter to header.") + + +def match_filter_to_flows(header_filter: str) -> Optional[str]: + """ + Match filter header value to flows filter. + """ + if header_filter in FILTERS: + return header_filter + + values = max(header_filter.lower().split(' '), header_filter.lower().split('.'), key=len) + + + filters_keys_lower = [str(key).lower() for key in FILTERS.keys()] + for value in values: + if value in filters_keys_lower: + if FILTERS.get(value) is not None: + return value + return value.upper() + + + filters_keys_lower = [clean_value(str(key)) for key in COMMON_FILTERS.keys()] + for value in values: + clean = clean_value(value) + if clean in filters_keys_lower: + return COMMON_FILTERS_LOWER.get(clean) + + return None From c6ba0dd5db60c77808311dcccc2de2c73b165e3a Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Sun, 28 Aug 2022 19:33:05 +0200 Subject: [PATCH 05/21] ignore VScode typing stub --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index c3614fb..b7b052d 100644 --- a/.gitignore +++ b/.gitignore @@ -142,3 +142,5 @@ dmypy.json # Ignore test output tests/output/ + +typings/astropy/units/ From bc9b2690e70533216308a38799f678fffcf71fc1 Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Mon, 29 Aug 2022 09:13:53 +0200 Subject: [PATCH 06/21] ignore all VScode typing clutter --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index b7b052d..98f3808 100644 --- a/.gitignore +++ b/.gitignore @@ -143,4 +143,5 @@ dmypy.json # Ignore test output tests/output/ -typings/astropy/units/ +# Ignore VSCode typing clutter +typings/ From aa696a4930e5aebbb4c524f64f4eeef673170635 Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Thu, 5 Jan 2023 14:03:32 +0100 Subject: [PATCH 07/21] ensure tests passing --- dev_requirements.txt | 4 +- flows/instruments/base_instrument.py | 6 +- flows/instruments/instruments.py | 17 ++++-- flows/load_image.py | 25 +++++--- setup.cfg | 4 +- tests/test_instruments.py | 9 ++- tests/test_tns.py | 11 ++-- typings/astropy/__init__.pyi | 88 +++++++++++++++++++++++++++- 8 files changed, 137 insertions(+), 27 deletions(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index 5871639..4cd67cd 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,5 +1,5 @@ pytest flake8 -flake8-tabs >= 2.3.2 flake8-builtins -flake8-logging-format \ No newline at end of file +flake8-logging-format +autopep8 \ No newline at end of file diff --git a/flows/instruments/base_instrument.py b/flows/instruments/base_instrument.py index deab018..8e8ad4c 100644 --- a/flows/instruments/base_instrument.py +++ b/flows/instruments/base_instrument.py @@ -17,7 +17,7 @@ class AbstractInstrument(ABC): - peakmax: int = None + peakmax: int @abstractmethod def __init__(self): @@ -45,8 +45,8 @@ def process_image(self): class Instrument(AbstractInstrument): - peakmax: int = None - siteid: int = None + peakmax: int = int(1e18) + siteid: int = -99 telescope: str = '' # Fits Header name of TELESCOP instrument: str = '' # Fits Header name of Instrument (can be partial) origin: str = '' # Fits Header value of ORIGIN (if relevant) diff --git a/flows/instruments/instruments.py b/flows/instruments/instruments.py index c5024cc..2817648 100644 --- a/flows/instruments/instruments.py +++ b/flows/instruments/instruments.py @@ -30,22 +30,26 @@ """ # Standard lib from __future__ import annotations -import sys + import inspect -from typing import Tuple, Union, Optional +import sys +from typing import List, Optional, Tuple, Union + +import astropy.coordinates as coords +import astropy.units as u # Third party import numpy as np -import astropy.units as u -import astropy.coordinates as coords from astropy.io import fits from astropy.time import Time from astropy.wcs import WCS # First party from tendrils import api + from flows.filters import FILTERS from flows.image import FlowsImage from flows.instruments.base_instrument import Instrument from flows.utilities import create_logger + logger = create_logger() @@ -583,5 +587,6 @@ def get_photfilter(self): - - +INSTRUMENTS: list[tuple[str, Instrument]] = inspect.getmembers(sys.modules[__name__], + lambda member: inspect.isclass(member) and member.__module__ == __name__) + \ No newline at end of file diff --git a/flows/load_image.py b/flows/load_image.py index 5151a1a..c00a0e8 100644 --- a/flows/load_image.py +++ b/flows/load_image.py @@ -1,18 +1,26 @@ """ Load image code. """ +# pyright: reportMissingTypeStubs=true from __future__ import annotations -import numpy as np -from typing import Union, Tuple + +from typing import Any, Tuple, Union + +import astropy import astropy.coordinates as coords +import numpy as np from astropy.io import fits +from astropy.io.fits import Header, PrimaryHDU from astropy.time import Time -from .instruments import INSTRUMENTS, verify_coordinates from .image import FlowsImage +from .instruments import INSTRUMENTS, verify_coordinates from .utilities import create_logger + logger = create_logger() +astropy.__version__ + def load_image(filename: str, target_coord: Union[coords.SkyCoord, Tuple[float, float]] = None): """ @@ -31,10 +39,11 @@ def load_image(filename: str, target_coord: Union[coords.SkyCoord, Tuple[float, ext = 0 # Default extension is 0, individual instruments may override this. # Read fits image, Structural Pattern Match to specific instrument. with fits.open(filename, mode='readonly') as hdul: - hdr = hdul[ext].header - origin = hdr.get('ORIGIN', '') - telescope = hdr.get('TELESCOP', '') - instrument = hdr.get('INSTRUME', '') + hdu: PrimaryHDU = hdul[ext] + hdr: Header = hdu.header + origin = str(hdr.get('ORIGIN', '')) + telescope = str(hdr.get('TELESCOP', '')) + instrument = str(hdr.get('INSTRUME', '')) for inst_name, inst_cls in INSTRUMENTS: if inst_cls.identifier(telescope, origin, instrument, hdr): @@ -44,7 +53,7 @@ def load_image(filename: str, target_coord: Union[coords.SkyCoord, Tuple[float, mask = inst_cls.get_mask(hdul) # Default = None is to only mask all non-finite values, override here is additive. - image = FlowsImage(image=np.asarray(hdul[ext].data, dtype='float64'), + image = FlowsImage(image=np.asarray(hdu.data, dtype='float64'), header=hdr, mask=mask) current_instrument = inst_cls(image) clean_image = current_instrument.process_image() diff --git a/setup.cfg b/setup.cfg index cc5e6fe..e7e4646 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [flake8] -exclude = .git,__pycache__,notes +exclude = .git,__pycache__,notes,.vscode,.pytest_cache,.idea,.github,.coverage,typings # To be compliant with black max-line-length = 120 #To be compliant with black @@ -45,7 +45,7 @@ ignore = [tool:pytest] addopts = --strict-markers --durations=10 -s testpaths = tests -xfail_strict = True +xfail_strict = False log_cli = True [coverage:run] diff --git a/tests/test_instruments.py b/tests/test_instruments.py index 493fe57..7790404 100644 --- a/tests/test_instruments.py +++ b/tests/test_instruments.py @@ -1,7 +1,10 @@ """ Test instruments module """ +import logging + import pytest + from flows.instruments import INSTRUMENTS @@ -9,7 +12,11 @@ def test_instruments(): for instrument_name, instrument_class in INSTRUMENTS: instrument = instrument_class() site = instrument.get_site() - assert site['siteid'] == instrument.siteid + logging.debug(f"{instrument_name}, site:{site['siteid']}, {instrument.siteid}") # set log_cli_level=10 to show. + if not instrument_name == "LCOGT": + assert site['siteid'] == instrument.siteid + else: + assert site['siteid'] == instrument.image if __name__ == '__main__': diff --git a/tests/test_tns.py b/tests/test_tns.py index 42dab79..3e83d54 100644 --- a/tests/test_tns.py +++ b/tests/test_tns.py @@ -6,9 +6,11 @@ .. codeauthor:: Rasmus Handberg """ -import pytest -import os import datetime +import logging +import os + +import pytest from astropy.coordinates import SkyCoord from tendrils.utils import tns @@ -51,8 +53,8 @@ def test_tns_get_obj_noexist(): @pytest.mark.skipif(os.environ.get('CI') == 'true', reason="Disabled on GitHub Actions to avoid too many requests HTTP error") @pytest.mark.parametrize('date_begin,date_end', - [('2019-01-01', '2019-02-01'), (datetime.date(2019, 1, 1), datetime.date(2019, 2, 1)), - (datetime.datetime(2019, 1, 1, 12, 0), datetime.datetime(2019, 2, 1, 12, 0))]) + [('2019-01-01', '2019-01-10'), (datetime.date(2019, 1, 1), datetime.date(2019, 1, 10)), + (datetime.datetime(2019, 1, 1, 12, 0), datetime.datetime(2019, 1, 10, 12, 0))]) def test_tns_getnames(date_begin, date_end): names = tns.tns_getnames(date_begin=date_begin, date_end=date_end, zmin=0, zmax=0.105, objtype=3) @@ -61,6 +63,7 @@ def test_tns_getnames(date_begin, date_end): for n in names: assert isinstance(n, str), "Each element should be a string" assert n.startswith('SN'), "All names should begin with 'SN'" + logging.debug(f"obtained names: {names}") assert 'SN2019A' in names, "SN2019A should be in the list" diff --git a/typings/astropy/__init__.pyi b/typings/astropy/__init__.pyi index 812cf75..21b275d 100644 --- a/typings/astropy/__init__.pyi +++ b/typings/astropy/__init__.pyi @@ -1 +1,87 @@ -def __getattr__(name: str) -> Any: ... \ No newline at end of file +""" +This type stub file was generated by pyright. +""" + +import os +import sys +from .version import version as __version__ +from . import config as _config +from .utils.state import ScienceState +from .tests.runner import TestRunner +from .logger import _init_log, _teardown_log +from .utils.misc import find_api_page +from types import ModuleType as __module_type__ + +""" +Astropy is a package intended to contain core functionality and some +common tools needed for performing astronomy and astrophysics research with +Python. It also provides an index for other astronomy packages and tools for +managing them. +""" +if 'dev' in __version__: + online_docs_root = ... +else: + online_docs_root = ... +class Conf(_config.ConfigNamespace): + """ + Configuration parameters for `astropy`. + """ + unicode_output = ... + use_color = ... + max_lines = ... + max_width = ... + + +conf = ... +class base_constants_version(ScienceState): + """ + Base class for the real version-setters below + """ + _value = ... + _versions = ... + @classmethod + def validate(cls, value): # -> str: + ... + + @classmethod + def set(cls, value): # -> _ScienceStateContext: + """ + Set the current constants value. + """ + ... + + + +class physical_constants(base_constants_version): + """ + The version of physical constants to use + """ + _value = ... + _versions = ... + + +class astronomical_constants(base_constants_version): + """ + The version of astronomical constants to use + """ + _value = ... + _versions = ... + + +test = ... +__citation__ = ... +log = ... +def online_help(query): # -> None: + """ + Search the online Astropy documentation for the given query. + Opens the results in the default web browser. Requires an active + Internet connection. + + Parameters + ---------- + query : str + The search query. + """ + ... + +__dir_inc__ = ... From 0b79d640c480b5a34eeb70d18bdefc11e0b9adc4 Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Thu, 5 Jan 2023 14:04:28 +0100 Subject: [PATCH 08/21] remove explicit none peakmax from TNG --- flows/instruments/instruments.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flows/instruments/instruments.py b/flows/instruments/instruments.py index 2817648..bab2f99 100644 --- a/flows/instruments/instruments.py +++ b/flows/instruments/instruments.py @@ -548,7 +548,6 @@ def get_photfilter(self): class TNG(Instrument): siteid = 5 # same as NOT - peakmax = None # Lluis did not provide this so it is in header?? #instrument = 'LRS' #telescope = 'TNG' From 2ab39d7167e13795d485179bbf4022c238c8c4d9 Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Wed, 8 Feb 2023 16:43:52 +0100 Subject: [PATCH 09/21] refactor image.py masking. --- flows/image.py | 39 +++++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/flows/image.py b/flows/image.py index a52b2cb..17d9cd8 100644 --- a/flows/image.py +++ b/flows/image.py @@ -2,12 +2,13 @@ from enum import Enum import numpy as np +from numpy.typing import NDArray from dataclasses import dataclass import warnings from typing import Union from astropy.time import Time from astropy.wcs import WCS, FITSFixedWarning -from typing import Tuple, Dict, Any, Optional +from typing import Tuple, Dict, Any, Optional, TypeGuard from .utilities import create_logger logger = create_logger() @@ -43,38 +44,52 @@ class FlowsImage: subclean: Optional[np.ma.MaskedArray] = None error: Optional[np.ma.MaskedArray] = None - def __post_init__(self): + def __post_init__(self) -> None: self.shape = self.image.shape self.wcs = self.create_wcs() - # Make empty mask - if self.mask is None: - self.mask = np.zeros_like(self.image, dtype='bool') - self.check_finite() + # Create mask + self.initialize_mask() + + def initialize_mask(self) -> None: + self.update_mask(self.mask) - def check_finite(self): - self.mask |= ~np.isfinite(self.image) + def check_finite(self) -> None: + if self.ensure_mask(self.mask): + self.mask |= ~np.isfinite(self.image) + + def mask_non_linear(self) -> None: + if self.peakmax is None: + return + if self.ensure_mask(self.mask): + self.mask |= self.image >= self.peakmax + + def ensure_mask(self, mask: Optional[np.ndarray]) -> TypeGuard[NDArray[np.bool_]]: + if mask is None: + self.mask = np.zeros_like(self.image, dtype='bool') + return True - def update_mask(self, mask): + def update_mask(self, mask) -> None: self.mask = mask self.check_finite() + self.mask_non_linear() def create_wcs(self) -> WCS: with warnings.catch_warnings(): warnings.simplefilter('ignore', category=FITSFixedWarning) return WCS(header=self.header, relax=True) - def create_masked_image(self): + def create_masked_image(self) -> None: """Warning: this is destructive and will overwrite image data setting masked values to NaN""" self.image[self.mask] = np.NaN self.clean = np.ma.masked_array(data=self.image, mask=self.mask, copy=False) - def set_edge_rows_to_value(self, y: Tuple[float] = None, value: Union[int, float, np.float64] = 0): + def set_edge_rows_to_value(self, y: Tuple[float] = None, value: Union[int, float, np.float64] = 0) -> None: if y is None: pass for row in y: self.image[row] = value - def set_edge_columns_to_value(self, x: Tuple[float] = None, value: Union[int, float, np.float64] = 0): + def set_edge_columns_to_value(self, x: Tuple[float] = None, value: Union[int, float, np.float64] = 0) -> None: if x is None: pass for col in x: From c86499aa45c22c6af1fe277f0209be8f2ba05392 Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Wed, 8 Feb 2023 21:29:19 +0100 Subject: [PATCH 10/21] ignore .config files --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 98f3808..16e998e 100644 --- a/.gitignore +++ b/.gitignore @@ -145,3 +145,5 @@ tests/output/ # Ignore VSCode typing clutter typings/ + +*.config From 996a992e1e7accb63819cdf245ae3bca5930d2cc Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Wed, 8 Feb 2023 21:29:55 +0100 Subject: [PATCH 11/21] add env lookup for AADC_DB username and pass --- flows/aadc_db.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/flows/aadc_db.py b/flows/aadc_db.py index 8d25ab9..a622c98 100644 --- a/flows/aadc_db.py +++ b/flows/aadc_db.py @@ -9,10 +9,11 @@ .. codeauthor:: Rasmus Handberg """ +import getpass +import os import psycopg2 as psql from psycopg2.extras import DictCursor -import getpass from tendrils.utils import load_config @@ -41,7 +42,7 @@ def __init__(self, username=None, password=None): config = load_config() if username is None: - username = config.get('database', 'username', fallback=None) + username = config.get('database', 'username', fallback=os.environ.get("AUDBUsername", None)) if username is None: default_username = getpass.getuser() username = input('Username [%s]: ' % default_username) @@ -49,7 +50,7 @@ def __init__(self, username=None, password=None): username = default_username if password is None: - password = config.get('database', 'password', fallback=None) + password = config.get('database', 'password', fallback=os.environ.get("AUDBPassword", None)) if password is None: password = getpass.getpass('Password: ') From b8c1fb1caecd8b0fa75d6c5d6ff81888c10850ee Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Wed, 8 Feb 2023 21:30:49 +0100 Subject: [PATCH 12/21] fix astropy SDSS query, catalog tests --- flows/catalogs.py | 139 +++++++++++++++++++++-------------------- requirements.txt | 2 +- tests/test_catalogs.py | 19 +++--- 3 files changed, 79 insertions(+), 81 deletions(-) diff --git a/flows/catalogs.py b/flows/catalogs.py index a25b7e0..3352ddf 100644 --- a/flows/catalogs.py +++ b/flows/catalogs.py @@ -4,25 +4,28 @@ .. codeauthor:: Rasmus Handberg """ - import logging +import os import os.path -import subprocess import shlex -import requests +import subprocess import warnings from io import BytesIO + import numpy as np -from bottleneck import anynan -from astropy.time import Time -from astropy.coordinates import SkyCoord, Angle +import requests from astropy import units as u -from astropy.table import Table, MaskedColumn -from astroquery.sdss import SDSS +from astropy.coordinates import Angle, SkyCoord +from astropy.table import MaskedColumn, Table +from astropy.time import Time +from astroquery import sdss from astroquery.simbad import Simbad +from bottleneck import anynan from tendrils.utils import load_config, query_ztf_id + from .aadc_db import AADC_DB +logger = logging.getLogger(__name__) # -------------------------------------------------------------------------------------------------- class CasjobsError(RuntimeError): @@ -55,9 +58,9 @@ def configure_casjobs(overwrite=False): .. codeauthor:: Rasmus Handberg """ - __dir__ = os.path.dirname(os.path.realpath(__file__)) casjobs_config = os.path.join(__dir__, 'casjobs', 'CasJobs.config') + logger.debug(",".join([casjobs_config,__dir__,os.path.realpath(__file__)])) if os.path.isfile(casjobs_config) and not overwrite: return @@ -100,8 +103,6 @@ def query_casjobs_refcat2(coo_centre, radius=24 * u.arcmin): .. codeauthor:: Rasmus Handberg """ - - logger = logging.getLogger(__name__) if isinstance(radius, (float, int)): radius *= u.deg @@ -128,7 +129,6 @@ def query_casjobs_refcat2(coo_centre, radius=24 * u.arcmin): # -------------------------------------------------------------------------------------------------- def _query_casjobs_refcat2_divide_and_conquer(coo_centre, radius): - logger = logging.getLogger(__name__) # Just put in a stop criterion to avoid infinite recursion: if radius < 0.04 * u.deg: @@ -172,7 +172,6 @@ def _query_casjobs_refcat2(coo_centre, radius=24 * u.arcmin): .. codeauthor:: Rasmus Handberg """ - logger = logging.getLogger(__name__) if isinstance(radius, (float, int)): radius *= u.deg @@ -298,7 +297,9 @@ def query_sdss(coo_centre, radius=24 * u.arcmin, dr=16, clean=True): if isinstance(radius, (float, int)): radius *= u.deg - AT_sdss = SDSS.query_region(coo_centre, photoobj_fields=['type', 'clean', 'ra', 'dec', 'psfMag_u'], data_release=dr, + #SDSS.MAX_CROSSID_RADIUS = radius + 1 * u.arcmin + sdss.conf.skyserver_baseurl = sdss.conf.skyserver_baseurl.replace("http://","https://") + AT_sdss = sdss.SDSS.query_region(coo_centre, photoobj_fields=['type', 'clean', 'ra', 'dec', 'psfMag_u'], data_release=dr, timeout=600, radius=radius) if AT_sdss is None: @@ -316,9 +317,9 @@ def query_sdss(coo_centre, radius=24 * u.arcmin, dr=16, clean=True): return None, None # Create SkyCoord object with the coordinates: - sdss = SkyCoord(ra=AT_sdss['ra'], dec=AT_sdss['dec'], unit=u.deg, frame='icrs') + sdss_coord = SkyCoord(ra=AT_sdss['ra'], dec=AT_sdss['dec'], unit=u.deg, frame='icrs') - return AT_sdss, sdss + return AT_sdss, sdss_coord # -------------------------------------------------------------------------------------------------- @@ -561,8 +562,6 @@ def download_catalog(target=None, radius=24 * u.arcmin, radius_ztf=3 * u.arcsec, .. codeauthor:: Rasmus Handberg """ - logger = logging.getLogger(__name__) - with AADC_DB() as db: # Get the information about the target from the database: @@ -623,54 +622,56 @@ def download_catalog(target=None, radius=24 * u.arcmin, radius_ztf=3 * u.arcsec, else: on_conflict = 'DO NOTHING' - try: - db.cursor.executemany("""INSERT INTO flows.refcat2 ( - starid, - ra, - decl, - pm_ra, - pm_dec, - gaia_mag, - gaia_bp_mag, - gaia_rp_mag, - gaia_variability, - u_mag, - g_mag, - r_mag, - i_mag, - z_mag, - "J_mag", - "H_mag", - "K_mag", - "V_mag", - "B_mag") - VALUES ( - %(starid)s, - %(ra)s, - %(decl)s, - %(pm_ra)s, - %(pm_dec)s, - %(gaia_mag)s, - %(gaia_bp_mag)s, - %(gaia_rp_mag)s, - %(gaia_variability)s, - %(u_mag)s, - %(g_mag)s, - %(r_mag)s, - %(i_mag)s, - %(z_mag)s, - %(J_mag)s, - %(H_mag)s, - %(K_mag)s, - %(V_mag)s, - %(B_mag)s) - ON CONFLICT """ + on_conflict + ";", results) - logger.info("%d catalog entries inserted for %s.", db.cursor.rowcount, target_name) - - # Mark the target that the catalog has been downloaded: - db.cursor.execute("UPDATE flows.targets SET catalog_downloaded=TRUE,ztf_id=%s WHERE targetid=%s;", - (ztf_id, targetid)) - db.conn.commit() - except: # noqa: E722, pragma: no cover - db.conn.rollback() - raise + # Avoid testing "ON CONFLICT" of postgres. Only test update/insert. + if update_existing: + try: + db.cursor.executemany("""INSERT INTO flows.refcat2 ( + starid, + ra, + decl, + pm_ra, + pm_dec, + gaia_mag, + gaia_bp_mag, + gaia_rp_mag, + gaia_variability, + u_mag, + g_mag, + r_mag, + i_mag, + z_mag, + "J_mag", + "H_mag", + "K_mag", + "V_mag", + "B_mag") + VALUES ( + %(starid)s, + %(ra)s, + %(decl)s, + %(pm_ra)s, + %(pm_dec)s, + %(gaia_mag)s, + %(gaia_bp_mag)s, + %(gaia_rp_mag)s, + %(gaia_variability)s, + %(u_mag)s, + %(g_mag)s, + %(r_mag)s, + %(i_mag)s, + %(z_mag)s, + %(J_mag)s, + %(H_mag)s, + %(K_mag)s, + %(V_mag)s, + %(B_mag)s) + ON CONFLICT """ + on_conflict + ";", results) + logger.info("%d catalog entries inserted for %s.", db.cursor.rowcount, target_name) + + # Mark the target that the catalog has been downloaded: + db.cursor.execute("UPDATE flows.targets SET catalog_downloaded=TRUE,ztf_id=%s WHERE targetid=%s;", + (ztf_id, targetid)) + db.conn.commit() + except: # noqa: E722, pragma: no cover + db.conn.rollback() + raise diff --git a/requirements.txt b/requirements.txt index bfa5c83..723e418 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,5 +17,5 @@ pytz sep astroalign > 2.3 networkx -astroquery >= 0.4.2 +astroquery >= 0.4.7dev8479 tendrils >= 0.1.5 diff --git a/tests/test_catalogs.py b/tests/test_catalogs.py index 9771d53..5d64614 100644 --- a/tests/test_catalogs.py +++ b/tests/test_catalogs.py @@ -6,15 +6,17 @@ .. codeauthor:: Rasmus Handberg """ -import pytest +import logging + +import conftest # noqa: F401 import numpy as np +import pytest from astropy.coordinates import SkyCoord from astropy.table import Table -import conftest # noqa: F401 + from flows import catalogs -# -------------------------------------------------------------------------------------------------- def test_query_simbad(): # Coordinates around test-object (2019yvr): coo_centre = SkyCoord(ra=256.727512, dec=30.271482, unit='deg', frame='icrs') @@ -27,7 +29,6 @@ def test_query_simbad(): results.pprint_all(50) -# -------------------------------------------------------------------------------------------------- def test_query_skymapper(): # Coordinates around test-object (2021aess): coo_centre = SkyCoord(ra=53.4505, dec=-19.495725, unit='deg', frame='icrs') @@ -40,12 +41,8 @@ def test_query_skymapper(): results.pprint_all(50) -# -------------------------------------------------------------------------------------------------- - -@pytest.mark.parametrize('ra,dec', [[256.727512, 30.271482], # 2019yvr - [58.59512, -19.18172], # 2009D - ]) -def test_download_catalog(SETUP_CONFIG, ra, dec): +def test_download_catalog(caplog, SETUP_CONFIG, ra: float = 256.727512, dec: float = 30.271482) -> None: + caplog.set_level(logging.DEBUG) # Check if CasJobs have been configured, and skip the entire test if it isn't. # This has to be done like this, to avoid problems when config.ini doesn't exist. try: @@ -57,7 +54,7 @@ def test_download_catalog(SETUP_CONFIG, ra, dec): coo_centre = SkyCoord(ra=ra, dec=dec, unit='deg', frame='icrs') tab = catalogs.query_all(coo_centre) - print(tab) + logging.debug(tab) assert isinstance(tab, Table), "Should return a Table" results = catalogs.convert_table_to_dict(tab) From a3da47f9e082015c39131f7b7acc855465ef538f Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Wed, 8 Feb 2023 21:31:18 +0100 Subject: [PATCH 13/21] test all instruments added to INSTRUMENTS --- flows/instruments/instruments.py | 36 ++++++++------------------------ tests/test_instruments.py | 10 +++++++-- 2 files changed, 17 insertions(+), 29 deletions(-) diff --git a/flows/instruments/instruments.py b/flows/instruments/instruments.py index bab2f99..2689d56 100644 --- a/flows/instruments/instruments.py +++ b/flows/instruments/instruments.py @@ -12,7 +12,7 @@ your instrument. See: ``` self.image.peakmax = self.peakmax - self.image.site = self.get_site() + self.image.site = self.get_site() self.image.exptime = self.get_exptime() self.image.obstime = self.get_obstime() self.image.photfilter = self.get_photfilter() @@ -20,8 +20,8 @@ Identifying the instrument for an image: -Each instrument can define (one or many) of `origin`, -`telescope`, `instrument` fields correspinding to the +Each instrument can define (one or many) of `origin`, +`telescope`, `instrument` fields correspinding to the standard fits headers to help uniquely identify itself. More advanced logic is possible using `unique_headers` field as a dict of key,value pairs in the header. Ex: @@ -33,7 +33,7 @@ import inspect import sys -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Type import astropy.coordinates as coords import astropy.units as u @@ -152,7 +152,6 @@ def get_ext(hdul: fits.HDUList, target_coords: coords.SkyCoord = None, raise RuntimeError(f"Could not find image extension that target is on!") - class ALFOSC(Instrument): # Obtained from http://www.not.iac.es/instruments/detectors/CCD14/LED-linearity/20181026-200-1x1.pdf peakmax = 80000 # For ALFOSC D, 1x1, 200; the standard for SNe. @@ -459,7 +458,6 @@ class TJO_MEIA2(Instrument): telescope = 'TJO' instrument = 'MEIA2' - def get_obstime(self): obstime = super().get_obstime() obstime += 0.5 * self.image.exptime * u.second @@ -523,9 +521,9 @@ def get_photfilter(self): class Schmidt(Instrument): siteid = 26 peakmax = 56_000 - telescope = '67/91 Schmidt Telescope' + telescope = '67/91 Schmidt Telescope' instrument = 'Moravian G4-16000LC' - origin = '' + origin = '' def get_obstime(self): obstime = Time(self.image.header['DATE-OBS'], format='isot', scale='utc', @@ -550,18 +548,16 @@ class TNG(Instrument): siteid = 5 # same as NOT #instrument = 'LRS' #telescope = 'TNG' - - unique_headers = {'TELESCOP':'TNG', 'INSTRUME':'LRS'} # assume we use unique headers? + unique_headers = {'TELESCOP':'TNG', 'INSTRUME':'LRS'} # assume we use unique headers? def get_obstime(self): return Time(self.image.header['DATE-OBS'], format='isot', scale='utc', location=self.image.site['EarthLocation']) - def get_exptime(self): exptime = super().get_exptime() - exptime *= int(self.image.header['EXPTIME']) + exptime *= int(self.image.header['EXPTIME']) return exptime def get_photfilter(self): @@ -571,21 +567,7 @@ def get_photfilter(self): return {'B_John_10': 'B', 'g_sdss_30':'g', 'r':'r_sdss_31', 'i_sdss_32':'i', 'u_sdss_29':'u', 'V_John_11':'V_John_11' }.get(ratir_filt) return ratir_filt - - - -INSTRUMENTS = inspect.getmembers(sys.modules[__name__], - lambda member: inspect.isclass(member) and member.__module__ == __name__) - -# instruments = {'LCOGT': LCOGT, 'HAWKI': HAWKI, 'ALFOSC': ALFOSC, 'NOTCAM': NOTCAM, 'PS1': PS1, 'Liverpool': Liverpool, -# 'Omega2000': Omega2000, 'Swope': Swope, 'Swope_newheader':Swope_newheader, 'Dupont': Dupont, 'Retrocam': -# RetroCam, 'Baade': Baade, -# 'Sofi': Sofi, 'EFOSC': EFOSC, 'AstroNIRCam': AstroNIRCam, 'OmegaCam': OmegaCam, 'AndiCam': AndiCam, -# 'PairTel': PairTel, 'TJO_Meia2': TJO_MEIA2, 'TJO_Meia3': TJO_MEIA3, 'RATIR': RATIR, "Schmidt": Schmidt, "AFOSC": AFOSC} - - -INSTRUMENTS: list[tuple[str, Instrument]] = inspect.getmembers(sys.modules[__name__], +INSTRUMENTS: List[Tuple[str, Type[Instrument]]] = inspect.getmembers(sys.modules[__name__], lambda member: inspect.isclass(member) and member.__module__ == __name__) - \ No newline at end of file diff --git a/tests/test_instruments.py b/tests/test_instruments.py index 7790404..f66e72d 100644 --- a/tests/test_instruments.py +++ b/tests/test_instruments.py @@ -1,15 +1,16 @@ """ Test instruments module """ +import inspect import logging import pytest -from flows.instruments import INSTRUMENTS +import flows.instruments def test_instruments(): - for instrument_name, instrument_class in INSTRUMENTS: + for instrument_name, instrument_class in flows.instruments.INSTRUMENTS: instrument = instrument_class() site = instrument.get_site() logging.debug(f"{instrument_name}, site:{site['siteid']}, {instrument.siteid}") # set log_cli_level=10 to show. @@ -19,5 +20,10 @@ def test_instruments(): assert site['siteid'] == instrument.image +def test_all_instruments_present(): + ins = inspect.getmembers(flows.instruments.instruments, + lambda member: inspect.isclass(member) and member.__module__ == flows.instruments.instruments.__name__) + assert len(flows.instruments.INSTRUMENTS) == len(ins) + if __name__ == '__main__': pytest.main([__file__]) From afe91406fca608c1b32205e1ea81d7e3cf97723d Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Wed, 8 Feb 2023 21:34:10 +0100 Subject: [PATCH 14/21] comment out test print clutter --- tests/test_catalogs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_catalogs.py b/tests/test_catalogs.py index 5d64614..f309eaa 100644 --- a/tests/test_catalogs.py +++ b/tests/test_catalogs.py @@ -26,7 +26,7 @@ def test_query_simbad(): assert isinstance(results, Table) assert isinstance(simbad, SkyCoord) assert len(results) > 0 - results.pprint_all(50) + #results.pprint_all(50) def test_query_skymapper(): @@ -38,11 +38,11 @@ def test_query_skymapper(): assert isinstance(results, Table) assert isinstance(skymapper, SkyCoord) assert len(results) > 0 - results.pprint_all(50) + #results.pprint_all(50) -def test_download_catalog(caplog, SETUP_CONFIG, ra: float = 256.727512, dec: float = 30.271482) -> None: - caplog.set_level(logging.DEBUG) +def test_download_catalog(SETUP_CONFIG, ra: float = 256.727512, dec: float = 30.271482) -> None: + # caplog.set_level(logging.DEBUG) # Check if CasJobs have been configured, and skip the entire test if it isn't. # This has to be done like this, to avoid problems when config.ini doesn't exist. try: From 5217d9bc79af9cf0009dc32716d4774d3fc27b49 Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Wed, 8 Feb 2023 21:39:31 +0100 Subject: [PATCH 15/21] add environment vars for casjobs --- flows/catalogs.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/flows/catalogs.py b/flows/catalogs.py index 3352ddf..90f52f9 100644 --- a/flows/catalogs.py +++ b/flows/catalogs.py @@ -7,6 +7,7 @@ import logging import os import os.path +from pickle import NONE import shlex import subprocess import warnings @@ -65,8 +66,8 @@ def configure_casjobs(overwrite=False): return config = load_config() - wsid = config.get('casjobs', 'wsid', fallback=None) - passwd = config.get('casjobs', 'password', fallback=None) + wsid = config.get('casjobs', 'wsid', fallback=os.environ.get("CASJOBS_WSID", None)) + passwd = config.get('casjobs', 'password', fallback=os.environ.get("CASJOBS_PASSWORD", None)) if wsid is None or passwd is None: raise CasjobsError("CasJobs WSID and PASSWORD not in config.ini") From ba4f4c8a12176b51c08c49f0fec58424b004206c Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Wed, 8 Feb 2023 21:46:58 +0100 Subject: [PATCH 16/21] update tests and undo version bump for GA --- .github/workflows/tests.yml | 12 ++++++++---- VERSION | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index cf137d6..2a9e24f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -47,15 +47,13 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macos-latest, windows-latest] - python-version: [3.9, '3.10'] + os: [ubuntu-latest, macos-latest] + python-version: ['3.10'] include: - os: ubuntu-latest pippath: ~/.cache/pip - os: macos-latest pippath: ~/Library/Caches/pip - - os: windows-latest - pippath: ~\AppData\Local\pip\Cache name: Python ${{ matrix.python-version }} on ${{ matrix.os }} runs-on: ${{ matrix.os }} @@ -102,6 +100,8 @@ jobs: env: FLOWS_CONFIG: ${{ secrets.FLOWS_CONFIG }} FLOWS_API_TOKEN: ${{ secrets.FLOWS_API_TOKEN }} + CASSJOBS_WSID: ${{ secrets.CASSJOBS_WSID }} + CASSJOBS_PASSWORD: ${{ secrets.CASSJOBS_PASSWORD }} run: | python -m pip install --upgrade pip wheel pip install -r requirements.txt @@ -113,6 +113,8 @@ jobs: env: FLOWS_CONFIG: ${{ secrets.FLOWS_CONFIG }} FLOWS_API_TOKEN: ${{ secrets.FLOWS_API_TOKEN }} + CASSJOBS_WSID: ${{ secrets.CASSJOBS_WSID }} + CASSJOBS_PASSWORD: ${{ secrets.CASSJOBS_PASSWORD }} run: pytest --cov - name: Upload coverage @@ -165,6 +167,8 @@ jobs: env: FLOWS_API_TOKEN: ${{ secrets.FLOWS_API_TOKEN }} FLOWS_CONFIG: ${{ secrets.FLOWS_CONFIG }} + CASSJOBS_WSID: ${{ secrets.CASSJOBS_WSID }} + CASSJOBS_PASSWORD: ${{ secrets.CASSJOBS_PASSWORD }} run: | python -m pip install --upgrade pip wheel pip install -r requirements.txt diff --git a/VERSION b/VERSION index ae624fb..bb251ec 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -master-v1.0.1 +master-v1.0.0 From 4527e976eea9ce78f28a5d305478e3a139ba69f1 Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Wed, 8 Feb 2023 23:19:10 +0100 Subject: [PATCH 17/21] add test for new site generation, flows instrument sites. --- flows/instruments/__init__.py | 3 ++ flows/instruments/sites.py | 27 +++++++++------- tests/test_sites.py | 60 +++++++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 12 deletions(-) create mode 100644 tests/test_sites.py diff --git a/flows/instruments/__init__.py b/flows/instruments/__init__.py index dd2eb5b..aaa90ab 100644 --- a/flows/instruments/__init__.py +++ b/flows/instruments/__init__.py @@ -1 +1,4 @@ from .instruments import INSTRUMENTS, Instrument, verify_coordinates +from .sites import Site + +__all__ = ["INSTRUMENTS", "Instrument", "verify_coordinates", "Site"] diff --git a/flows/instruments/sites.py b/flows/instruments/sites.py index 4650b86..f70edcd 100644 --- a/flows/instruments/sites.py +++ b/flows/instruments/sites.py @@ -1,9 +1,11 @@ from dataclasses import dataclass -from astropy.coordinates import EarthLocation -import astropy.units as u from typing import Optional + +import astropy.units as u +from astropy.coordinates import EarthLocation from tendrils import api + @dataclass class Site: siteid: int @@ -13,9 +15,9 @@ class Site: elevation: float earth_location: Optional[EarthLocation] = None site_keyword: Optional[str] = None - + def __post_init__(self): - if self.earth_location is None: + if self.earth_location is not None: return self.earth_location = EarthLocation( lat=self.latitude*u.deg, @@ -25,23 +27,24 @@ def __post_init__(self): @classmethod def from_flows(cls, siteid: int) -> 'Site': - return cls(**api.get_site(siteid)) - + site_dict = api.get_site(siteid) + site_dict['earth_location'] = site_dict.pop('EarthLocation') + return cls(**site_dict) + @classmethod def from_astropy(cls, sitename: str) -> 'Site': loc = EarthLocation.of_site(sitename) - return cls(siteid=999, sitename=sitename, - longitude=loc.long.value, latitude=loc.lat.value, + return cls(siteid=999, sitename=sitename, + longitude=loc.lon.value, latitude=loc.lat.value, elevation=loc.height.value, earth_location=loc) - + @classmethod def from_query(cls) -> 'Site': sitename = input('Enter a site name for logging: ') longitude = float(input('Enter longitude in degrees: ')) lat = float(input('Enter latitude in degrees: ')) elevation = float(input('Enter elevation in meters: ')) - siteid = 1000 # hardcoded for user defined site - return cls(siteid=999, sitename=sitename, + siteid = 999 # hardcoded for user defined site + return cls(siteid=siteid, sitename=sitename, longitude=longitude, latitude=lat, elevation=elevation) - \ No newline at end of file diff --git a/tests/test_sites.py b/tests/test_sites.py new file mode 100644 index 0000000..9369502 --- /dev/null +++ b/tests/test_sites.py @@ -0,0 +1,60 @@ +from typing import Optional + +import pytest +from tendrils import api + +from flows.instruments import INSTRUMENTS, Instrument, Site + + +def test_flows_sites(siteid: int = 9): + + site = Site.from_flows(siteid) + + instrument: Optional[Instrument] = None + for instrument_name, instrument_class in INSTRUMENTS: + if instrument_class.siteid == siteid: + instrument = instrument_class() + if instrument is None: + raise ValueError(f"Expected to find site and instrument with siteid {siteid} but found None in INSTRUMENTS") + + assert instrument.telescope == 'CA 3.5m' + assert site.siteid == siteid + assert site.sitename == site.sitename + assert site.site_keyword == site.site_keyword + + +def test_site_from_astropy_vs_flows(sitename: str = "paranal", siteid: int = 2): + site = Site.from_astropy(sitename) + assert site.sitename == sitename + + flows_site = Site.from_flows(siteid) + assert int(site.latitude) == int(flows_site.latitude) + assert int(site.longitude) == int(flows_site.longitude) + +def test_user_site(monkeypatch): + + # provided inputs + sitename = 'test' + longitude = 11.5 + lat = 12.5 + elevation = 1200 + + # creating iterator object + answers = iter([sitename, str(longitude), str(lat), str(elevation)]) + + monkeypatch.setattr('builtins.input', lambda name: next(answers)) + site = Site.from_query() + + assert site.sitename == 'test' + assert site.siteid == 999 + assert int(site.longitude) == int(longitude) + assert int(site.latitude) == int(lat) + assert int(elevation) == int(elevation) + + if site.earth_location is None: + raise ValueError(f"Expected to find site with earth location!") + assert int(site.earth_location.lat.value) == int(lat) + + +if __name__ == '__main__': + pytest.main([__file__]) From d5fcb37feccdaeb2bb0221f47402c4e4d878f516 Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Wed, 8 Feb 2023 23:22:21 +0100 Subject: [PATCH 18/21] Fix astroquery version until changes are in. --- requirements.txt | 2 +- tests/test_sites.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 723e418..05dd931 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,5 +17,5 @@ pytz sep astroalign > 2.3 networkx -astroquery >= 0.4.7dev8479 +astroquery == 0.4.6 tendrils >= 0.1.5 diff --git a/tests/test_sites.py b/tests/test_sites.py index 9369502..edb5916 100644 --- a/tests/test_sites.py +++ b/tests/test_sites.py @@ -1,7 +1,6 @@ from typing import Optional import pytest -from tendrils import api from flows.instruments import INSTRUMENTS, Instrument, Site From f3bc6e8f1fd79566917b6f02cfe008401648f202 Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Thu, 9 Feb 2023 09:55:37 +0000 Subject: [PATCH 19/21] Add target.from_tid, visibility tests --- flows/target.py | 9 ++++++++- flows/visibility.py | 4 +++- tests/test_sites.py | 32 +++++++++++++++++++++++++++++++- 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/flows/target.py b/flows/target.py index 0cfaf6b..3251649 100644 --- a/flows/target.py +++ b/flows/target.py @@ -58,9 +58,16 @@ def from_fid(cls, fid: int, datafile: Optional[Dict] = None) -> 'Target': """ Create target from fileid. """ - + datafile = datafile or api.get_datafile(fid) if datafile is None: raise ValueError(f'No datafile found for fid={fid}') d = api.get_target(datafile['target_name']) | datafile return cls.from_dict(d) + + @classmethod + def from_tid(cls, target_id: int) -> 'Target': + """ + Create target from target id. + """ + return cls.from_dict(api.get_target(target_id)) diff --git a/flows/visibility.py b/flows/visibility.py index bb6e108..5570fd6 100644 --- a/flows/visibility.py +++ b/flows/visibility.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -Target visibility plotting. +Target visibility plotting. @TODO: Move to flows-tools. .. codeauthor:: Rasmus Handberg @@ -20,6 +20,7 @@ from tendrils import api from .target import Target from typing import Optional +import warnings # -------------------------------------------------------------------------------------------------- @@ -41,6 +42,7 @@ def visibility(target: Target, siteid: Optional[int] = None, date=None, output=N """ logger = logging.getLogger(__name__) + warnings.warn(DeprecationWarning("This module is moved to SNFLOWS/flows-tools.")) if date is None: date = datetime.utcnow() diff --git a/tests/test_sites.py b/tests/test_sites.py index edb5916..65a7361 100644 --- a/tests/test_sites.py +++ b/tests/test_sites.py @@ -1,8 +1,18 @@ -from typing import Optional +from typing import Optional, List, Dict, Any import pytest +from matplotlib import pyplot as plt +import tempfile +from tendrils import api from flows.instruments import INSTRUMENTS, Instrument, Site +from flows.visibility import visibility +from flows.target import Target + + +@pytest.fixture(scope='session') +def flows_sites() -> List[Dict[str, Any]]: + return api.get_all_sites() # type: ignore def test_flows_sites(siteid: int = 9): @@ -30,6 +40,7 @@ def test_site_from_astropy_vs_flows(sitename: str = "paranal", siteid: int = 2): assert int(site.latitude) == int(flows_site.latitude) assert int(site.longitude) == int(flows_site.longitude) + def test_user_site(monkeypatch): # provided inputs @@ -55,5 +66,24 @@ def test_user_site(monkeypatch): assert int(site.earth_location.lat.value) == int(lat) +# Very basic due to being moved to flows tools +def test_site_visibility(flows_sites): + target = Target.from_tid(8) + with pytest.deprecated_call(): + with plt.ioff(): + ax = visibility(target, siteid=2) + assert isinstance(ax, plt.Axes) + + with tempfile.TemporaryDirectory() as tempdir: + with pytest.deprecated_call(): + with plt.ioff(): + plotpaths = visibility( + target, date="2023-01-01", output=tempdir + ) + + assert not isinstance(plotpaths, plt.Axes) + assert len(flows_sites) == len(plotpaths) + + if __name__ == '__main__': pytest.main([__file__]) From 7f89e4ea6b2417d83c5428cb986370772d9314f9 Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Thu, 9 Feb 2023 14:00:04 +0000 Subject: [PATCH 20/21] adjust coverage report --- .github/codecov.yml | 9 +++++++++ .github/workflows/tests.yml | 3 ++- 2 files changed, 11 insertions(+), 1 deletion(-) create mode 100644 .github/codecov.yml diff --git a/.github/codecov.yml b/.github/codecov.yml new file mode 100644 index 0000000..4c2d04e --- /dev/null +++ b/.github/codecov.yml @@ -0,0 +1,9 @@ +coverage: + status: + project: + default: + # basic + threshold: 5% + patch: + default: + threshold: 5% \ No newline at end of file diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2a9e24f..c3c0c20 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -119,8 +119,9 @@ jobs: - name: Upload coverage continue-on-error: true - uses: codecov/codecov-action@v2 + uses: codecov/codecov-action@v3 with: + fail_ci_if_error: true env_vars: OS,PYTHON verbose: true From 31d6962227e8488189463bfe8c4bae2fb3244cd1 Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Fri, 10 Feb 2023 19:58:00 +0100 Subject: [PATCH 21/21] ensure tests passing --- flows/target.py | 10 +++++++--- flows/visibility.py | 22 +++++++++++++--------- tests/test_sites.py | 6 +++--- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/flows/target.py b/flows/target.py index 3251649..25902c7 100644 --- a/flows/target.py +++ b/flows/target.py @@ -1,12 +1,13 @@ from dataclasses import dataclass -from typing import Optional, Dict +from typing import Dict, Optional import numpy as np -from numpy.typing import NDArray from astropy.coordinates import SkyCoord from astropy.wcs import WCS +from numpy.typing import NDArray from tendrils import api + @dataclass class Target: ra: float @@ -70,4 +71,7 @@ def from_tid(cls, target_id: int) -> 'Target': """ Create target from target id. """ - return cls.from_dict(api.get_target(target_id)) + target_pars = api.get_target(target_id) + return cls( + ra=target_pars['ra'], dec=target_pars['decl'], + name=target_pars['target_name'], id=target_pars['targetid']) diff --git a/flows/visibility.py b/flows/visibility.py index 5570fd6..5c99faf 100644 --- a/flows/visibility.py +++ b/flows/visibility.py @@ -9,18 +9,20 @@ import logging import os.path -import numpy as np -import matplotlib.pyplot as plt -from matplotlib.dates import DateFormatter +import warnings +from datetime import datetime +from typing import Optional + import astropy.units as u +import matplotlib.pyplot as plt +import numpy as np +from astropy.coordinates import AltAz, SkyCoord, get_moon, get_sun from astropy.time import Time -from datetime import datetime -from astropy.coordinates import SkyCoord, AltAz, get_sun, get_moon from astropy.visualization import quantity_support +from matplotlib.dates import DateFormatter from tendrils import api + from .target import Target -from typing import Optional -import warnings # -------------------------------------------------------------------------------------------------- @@ -73,9 +75,11 @@ def visibility(target: Target, siteid: Optional[int] = None, date=None, output=N if not overwrite and os.path.exists(plotpath): logger.info("File already exists: %s", plotpath) continue - + logger.debug(f"{site}, type: {type(site)}") # Observatory: - observatory = site['EarthLocation'] + observatory = site.get('EarthLocation', None) + if observatory is None: + observatory = site.get('earth_location') utcoffset = (site['longitude'] * u.deg / (360 * u.deg)) * 24 * u.hour # Create timestamps to calculate for: diff --git a/tests/test_sites.py b/tests/test_sites.py index 65a7361..e0eba5e 100644 --- a/tests/test_sites.py +++ b/tests/test_sites.py @@ -1,13 +1,13 @@ -from typing import Optional, List, Dict, Any +import tempfile +from typing import Any, Dict, List, Optional import pytest from matplotlib import pyplot as plt -import tempfile from tendrils import api from flows.instruments import INSTRUMENTS, Instrument, Site -from flows.visibility import visibility from flows.target import Target +from flows.visibility import visibility @pytest.fixture(scope='session')