Skip to content

Commit

Permalink
Merge pull request #393 from MannLabs/chat_optimization
Browse files Browse the repository at this point in the history
Chat optimization around tokens and keeping important messages in history
  • Loading branch information
JuliaS92 authored Jan 22, 2025
2 parents 3c5ac58 + 92e368c commit 57de2c7
Show file tree
Hide file tree
Showing 5 changed files with 388 additions and 72 deletions.
83 changes: 71 additions & 12 deletions alphastats/gui/pages/06_LLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
init_session_state,
sidebar_info,
)
from alphastats.llm.llm_integration import LLMIntegration, Models
from alphastats.llm.llm_integration import LLMIntegration, MessageKeys, Models, Roles
from alphastats.llm.prompts import get_initial_prompt, get_system_message
from alphastats.plots.plot_utils import PlotlyObject

Expand Down Expand Up @@ -86,6 +86,14 @@ def llm_config():
else:
st.error(f"Connection to {model_name} failed: {str(error)}")

st.number_input(
"Maximal number of tokens",
value=st.session_state[StateKeys.MAX_TOKENS],
min_value=2000,
max_value=128000, # TODO: set this automatically based on the selected model
key=StateKeys.MAX_TOKENS,
)

if current_model != st.session_state[StateKeys.MODEL_NAME]:
st.rerun(scope="app")

Expand Down Expand Up @@ -261,6 +269,7 @@ def llm_config():
base_url=OLLAMA_BASE_URL,
dataset=st.session_state[StateKeys.DATASET],
genes_of_interest=list(regulated_genes_dict.keys()),
max_tokens=st.session_state[StateKeys.MAX_TOKENS],
)

st.session_state[StateKeys.LLM_INTEGRATION][model_name] = llm_integration
Expand All @@ -271,7 +280,7 @@ def llm_config():
)

with st.spinner("Processing initial prompt..."):
llm_integration.chat_completion(initial_prompt)
llm_integration.chat_completion(initial_prompt, pin_message=True)

st.rerun(scope="app")
except AuthenticationError:
Expand All @@ -282,7 +291,11 @@ def llm_config():


@st.fragment
def llm_chat(llm_integration: LLMIntegration, show_all: bool = False):
def llm_chat(
llm_integration: LLMIntegration,
show_all: bool = False,
show_individual_tokens: bool = False,
):
"""The chat interface for the LLM analysis."""

# TODO dump to file -> static file name, plus button to do so
Expand All @@ -291,10 +304,30 @@ def llm_chat(llm_integration: LLMIntegration, show_all: bool = False):
# Alternatively write it all in one pdf report using e.g. pdfrw and reportlab (I have code for that combo).

# no. tokens spent
total_tokens = 0
pinned_tokens = 0
for message in llm_integration.get_print_view(show_all=show_all):
with st.chat_message(message["role"]):
st.markdown(message["content"])
for artifact in message["artifacts"]:
with st.chat_message(message[MessageKeys.ROLE]):
st.markdown(message[MessageKeys.CONTENT])
tokens = llm_integration.estimate_tokens([message])
if message[MessageKeys.IN_CONTEXT]:
total_tokens += tokens
if message[MessageKeys.PINNED]:
pinned_tokens += tokens
if (
message[MessageKeys.PINNED]
or not message[MessageKeys.IN_CONTEXT]
or show_individual_tokens
):
token_message = ""
if message[MessageKeys.PINNED]:
token_message += ":pushpin: "
if not message[MessageKeys.IN_CONTEXT]:
token_message += ":x: "
if show_individual_tokens:
token_message += f"*tokens: {str(tokens)}*"
st.markdown(token_message)
for artifact in message[MessageKeys.ARTIFACTS]:
if isinstance(artifact, pd.DataFrame):
st.dataframe(artifact)
elif isinstance(
Expand All @@ -305,7 +338,17 @@ def llm_chat(llm_integration: LLMIntegration, show_all: bool = False):
st.warning("Don't know how to display artifact:")
st.write(artifact)

st.markdown(
f"*total tokens used: {str(total_tokens)}, tokens used for pinned messages: {str(pinned_tokens)}*"
)

if prompt := st.chat_input("Say something"):
with st.chat_message(Roles.USER):
st.markdown(prompt)
if show_individual_tokens:
st.markdown(
f"*tokens: {str(llm_integration.estimate_tokens([{MessageKeys.CONTENT:prompt}]))}*"
)
with st.spinner("Processing prompt..."):
llm_integration.chat_completion(prompt)
st.rerun(scope="fragment")
Expand All @@ -317,11 +360,27 @@ def llm_chat(llm_integration: LLMIntegration, show_all: bool = False):
"text/plain",
)

st.markdown(
"*icons: :pushpin: pinned message, :x: message no longer in context due to token limitations*"
)


show_all = st.checkbox(
"Show system messages",
key="show_system_messages",
help="Show all messages in the chat interface.",
)
c1, c2 = st.columns((1, 2))
with c1:
show_all = st.checkbox(
"Show system messages",
key="show_system_messages",
help="Show all messages in the chat interface.",
)
with c2:
show_inidvidual_tokens = st.checkbox(
"Show individual token estimates",
key="show_individual_tokens",
help="Show individual token estimates for each message.",
)

llm_chat(st.session_state[StateKeys.LLM_INTEGRATION][model_name], show_all)
llm_chat(
st.session_state[StateKeys.LLM_INTEGRATION][model_name],
show_all,
show_inidvidual_tokens,
)
4 changes: 4 additions & 0 deletions alphastats/gui/utils/ui_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def init_session_state() -> None:
DefaultStates.SELECTED_UNIPROT_FIELDS.copy()
)

if StateKeys.MAX_TOKENS not in st.session_state:
st.session_state[StateKeys.MAX_TOKENS] = 10000


class StateKeys(metaclass=ConstantsClass):
USER_SESSION_ID = "user_session_id"
Expand All @@ -160,6 +163,7 @@ class StateKeys(metaclass=ConstantsClass):
SELECTED_GENES_UP = "selected_genes_up"
SELECTED_GENES_DOWN = "selected_genes_down"
SELECTED_UNIPROT_FIELDS = "selected_uniprot_fields"
MAX_TOKENS = "max_tokens"

ORGANISM = "organism" # TODO this is essentially a constant

Expand Down
Loading

0 comments on commit 57de2c7

Please sign in to comment.