Skip to content

Commit

Permalink
Merge pull request #175 from l3vels/feat/stop-execution
Browse files Browse the repository at this point in the history
Feat: add stop generating feature
  • Loading branch information
Chkhikvadze authored Sep 27, 2023
2 parents 46d09ae + 2496047 commit ff5d33b
Show file tree
Hide file tree
Showing 19 changed files with 317 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from utils.agent import convert_model_to_response
from config import Config
from memory.zep.zep_memory import ZepMemory
from models.config import ConfigModel
from typings.chat import ChatStatus

class AuthoritarianSpeaker(BaseAgent):
def __init__(
Expand Down Expand Up @@ -163,7 +165,18 @@ def run(self,
simulator.inject("Audience member", specified_topic)

while True:
status_config = ConfigModel.get_config_by_session_id(db, self.session_id, self.account)

if status_config.value == ChatStatus.STOPPED.value:
break

agent_id, agent_name, message = simulator.step()

db.session.refresh(status_config)

if status_config.value == ChatStatus.STOPPED.value:
break

ai_message = history.create_ai_message(message, None, agent_id)

if team.is_memory:
Expand Down
13 changes: 13 additions & 0 deletions apps/server/agents/agent_simulations/debates/agent_debates.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from agents.handle_agent_errors import handle_agent_error
from config import Config
from memory.zep.zep_memory import ZepMemory
from models.config import ConfigModel
from typings.chat import ChatStatus

class AgentDebates(BaseAgent):
def __init__(
Expand Down Expand Up @@ -129,7 +131,18 @@ def run(self,
simulator.inject("Moderator", specified_topic)

while n < max_iters:
status_config = ConfigModel.get_config_by_session_id(db, self.session_id, self.account)

if status_config.value == ChatStatus.STOPPED.value:
break

agent_id, agent_name, message = simulator.step()

db.session.refresh(status_config)

if status_config.value == ChatStatus.STOPPED.value:
break

ai_message = history.create_ai_message(message, None, agent_id)

if team.is_memory:
Expand Down
48 changes: 39 additions & 9 deletions apps/server/controllers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,22 @@
from agents.agent_simulations.authoritarian.authoritarian_speaker import AuthoritarianSpeaker
from agents.agent_simulations.debates.agent_debates import AgentDebates
from postgres import PostgresChatMessageHistory
from typings.chat import ChatMessageInput, NegotiateOutput
from typings.chat import ChatMessageInput, NegotiateOutput, ChatMessageOutput, ChatStopInput
from utils.chat import get_chat_session_id, has_team_member_mention, parse_agent_mention
from tools.get_tools import get_agent_tools
from models.agent import AgentModel
from models.datasource import DatasourceModel
from utils.agent import convert_model_to_response
from tools.datasources.get_datasource_tools import get_datasource_tools
from typings.chat import ChatMessageOutput
from typings.config import ConfigInput, ConfigOutput
from models.team import TeamModel
from models.config import ConfigModel
from agents.team_base import TeamOfAgentsType
from services.pubsub import ChatPubSubService, AzurePubSubService
from memory.zep.zep_memory import ZepMemory
from typings.chat import ChatStatus
from config import Config
from utils.configuration import convert_model_to_response as convert_config_model_to_response

router = APIRouter()

Expand All @@ -45,6 +47,8 @@ def create_chat_message(body: ChatMessageInput, auth: UserAccount = Depends(auth
team: TeamModel = None
team_configs = None
parent: ChatMessageModel = None

team_status_config: Optional[ConfigModel] = None

if body.parent_id:
parent = ChatMessageModel.get_chat_message_by_id(db, body.parent_id, auth.account)
Expand Down Expand Up @@ -149,6 +153,22 @@ def create_chat_message(body: ChatMessageInput, auth: UserAccount = Depends(auth
)

return plan_and_execute.run(settings, chat_pubsub_service, team, prompt, history, human_message_id)

team_status_config = ConfigModel.get_config_by_session_id(db, session_id, auth.account)

if team_status_config:
team_status_config.value = ChatStatus.RUNNING.value
db.session.add(team_status_config)
db.session.commit()
if not team_status_config:
team_status_config = ConfigModel.create_config(
db,
ConfigInput(key="status", value=ChatStatus.RUNNING.value, key_type="string", is_secret=False, is_required=False, session_id=session_id),
auth.user,
auth.account,
)

chat_pubsub_service.send_chat_status(config=convert_config_model_to_response(team_status_config).dict())

if team.team_type == TeamOfAgentsType.AUTHORITARIAN_SPEAKER.value:
topic = prompt
Expand All @@ -166,15 +186,13 @@ def create_chat_message(body: ChatMessageInput, auth: UserAccount = Depends(auth
word_limit=int(word_limit)
)

result = authoritarian_speaker.run(
authoritarian_speaker.run(
topic=topic,
team=team,
agents_with_configs=agents,
history=history,
)

return result

if team.team_type == TeamOfAgentsType.DEBATES.value:
topic = prompt
agents = [convert_model_to_response(item.agent) for item in team.team_agents if item.agent is not None]
Expand All @@ -189,20 +207,32 @@ def create_chat_message(body: ChatMessageInput, auth: UserAccount = Depends(auth
word_limit=int(word_limit)
)

result = agent_debates.run(
agent_debates.run(
topic=topic,
team=team,
agents_with_configs=agents,
history=history,
is_private_chat=body.is_private_chat
)

return result
team_status_config.value = ChatStatus.IDLE.value
db.session.add(team_status_config)
db.session.commit()

chat_pubsub_service.send_chat_status(config=convert_config_model_to_response(team_status_config).dict())

if team.team_type == TeamOfAgentsType.DECENTRALIZED_SPEAKERS.value:
pass
return ""


@router.post("/stop", status_code=201, response_model=ConfigOutput)
def stop_run(body: ChatStopInput, auth: UserAccount = Depends(authenticate)):
session_id = get_chat_session_id(auth.user.id, auth.account.id, body.is_private_chat, body.agent_id, body.team_id)
team_status_config = ConfigModel.get_config_by_session_id(db, session_id, auth.account)
team_status_config.value = ChatStatus.STOPPED.value
db.session.add(team_status_config)
db.session.commit()
return convert_config_model_to_response(team_status_config)


@router.get("", status_code=200, response_model=List[ChatMessageOutput])
def get_chat_messages(is_private_chat: bool, agent_id: Optional[UUID] = None, team_id: Optional[UUID] = None, auth: UserAccount = Depends(authenticate)):
Expand Down
24 changes: 24 additions & 0 deletions apps/server/models/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations
from typing import List, Optional
import uuid
from fastapi_sqlalchemy.middleware import DBSessionMeta

from sqlalchemy import Column, String, Boolean, UUID, ForeignKey, Index
from sqlalchemy.orm import relationship
from sqlalchemy.sql import and_, or_
from models.base_model import BaseModel
from typings.config import ConfigInput, ConfigQueryParams, AccountSettings
from typings.account import AccountOutput
from exceptions import ConfigNotFoundException
from utils.encyption import encrypt_data, decrypt_data, is_encrypted

Expand Down Expand Up @@ -42,6 +44,7 @@ class ConfigModel(BaseModel):
datasource_id = Column(UUID, ForeignKey('datasource.id', ondelete='CASCADE'), nullable=True, index=True)
team_id = Column(UUID, ForeignKey('team.id', ondelete='CASCADE'), nullable=True, index=True)
team_agent_id = Column(UUID, ForeignKey('team_agent.id', ondelete='CASCADE'), nullable=True, index=True)
session_id = Column(String, nullable=True, index=True)
value = Column(String)
key_type = Column(String)
is_secret = Column(Boolean)
Expand Down Expand Up @@ -168,6 +171,27 @@ def get_config_by_id(cls, db, config_id, account):

return config

@classmethod
def get_config_by_session_id(cls, db: DBSessionMeta, session_id: str, account: AccountOutput):
"""
Get Config from session_id
Args:
db: The database session.
session_id(str): Unique identifier of an Config.
account(AccountOutput): Account
Returns:
Config: Config object is returned.
"""
config = (
db.session.query(ConfigModel)
.filter(ConfigModel.session_id == session_id, ConfigModel.account_id == account.id, or_(or_(ConfigModel.is_deleted == False, ConfigModel.is_deleted is None), ConfigModel.is_deleted is None))
.first()
)

return config

@classmethod
def get_account_settings(cls, db, account) -> AccountSettings:
keys = ["open_api_key", "hugging_face_token"]
Expand Down
33 changes: 29 additions & 4 deletions apps/server/services/pubsub.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import sentry_sdk
import json
from typing import Dict, Optional, Any
import sentry_sdk
from azure.messaging.webpubsubservice import WebPubSubServiceClient
from azure.core.exceptions import AzureError
from azure.identity import DefaultAzureCredential
from config import Config
from typings.user import UserOutput

from datetime import datetime
from uuid import UUID

class AzurePubSubService:
def __init__(self):
Expand Down Expand Up @@ -33,7 +35,6 @@ def get_client_access_token(self, user_id):
sentry_sdk.capture_exception(err)



class ChatPubSubService:
def __init__(self, session_id: str, user: UserOutput, is_private_chat: bool, team_id: Optional[str] = None, agent_id: Optional[str] = None):
self.session_id = session_id
Expand All @@ -55,4 +56,28 @@ def send_chat_message(self, chat_message: Dict, local_chat_message_ref_id: Optio
'local_chat_message_ref_id': local_chat_message_ref_id,
'agent_id': self.agent_id,
'team_id': self.team_id,
})
})

def send_chat_status(self, config: Dict):
"""Sends chat status object"""
data = json.loads(json.dumps(config, cls=PubSubJSONEncoder))

self.azure_pubsub_service.send_to_group(self.session_id, message={
'type': 'CHAT_STATUS',
'from': str(self.user.id),
'config': data,
'is_private_chat': self.is_private_chat,
'agent_id': self.agent_id,
'team_id': self.team_id,
})


class PubSubJSONEncoder(json.JSONEncoder):
def default(self, obj: object):
if isinstance(obj, UUID):
# if the obj is uuid, we simply return the value of uuid
return str(obj)
if isinstance(obj, datetime):
# for datetime objects, convert to string in your preferred format
return obj.isoformat()
return super().default(obj)
11 changes: 11 additions & 0 deletions apps/server/typings/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
from uuid import UUID
from pydantic import BaseModel
from datetime import datetime
from enum import Enum

class ChatStatus(Enum):
IDLE = 'Idle'
RUNNING = 'Running'
STOPPED = 'Stopped'

class ChatMessageInput(BaseModel):
prompt: str
Expand Down Expand Up @@ -30,3 +36,8 @@ class ChatMessageOutput(BaseModel):

class NegotiateOutput(BaseModel):
url: str

class ChatStopInput(BaseModel):
is_private_chat: bool
agent_id: Optional[UUID] = None
team_id: Optional[UUID] = None
3 changes: 3 additions & 0 deletions apps/server/typings/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class ConfigInput(BaseModel):
datasource_id: Optional[UUID4]
workspace_id: Optional[UUID4]
team_id: Optional[UUID4]
session_id: Optional[UUID4]


class ConfigOutput(BaseModel):
Expand All @@ -30,6 +31,7 @@ class ConfigOutput(BaseModel):
is_deleted: bool
created_by: Optional[UUID4]
modified_by: Optional[UUID4]
session_id: Optional[UUID4]

class ConfigQueryParams(BaseModel):
id: Optional[str]
Expand All @@ -39,6 +41,7 @@ class ConfigQueryParams(BaseModel):
toolkit_id: Optional[UUID4]
datasource_id: Optional[UUID4]
workspace_id: Optional[UUID4]
session_id: Optional[UUID4]


class AccountSettings(BaseModel):
Expand Down
18 changes: 18 additions & 0 deletions apps/ui/src/gql/ai/config/configFragment.gql
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
fragment ConfigFragment on Config {
id
key
value
key_type
is_secret
is_required
agent_id
toolkit_id
datasource_id
team_id
account_id
workspace_id
session_id
is_deleted
created_by
modified_by
}
22 changes: 5 additions & 17 deletions apps/ui/src/gql/ai/config/configs.gql
Original file line number Diff line number Diff line change
@@ -1,19 +1,7 @@
query getConfigs @api(name: "ai") {
getConfigs @rest(type: "Config", path: "/config", method: "GET", endpoint: "ai") {
id
key
value
key_type
is_secret
is_required
agent_id
toolkit_id
datasource_id
team_id
account_id
workspace_id
is_deleted
created_by
modified_by
#import "./configFragment.gql"

query configs @api(name: "ai") {
configs @rest(type: "Config", path: "/config", method: "GET", endpoint: "ai") {
...ConfigFragment
}
}
18 changes: 3 additions & 15 deletions apps/ui/src/gql/ai/config/createConfig.gql
Original file line number Diff line number Diff line change
@@ -1,20 +1,8 @@
#import "./configFragment.gql"

mutation createConfig($input: input!) @api(name: ai) {
createConfig(input: $input)
@rest(type: "Config", path: "/config", method: "POST", bodyKey: "input", endpoint: "ai") {
id
key
value
key_type
is_secret
is_required
agent_id
toolkit_id
datasource_id
team_id
account_id
workspace_id
is_deleted
created_by
modified_by
...ConfigFragment
}
}
Loading

0 comments on commit ff5d33b

Please sign in to comment.