Skip to content

Commit

Permalink
memory context added
Browse files Browse the repository at this point in the history
  • Loading branch information
spike-spiegel-21 committed Oct 6, 2024
1 parent 29178a4 commit e1e3e6d
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 27 deletions.
59 changes: 58 additions & 1 deletion mem0/configs/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from typing import Any, Dict, Optional

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator

from mem0.embeddings.configs import EmbedderConfig
from mem0.graphs.configs import GraphStoreConfig
Expand Down Expand Up @@ -72,3 +72,60 @@ class AzureConfig(BaseModel):
azure_deployment: str = Field(description="The name of the Azure deployment.", default=None)
azure_endpoint: str = Field(description="The endpoint URL for the Azure service.", default=None)
api_version: str = Field(description="The version of the Azure API being used.", default=None)


class MemoryContext(BaseModel):
user_id: Optional[str] = None
agent_id: Optional[str] = None
run_id: Optional[str] = None
metadata: Optional[Dict[str, str]] = None
filters: Optional[Dict[str, str]] = None

@model_validator(mode='before')
def check_at_least_one_id(cls, values):
"""
Ensure at least one of 'user_id', 'agent_id', or 'run_id' is provided.
This validator runs before initializing the model.
"""
user_id = values.get('user_id')
agent_id = values.get('agent_id')
run_id = values.get('run_id')

if not any([user_id, agent_id, run_id]):
raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be provided!")

# Ensure metadata and filters are initialized as empty dicts if None
if values.get('metadata') is None:
values['metadata'] = {}
if values.get('filters') is None:
values['filters'] = {}

return values

def prepare_metadata(self):
"""
Prepare the metadata and ensure it includes the user, agent, and run IDs.
"""
metadata = self.metadata or {}
if self.user_id:
metadata["user_id"] = self.user_id
if self.agent_id:
metadata["agent_id"] = self.agent_id
if self.run_id:
metadata["run_id"] = self.run_id

return metadata

def prepare_filters(self):
"""
Prepare the filters and ensure it includes the user, agent, and run IDs.
"""
filters = self.filters or {}
if self.user_id:
filters["user_id"] = self.user_id
if self.agent_id:
filters["agent_id"] = self.agent_id
if self.run_id:
filters["run_id"] = self.run_id

return filters
45 changes: 19 additions & 26 deletions mem0/memory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytz
from pydantic import ValidationError

from mem0.configs.base import MemoryConfig, MemoryItem
from mem0.configs.base import MemoryConfig, MemoryContext, MemoryItem
from mem0.configs.prompts import get_update_memory_messages
from mem0.memory.base import MemoryBase
from mem0.memory.setup import setup_config
Expand Down Expand Up @@ -83,23 +83,20 @@ def add(
Returns:
dict: A dictionary containing the result of the memory addition operation.
"""
if metadata is None:
metadata = {}

filters = filters or {}
if user_id:
filters["user_id"] = metadata["user_id"] = user_id
if agent_id:
filters["agent_id"] = metadata["agent_id"] = agent_id
if run_id:
filters["run_id"] = metadata["run_id"] = run_id

if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")


if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]

memory_context = MemoryContext(
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
metadata=metadata,
filters=filters
)
metadata = memory_context.prepare_metadata(),
filters = memory_context.prepare_filters()

with concurrent.futures.ThreadPoolExecutor() as executor:
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters)
future2 = executor.submit(self._add_to_graph, messages, filters)
Expand Down Expand Up @@ -360,17 +357,13 @@ def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, fil
Returns:
list: List of search results.
"""
filters = filters or {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id

if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")

memory_context = MemoryContext(
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
filters=filters
)
filters = memory_context.prepare_filters(filters)
capture_event(
"mem0.search",
self,
Expand Down

0 comments on commit e1e3e6d

Please sign in to comment.