Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP:RAG #126

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions demo/rag_test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d8a54001-ef13-412d-95e1-2535e67e3ea6",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from bia_bob._rag import HintVectorStore\n",
"from bia_bob._machinery import Context\n",
"from bia_bob import bob"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "cc9188d2-ab7a-4bbd-8b21-3bae2885a705",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hint lines in System prompt: 191\n"
]
},
{
"data": {
"text/markdown": [
"I will load the `blobs.tif` image using `aicsimageio`, segment the bright blobs, and display the result.\n",
"\n"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"bob(\"\"\"\n",
"Load blobs.tif using aicsimageio\n",
"segment the bright blobs\n",
"show the result\n",
"\"\"\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c24dedd1-8c41-411d-84df-5880636b59c1",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"Context.hint_store = HintVectorStore()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "781c6beb-4239-4a13-bc58-7d55cf077d52",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hint lines in System prompt: 26\n"
]
},
{
"data": {
"text/markdown": [
"I will load the `blobs.tif` image using `aicsimageio`, segment the bright blobs, and display the result.\n",
"\n"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"bob(\"\"\"\n",
"Load blobs.tif using aicsimageio\n",
"segment the bright blobs\n",
"show the result\n",
"\"\"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bdb73805-6934-406a-8191-40fd8087778a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions src/bia_bob/_machinery.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class Context:
temperature = None # openai only
endpoint = None
api_key = None
hint_store = None

libraries = keep_available_packages([
"scikit-image",
Expand Down
121 changes: 121 additions & 0 deletions src/bia_bob/_rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
class HintVectorStore():
def __init__(self):
import os
import importlib_metadata
from importlib.metadata import entry_points
import pandas as pd
import yaml

# load cache from disk if it exists
home_dir = os.path.expanduser('~')
store_filename = os.path.join(home_dir, ".cache", "bia-bob", "bia_bob_vectore_store.yaml")
os.makedirs(os.path.dirname(store_filename), exist_ok=True)
if os.path.exists(store_filename):
#df_dict = pd.read_csv(store_filename)
with open(store_filename, mode="rt", encoding="utf-8") as test_df_to_yaml:
df_dict = pd.DataFrame(yaml.full_load(test_df_to_yaml)['vectorstore'])
db_dict = df_dict.to_dict(orient="list")
else:
# create empty cache
db_dict = {
"package":[],
"hint":[],
"vector":[],
}
df_dict = pd.DataFrame(db_dict)

# scan installed modules that are compatible plugins for hints
try:
bia_bob_plugins = entry_points(group='bia_bob_plugins')
except TypeError:
all_plugins = entry_points()
try:
bia_bob_plugins = all_plugins['bia_bob_plugins']
except KeyError:
bia_bob_plugins = []

all_modules = importlib_metadata.packages_distributions()
for b in bia_bob_plugins:
module_name = b.value.split(".")[0]
package_name = all_modules[module_name][0]
package_version = get_package_version(package_name)

package_name_and_version = package_name + "==" + package_version

# check if these functions are in the hint-vector store already. If not: add them
if df_dict[df_dict["package"] == package_name_and_version].size == 0:
print("BiA-Bob is scanning and caching", package_name_and_version)
func = b.load()
hints = func()
parse_hints_to_dict(hints, db_dict, package_name_and_version)

# convert to DataFrame and save to disk
df_dict = pd.DataFrame(db_dict)
with open(store_filename, 'w') as file:
documents = yaml.dump({'vectorstore': df_dict.to_dict(orient='records')}, file, default_flow_style=False)

# get all packages+versions in cache
unique_packages = df_dict['package'].unique().tolist()

# check if they are installed
installed_packages = []
for k, v in all_modules.items():
for package_name in v:
package_version = get_package_version(package_name)
package_name_and_version = package_name+"=="+package_version
if package_name_and_version in unique_packages or package_name in unique_packages:
installed_packages.append(package_name_and_version)

# only keep cache for installed packages
df_dict = df_dict[df_dict['package'].isin(installed_packages)]
# keep a dictionary {vector:hint}
self._vector_store = df_dict.set_index('vector')['hint'].to_dict()

def search(self, text, n_best_results=3):
import numpy as np
single_vector = np.asarray(embed(text))

# Step 1: inner products, vector
inner_products = [(np.dot(single_vector, np.asarray(vector)), vector) for vector in self._vector_store.keys()]

# Step 2: Sort inner products and get the three vectors with the maximum inner product
inner_products.sort()
inner_products.reverse()
closest_vectors = [vec for _, vec in inner_products[:n_best_results]] # Extract only the vectors

# Step 1: Compute Euclidean distances
#distances = [(np.linalg.norm(single_vector - np.asarray(vector)), vector) for vector in self._vector_store.keys()]

# Step 2: Sort distances and get the three vectors with the shortest distances
#distances.sort() # Sort based on the first element in the tuple (distance)
#closest_vectors = [vec for _, vec in distances[:n_best_results]] # Extract only the vectors

return [self._vector_store[tuple(v)] for v in closest_vectors]

def parse_hints_to_dict(hints, db_dict, package_name_and_version):
instructions = hints
while "\n " in instructions:
instructions = instructions.replace("\n ", "\n")
instructions = instructions.replace("\n \n", "\n\n").strip().split("\n\n")

for i in instructions:
e = embed(i)

db_dict["package"].append(package_name_and_version)
db_dict["hint"].append(i)
db_dict["vector"].append(tuple(e))

def embed(text):
from openai import OpenAI
client = OpenAI()
#print("embedding", text)

response = client.embeddings.create(
input=text,
model="text-embedding-3-small" # todo: make configurable
)
return response.data[0].embedding

def get_package_version(package_name):
from importlib.metadata import version
return version(package_name)
46 changes: 36 additions & 10 deletions src/bia_bob/_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,27 @@ def generate_response_to_user(model, user_prompt: str, image=None, additional_sy
text, plan, code = None, None, None

chat_backup = [c for c in Context.chat]

hints = None

if Context.hint_store is not None:

prompt = f"""
Split the following prompt into sub-tasks separated by two line breaks. Keep the text as it is otherwise:

{user_prompt}
"""

sub_tasks_text = generate_response(chat_history=[], image=None, model=Context.model, system_prompt="", user_prompt=prompt, vision_system_prompt="")

sub_tasks = [s.strip("\n") for s in sub_tasks_text.split("\n\n")]

hints = "\n\n".join(["\n\n".join(Context.hint_store.search(s)) for s in sub_tasks])


for attempt in range(1, max_number_attempts + 1):
if system_prompt is None:
system_prompt = create_system_prompt()
system_prompt = create_system_prompt(hints=hints)
if additional_system_prompt is not None:
system_prompt += "\n" + additional_system_prompt

Expand Down Expand Up @@ -116,10 +133,10 @@ def split_response(text):
return summary, plan, code


def create_system_prompt(reusable_variables_block=None):
def create_system_prompt(reusable_variables_block=None, hints=None):
"""Creates a system prompt that contains instructions of general interest, available functions and variables."""
# determine useful variables and functions in context

# if scikit-image is installed, give hints how to use it
from ._machinery import Context

Expand All @@ -129,11 +146,13 @@ def create_system_prompt(reusable_variables_block=None):
from skimage.io import imread
image = imread(filename)
```

* Expanding labels by a given radius in a label image works like this:
```
from skimage.segmentation import expand_labels
expanded_labels = expand_labels(label_image, distance=10)
```

* Measure properties of labels with respect to an image works like this:
```
from skimage.measure import regionprops
Expand Down Expand Up @@ -197,6 +216,19 @@ def create_system_prompt(reusable_variables_block=None):
else:
additional_snippets = ""

if hints is None:
hints = f"""
## Python specific code snippets

If the user asks for those simple tasks, use these code snippets.
{skimage_snippets}
{aicsimageio_snippets}
{czifile_snippets}
{additional_snippets}
"""

print("Hint lines in System prompt:", len(hints.split("\n")))

system_prompt = f"""
You are a extremely talented bioimage analyst and you use Python to solve your tasks unless stated otherwise.
If the request entails writing code, write concise professional bioimage analysis high-quality code.
Expand All @@ -205,13 +237,7 @@ def create_system_prompt(reusable_variables_block=None):

{reusable_variables_block}

## Python specific code snippets

If the user asks for those simple tasks, use these code snippets.
{skimage_snippets}
{aicsimageio_snippets}
{czifile_snippets}
{additional_snippets}
{hints}

## Todos

Expand Down