diff --git a/fafbseg/flywire/annotations.py b/fafbseg/flywire/annotations.py index a9cdac7..af16490 100644 --- a/fafbseg/flywire/annotations.py +++ b/fafbseg/flywire/annotations.py @@ -473,8 +473,8 @@ def list_cave_tables(*, dataset=None): @inject_dataset(disallowed=['flat_630', 'flat_571']) def get_cave_table_info(table_name: str, - *, - dataset=None): + *, + dataset=None): """Get info for given CAVE table. Parameters @@ -501,6 +501,7 @@ def get_cave_table_info(table_name: str, def get_cave_table(table_name: str, materialization='latest', split_positions: bool = False, + fill_user_info: bool = True, drop_invalid: bool = True, *, dataset: Optional[str] = None, @@ -519,6 +520,11 @@ def get_cave_table(table_name: str, this function will search all of them and concatenate the results (no deduplication). Set to ``False`` to fetch the non-materialized version. + fill_user_info : bool | full + Whether to fill in user information for the table. Only + relevant if table has a `user_id` column. If True, + will add a `user_name` column. If "full", will add + also add a `user_pi` column. split_positions : bool Whether to split x/y/z positions into separate columns. drop_invalid : bool @@ -569,11 +575,24 @@ def get_cave_table(table_name: str, else: raise ValueError('It is currently not possible to query the non-' 'materialized tables.') + + if fill_user_info and 'user_id' in data.columns: + user_info = get_user_information(data.user_id.unique(), dataset=dataset, raise_missing=False) + user_info = {r['id']: r for r in user_info if 'id' in r} + data['user_name'] = data.user_id.map(lambda x: user_info.get(x, {}).get('name', None)) + if fill_user_info == 'full': + data['user_pi'] = data.user_id.map(lambda x: user_info.get(x, {}).get('pi', None)) if drop_invalid and 'valid' in data.columns: data = data[data.valid == 't'].copy() data.drop('valid', axis=1, inplace=True) + # There is some weird interaction with pandas and the .attrs if the attrs contain numpy arrays + if getattr(data, 'attrs', None): + for k, v in data.attrs.items(): + if isinstance(v, np.ndarray): + data.attrs[k] = v.tolist() + return data