Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dynamically choosing proteins from volcano plot #389

Merged
merged 14 commits into from
Jan 16, 2025
Merged
7 changes: 2 additions & 5 deletions alphastats/gui/pages/05_Analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from alphastats.gui.utils.analysis_helper import (
display_analysis_result_with_buttons,
gather_parameters_and_do_analysis,
gather_uniprot_data,
get_regulated_features,
)
from alphastats.gui.utils.ui_helper import (
StateKeys,
Expand Down Expand Up @@ -93,10 +91,9 @@ def show_start_llm_button(analysis_method: str) -> None:
if submitted:
if StateKeys.LLM_INTEGRATION in st.session_state:
del st.session_state[StateKeys.LLM_INTEGRATION]
st.session_state[StateKeys.SELECTED_GENES_UP] = None
st.session_state[StateKeys.SELECTED_GENES_DOWN] = None
st.session_state[StateKeys.LLM_INPUT] = (analysis_object, parameters)
regulated_features = get_regulated_features(analysis_object)
# TODO: Add confirmation prompt if an excessive number of proteins is to be looked up.
gather_uniprot_data(regulated_features)

st.toast("LLM analysis created!", icon="✅")
st.page_link("pages/06_LLM.py", label="=> Go to LLM page..")
Expand Down
130 changes: 85 additions & 45 deletions alphastats/gui/pages/06_LLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from alphastats.dataset.plotting import plotly_object
from alphastats.gui.utils.analysis_helper import (
display_figure,
gather_uniprot_data,
)
from alphastats.gui.utils.llm_helper import (
display_uniprot,
get_display_proteins_html,
get_df_for_protein_selector,
llm_connection_test,
protein_selector,
set_api_key,
)
from alphastats.gui.utils.ui_helper import (
Expand Down Expand Up @@ -99,57 +101,85 @@ def llm_config():
volcano_plot, plot_parameters = st.session_state[StateKeys.LLM_INPUT]

st.markdown(f"Parameters used for analysis: `{plot_parameters}`")
c1, c2 = st.columns((1, 2))

with c2:
c1, c2, c3 = st.columns((1, 1, 1))

with c3:
st.markdown("##### Volcano plot")
display_figure(volcano_plot.plot)

regulated_genes_df = volcano_plot.res[volcano_plot.res["label"] != ""]
regulated_genes_dict = dict(
zip(regulated_genes_df[Cols.INDEX], regulated_genes_df["color"].tolist())
)

if not regulated_genes_dict:
st.text("No genes of interest found.")
st.stop()

# Separate upregulated and downregulated genes
upregulated_genes = [
key for key in regulated_genes_dict if regulated_genes_dict[key] == "up"
]
downregulated_genes = [
key for key in regulated_genes_dict if regulated_genes_dict[key] == "down"
]

# Create dataframes with checkboxes for selection
if st.session_state[StateKeys.SELECTED_GENES_UP] is None:
st.session_state[StateKeys.SELECTED_GENES_UP] = upregulated_genes
upregulated_genes_df = get_df_for_protein_selector(
upregulated_genes, st.session_state[StateKeys.SELECTED_GENES_UP]
)

if st.session_state[StateKeys.SELECTED_GENES_DOWN] is None:
st.session_state[StateKeys.SELECTED_GENES_DOWN] = downregulated_genes
downregulated_genes_df = get_df_for_protein_selector(
downregulated_genes, st.session_state[StateKeys.SELECTED_GENES_DOWN]
)


with c1:
regulated_genes_df = volcano_plot.res[volcano_plot.res["label"] != ""]
regulated_genes_dict = dict(
zip(regulated_genes_df[Cols.INDEX], regulated_genes_df["color"].tolist())
st.markdown("##### Genes of interest")
st.session_state[StateKeys.SELECTED_GENES_UP] = protein_selector(
upregulated_genes_df,
"Upregulated Proteins",
state_key=StateKeys.SELECTED_GENES_UP,
)

if not regulated_genes_dict:
st.text("No genes of interest found.")
st.stop()
with c2:
st.markdown("##### ")
st.session_state[StateKeys.SELECTED_GENES_DOWN] = protein_selector(
downregulated_genes_df,
"Downregulated Proteins",
state_key=StateKeys.SELECTED_GENES_DOWN,
)

upregulated_genes = [
key for key in regulated_genes_dict if regulated_genes_dict[key] == "up"
]
downregulated_genes = [
key for key in regulated_genes_dict if regulated_genes_dict[key] == "down"
]
# Combine the selected genes into a new regulated_genes_dict
selected_genes = (
st.session_state[StateKeys.SELECTED_GENES_UP]
+ st.session_state[StateKeys.SELECTED_GENES_DOWN]
)
regulated_genes_dict = {
gene: "up" if gene in st.session_state[StateKeys.SELECTED_GENES_UP] else "down"
for gene in selected_genes
}

# If no genes are selected, stop the script
if not regulated_genes_dict:
st.text("No genes selected for analysis.")
st.stop()

st.markdown("##### Genes of interest")
c11, c12 = st.columns((1, 2), gap="medium")
with c11:
st.write("Upregulated genes")
st.markdown(
get_display_proteins_html(
upregulated_genes,
True,
annotation_store=st.session_state[StateKeys.ANNOTATION_STORE],
feature_to_repr_map=st.session_state[
StateKeys.DATASET
]._feature_to_repr_map,
),
unsafe_allow_html=True,
)
if st.button("Gather UniProt data for selected proteins"):
gather_uniprot_data(selected_genes)

with c12:
st.write("Downregulated genes")
st.markdown(
get_display_proteins_html(
downregulated_genes,
False,
annotation_store=st.session_state[StateKeys.ANNOTATION_STORE],
feature_to_repr_map=st.session_state[
StateKeys.DATASET
]._feature_to_repr_map,
),
unsafe_allow_html=True,
)
if any(
feature not in st.session_state[StateKeys.ANNOTATION_STORE]
for feature in selected_genes
):
st.info(
"No UniProt data stored for some proteins. Please run UniProt data fetching first to ensure correct annotation from Protein IDs instead of gene names."
)


model_name = st.session_state[StateKeys.MODEL_NAME]
Expand Down Expand Up @@ -181,8 +211,18 @@ def llm_config():
"",
value=get_initial_prompt(
plot_parameters,
list(map(feature_to_repr_map.get, upregulated_genes)),
list(map(feature_to_repr_map.get, downregulated_genes)),
list(
map(
feature_to_repr_map.get,
st.session_state[StateKeys.SELECTED_GENES_UP],
)
),
list(
map(
feature_to_repr_map.get,
st.session_state[StateKeys.SELECTED_GENES_DOWN],
)
),
),
height=200,
disabled=llm_integration_set_for_model,
Expand Down
6 changes: 3 additions & 3 deletions alphastats/gui/utils/analysis_helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import io
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import pandas as pd
import streamlit as st
Expand Down Expand Up @@ -202,14 +202,14 @@ def gather_parameters_and_do_analysis(
raise ValueError(f"Analysis method {analysis_method} not found.")


def gather_uniprot_data(features: list) -> None:
def gather_uniprot_data(features: List[str]) -> None:
"""
Gathers UniProt data for a list of features and stores it in the session state.

Features that are already in the session state are skipped.

Args:
features (list): A list of features for which UniProt data needs to be gathered.
features (List[str]): A list of features for which UniProt data needs to be gathered.
Returns:
None
"""
Expand Down
87 changes: 80 additions & 7 deletions alphastats/gui/utils/llm_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
from typing import List, Optional

import pandas as pd
import streamlit as st

from alphastats.gui.utils.ui_helper import DefaultStates, StateKeys
Expand All @@ -11,6 +12,69 @@
)


@st.fragment
def protein_selector(df: pd.DataFrame, title: str, state_key: str) -> List[str]:
"""Creates a data editor for protein selection and returns the selected proteins.

Args:
df: DataFrame containing protein data with 'Gene', 'Selected', 'Protein' columns
title: Title to display above the editor

Returns:
selected_proteins (List[str]): A list of selected proteins.
"""
st.write(title)
c1, c2 = st.columns([1, 1])
if c1.button("Select all", help=f"Select all {title} for analysis"):
st.session_state[state_key] = df["Protein"].tolist()
st.rerun()
if c2.button("Select none", help=f"Select no {title} for analysis"):
st.session_state[state_key] = []
st.rerun()
edited_df = st.data_editor(
df,
column_config={
"Selected": st.column_config.CheckboxColumn(
"Include?",
help="Check to include this gene in analysis",
default=True,
),
"Gene": st.column_config.TextColumn(
"Gene",
help="The gene name to be included in the analysis",
width="medium",
),
},
disabled=["Gene"],
hide_index=True,
)
# Extract the selected genes
return edited_df.loc[edited_df["Selected"], "Protein"].tolist()


def get_df_for_protein_selector(
proteins: List[str], selected: List[str]
) -> pd.DataFrame:
"""Create a DataFrame for the protein selector.

Args:
proteins (List[str]): A list of proteins.

Returns:
pd.DataFrame: A DataFrame with 'Gene', 'Selected', 'Protein' columns.
"""
return pd.DataFrame(
{
"Gene": [
st.session_state[StateKeys.DATASET]._feature_to_repr_map[protein]
for protein in proteins
],
"Selected": [protein in selected for protein in proteins],
"Protein": proteins,
}
)


def get_display_proteins_html(
protein_ids: List[str], is_upregulated: True, annotation_store, feature_to_repr_map
) -> str:
Expand Down Expand Up @@ -160,13 +224,22 @@ def display_uniprot(regulated_genes_dict, feature_to_repr_map, disabled=False):
# TODO: Fix desync on rerun (widget state not updated on rerun, value becomes ind0)
preview_feature = st.selectbox(
"Feature id",
options=list(regulated_genes_dict.keys()),
options=[
feature
for feature in regulated_genes_dict
if feature in st.session_state[StateKeys.ANNOTATION_STORE]
],
format_func=lambda x: feature_to_repr_map[x],
)
st.markdown(f"Text generated from feature id {preview_feature}:")
st.markdown(
format_uniprot_annotation(
st.session_state[StateKeys.ANNOTATION_STORE][preview_feature],
fields=st.session_state[StateKeys.SELECTED_UNIPROT_FIELDS],
if preview_feature is not None:
uniprot_url = "https://www.uniprot.org/uniprotkb/"
st.markdown(
f"[Open in Uniprot ...]({uniprot_url + st.session_state[StateKeys.ANNOTATION_STORE][preview_feature]['primaryAccession']})"
)
st.markdown(f"Text generated from feature id {preview_feature}:")
st.markdown(
format_uniprot_annotation(
st.session_state[StateKeys.ANNOTATION_STORE][preview_feature],
fields=st.session_state[StateKeys.SELECTED_UNIPROT_FIELDS],
)
)
)
8 changes: 8 additions & 0 deletions alphastats/gui/utils/ui_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ def init_session_state() -> None:
if StateKeys.ANNOTATION_STORE not in st.session_state:
st.session_state[StateKeys.ANNOTATION_STORE] = {}

if StateKeys.SELECTED_GENES_UP not in st.session_state:
st.session_state[StateKeys.SELECTED_GENES_UP] = None

if StateKeys.SELECTED_GENES_DOWN not in st.session_state:
st.session_state[StateKeys.SELECTED_GENES_DOWN] = None

JuliaS92 marked this conversation as resolved.
Show resolved Hide resolved
if StateKeys.SELECTED_UNIPROT_FIELDS not in st.session_state:
st.session_state[StateKeys.SELECTED_UNIPROT_FIELDS] = (
DefaultStates.SELECTED_UNIPROT_FIELDS.copy()
Expand All @@ -151,6 +157,8 @@ class StateKeys(metaclass=ConstantsClass):
LLM_INPUT = "llm_input"
LLM_INTEGRATION = "llm_integration"
ANNOTATION_STORE = "annotation_store"
SELECTED_GENES_UP = "selected_genes_up"
JuliaS92 marked this conversation as resolved.
Show resolved Hide resolved
SELECTED_GENES_DOWN = "selected_genes_down"
SELECTED_UNIPROT_FIELDS = "selected_uniprot_fields"

ORGANISM = "organism" # TODO this is essentially a constant
Loading