Skip to content

Commit

Permalink
flywire.get_cave_table: add option to fill in user info
Browse files Browse the repository at this point in the history
  • Loading branch information
schlegelp committed Mar 15, 2024
1 parent 9cfa11f commit 6ca528c
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions fafbseg/flywire/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,8 @@ def list_cave_tables(*, dataset=None):

@inject_dataset(disallowed=['flat_630', 'flat_571'])
def get_cave_table_info(table_name: str,
*,
dataset=None):
*,
dataset=None):
"""Get info for given CAVE table.
Parameters
Expand All @@ -501,6 +501,7 @@ def get_cave_table_info(table_name: str,
def get_cave_table(table_name: str,
materialization='latest',
split_positions: bool = False,
fill_user_info: bool = True,
drop_invalid: bool = True,
*,
dataset: Optional[str] = None,
Expand All @@ -519,6 +520,11 @@ def get_cave_table(table_name: str,
this function will search all of them and concatenate
the results (no deduplication).
Set to ``False`` to fetch the non-materialized version.
fill_user_info : bool | full
Whether to fill in user information for the table. Only
relevant if table has a `user_id` column. If True,
will add a `user_name` column. If "full", will add
also add a `user_pi` column.
split_positions : bool
Whether to split x/y/z positions into separate columns.
drop_invalid : bool
Expand Down Expand Up @@ -569,11 +575,24 @@ def get_cave_table(table_name: str,
else:
raise ValueError('It is currently not possible to query the non-'
'materialized tables.')

if fill_user_info and 'user_id' in data.columns:
user_info = get_user_information(data.user_id.unique(), dataset=dataset, raise_missing=False)
user_info = {r['id']: r for r in user_info if 'id' in r}
data['user_name'] = data.user_id.map(lambda x: user_info.get(x, {}).get('name', None))
if fill_user_info == 'full':
data['user_pi'] = data.user_id.map(lambda x: user_info.get(x, {}).get('pi', None))

if drop_invalid and 'valid' in data.columns:
data = data[data.valid == 't'].copy()
data.drop('valid', axis=1, inplace=True)

# There is some weird interaction with pandas and the .attrs if the attrs contain numpy arrays
if getattr(data, 'attrs', None):
for k, v in data.attrs.items():
if isinstance(v, np.ndarray):
data.attrs[k] = v.tolist()

return data


Expand Down

0 comments on commit 6ca528c

Please sign in to comment.