diff --git a/pkg-py/src/shinychat/_chat.py b/pkg-py/src/shinychat/_chat.py index 9443993..57aa626 100644 --- a/pkg-py/src/shinychat/_chat.py +++ b/pkg-py/src/shinychat/_chat.py @@ -29,7 +29,6 @@ css, ) from shiny import reactive -from shiny._deprecated import warn_deprecated from shiny.bookmark import BookmarkState, RestoreState from shiny.bookmark._types import BookmarkStore from shiny.module import ResolvedId, resolve_id @@ -39,7 +38,7 @@ require_active_session, session_context, ) -from shiny.types import MISSING, MISSING_TYPE, Jsonifiable, NotifyException +from shiny.types import Jsonifiable, NotifyException from shiny.ui.css import CssUnit, as_css_unit from shiny.ui.fill import as_fill_item, as_fillable_container @@ -53,27 +52,7 @@ set_chatlas_state, ) from ._chat_normalize import normalize_message, normalize_message_chunk -from ._chat_provider_types import ( - AnthropicMessage, - GoogleMessage, - LangChainMessage, - OllamaMessage, - OpenAIMessage, - ProviderMessage, - ProviderMessageFormat, - as_provider_message, -) -from ._chat_tokenizer import ( - TokenEncoding, - TokenizersEncoding, - get_default_tokenizer, -) -from ._chat_types import ( - ChatMessage, - ChatMessageDict, - ClientMessage, - TransformedMessage, -) +from ._chat_types import ChatMessage, ChatMessageDict, ClientMessage from ._html_deps_py_shiny import chat_deps if TYPE_CHECKING: @@ -92,24 +71,6 @@ # TODO: UserInput might need to be a list of dicts if we want to support multiple # user input content types -TransformUserInput = Callable[[str], Union[str, None]] -TransformUserInputAsync = Callable[[str], Awaitable[Union[str, None]]] -TransformAssistantResponse = Callable[[str], Union[str, HTML, None]] -TransformAssistantResponseAsync = Callable[ - [str], Awaitable[Union[str, HTML, None]] -] -TransformAssistantResponseChunk = Callable[ - [str, str, bool], Union[str, HTML, None] -] -TransformAssistantResponseChunkAsync = Callable[ - [str, str, bool], Awaitable[Union[str, HTML, None]] -] -TransformAssistantResponseFunction = Union[ - TransformAssistantResponse, - TransformAssistantResponseAsync, - TransformAssistantResponseChunk, - TransformAssistantResponseChunkAsync, -] UserSubmitFunction0 = Union[ Callable[[], None], Callable[[], Awaitable[None]], @@ -181,22 +142,6 @@ async def handle_user_input(user_input: str): id A unique identifier for the chat session. In Shiny Core, make sure this id matches a corresponding :func:`~shiny.ui.chat_ui` call in the UI. - messages - A sequence of messages to display in the chat. A given message can be one of the - following: - - * A string, which is interpreted as markdown and rendered to HTML on the client. - * To prevent interpreting as markdown, mark the string as - :class:`~shiny.ui.HTML`. - * A UI element (specifically, a :class:`~shiny.ui.TagChild`). - * This includes :class:`~shiny.ui.TagList`, which take UI elements - (including strings) as children. In this case, strings are still - interpreted as markdown as long as they're not inside HTML. - * A dictionary with `content` and `role` keys. The `content` key can contain a - content as described above, and the `role` key can be "assistant" or "user". - - **NOTE:** content may include specially formatted **input suggestion** links - (see `.append_message()` for more information). on_error How to handle errors that occur in response to user input. When `"unhandled"`, the app will stop running when an error occurs. Otherwise, a notification @@ -207,32 +152,19 @@ async def handle_user_input(user_input: str): * `"actual"`: Display the actual error message to the user. * `"sanitize"`: Sanitize the error message before displaying it to the user. * `"unhandled"`: Do not display any error message to the user. - tokenizer - The tokenizer to use for calculating token counts, which is required to impose - `token_limits` in `.messages()`. If not provided, a default generic tokenizer - is attempted to be loaded from the tokenizers library. A specific tokenizer - may also be provided by following the `TokenEncoding` (tiktoken or tozenizers) - protocol (e.g., `tiktoken.encoding_for_model("gpt-4o")`). """ def __init__( self, id: str, *, - messages: Sequence[Any] = (), on_error: Literal["auto", "actual", "sanitize", "unhandled"] = "auto", - tokenizer: TokenEncoding | None = None, ): if not isinstance(id, str): raise TypeError("`id` must be a string.") self.id = resolve_id(id) self.user_input_id = ResolvedId(f"{self.id}_user_input") - self._transform_user: TransformUserInputAsync | None = None - self._transform_assistant: ( - TransformAssistantResponseChunkAsync | None - ) = None - self._tokenizer = tokenizer # TODO: remove the `None` when this PR lands: # https://github.com/posit-dev/py-shiny/pull/793/files @@ -255,10 +187,6 @@ def __init__( # For tracking message stream state when entering/exiting nested streams self._message_stream_checkpoint: str = "" - # If a user input message is transformed into a response, we need to cancel - # the next user input submit handling - self._suspend_input_handler: bool = False - # Keep track of effects so we can destroy them when the chat is destroyed self._effects: list[Effect_] = [] self._cancel_bookmarking_callbacks: CancelCallback | None = None @@ -266,13 +194,13 @@ def __init__( # Initialize chat state and user input effect with session_context(self._session): # Initialize message state - self._messages: reactive.Value[tuple[TransformedMessage, ...]] = ( + self._messages: reactive.Value[tuple[ChatMessage, ...]] = ( reactive.Value(()) ) - self._latest_user_input: reactive.Value[ - TransformedMessage | None - ] = reactive.Value(None) + self._latest_user_input: reactive.Value[ChatMessage | None] = ( + reactive.Value(None) + ) @reactive.extended_task async def _mock_task() -> str: @@ -282,19 +210,6 @@ async def _mock_task() -> str: reactive.ExtendedTask[[], str] ] = reactive.Value(_mock_task) - # TODO: deprecate messages once we start promoting managing LLM message - # state through other means - async def _append_init_messages(): - for msg in messages: - await self.append_message(msg) - - @reactive.effect - async def _init_chat(): - await _append_init_messages() - - self._append_init_messages = _append_init_messages - self._init_chat = _init_chat - # When user input is submitted, transform, and store it in the chat state # (and make sure this runs before other effects since when the user # calls `.messages()`, they should get the latest user input) @@ -302,21 +217,9 @@ async def _init_chat(): @reactive.event(self._user_input) async def _on_user_input(): msg = ChatMessage(content=self._user_input(), role="user") - # It's possible that during the transform, a message is appended, so get - # the length now, so we can insert the new message at the right index - n_pre = len(self._messages()) - msg_post = await self._transform_message(msg) - if msg_post is not None: - self._store_message(msg_post) - self._suspend_input_handler = False - else: - # A transformed value of None is a special signal to suspend input - # handling (i.e., don't generate a response) - self._store_message(msg, index=n_pre) - await self._remove_loading_message() - self._suspend_input_handler = True - - self._effects.append(_init_chat) + self._store_message(msg) + await self._remove_loading_message() + self._effects.append(_on_user_input) # Prevent repeated calls to Chat() with the same id from accumulating effects @@ -371,23 +274,14 @@ def create_effect(fn: UserSubmitFunction): @reactive.effect @reactive.event(self._user_input) async def handle_user_input(): - if self._suspend_input_handler: - from shiny import req - - req(False) try: if len(fn_params) > 1: raise ValueError( "A on_user_submit function should not take more than 1 argument" ) elif len(fn_params) == 1: - input = self.user_input(transform=True) - # The line immediately below handles the possibility of input - # being transformed to None. Technically, input should never be - # None at this point (since the handler should be suspended). - input = "" if input is None else input afunc = _utils.wrap_async(cast(UserSubmitFunction1, fn)) - await afunc(input) + await afunc(self.user_input()) else: afunc = _utils.wrap_async(cast(UserSubmitFunction0, fn)) await afunc() @@ -415,160 +309,27 @@ async def _raise_exception( msg = f"Error in Chat('{self.id}'): {str(e)}" raise NotifyException(msg, sanitize=sanitize) from e - @overload - def messages( - self, - *, - format: Literal["anthropic"], - token_limits: tuple[int, int] | None = None, - transform_user: Literal["all", "last", "none"] = "all", - transform_assistant: bool = False, - ) -> tuple[AnthropicMessage, ...]: ... - - @overload - def messages( - self, - *, - format: Literal["google"], - token_limits: tuple[int, int] | None = None, - transform_user: Literal["all", "last", "none"] = "all", - transform_assistant: bool = False, - ) -> tuple[GoogleMessage, ...]: ... - - @overload - def messages( - self, - *, - format: Literal["langchain"], - token_limits: tuple[int, int] | None = None, - transform_user: Literal["all", "last", "none"] = "all", - transform_assistant: bool = False, - ) -> tuple[LangChainMessage, ...]: ... - - @overload - def messages( - self, - *, - format: Literal["openai"], - token_limits: tuple[int, int] | None = None, - transform_user: Literal["all", "last", "none"] = "all", - transform_assistant: bool = False, - ) -> tuple[OpenAIMessage, ...]: ... - - @overload - def messages( - self, - *, - format: Literal["ollama"], - token_limits: tuple[int, int] | None = None, - transform_user: Literal["all", "last", "none"] = "all", - transform_assistant: bool = False, - ) -> tuple[OllamaMessage, ...]: ... - - @overload - def messages( - self, - *, - format: MISSING_TYPE = MISSING, - token_limits: tuple[int, int] | None = None, - transform_user: Literal["all", "last", "none"] = "all", - transform_assistant: bool = False, - ) -> tuple[ChatMessageDict, ...]: ... - - def messages( - self, - *, - format: MISSING_TYPE | ProviderMessageFormat = MISSING, - token_limits: tuple[int, int] | None = None, - transform_user: Literal["all", "last", "none"] = "all", - transform_assistant: bool = False, - ) -> tuple[ChatMessageDict | ProviderMessage, ...]: + def messages(self) -> tuple[ChatMessageDict, ...]: """ Reactively read chat messages - Obtain chat messages within a reactive context. The default behavior is - intended for passing messages along to a model for response generation where - you typically want to: - - 1. Cap the number of tokens sent in a single request (i.e., `token_limits`). - 2. Apply user input transformations (i.e., `transform_user`), if any. - 3. Not apply assistant response transformations (i.e., `transform_assistant`) - since these are predominantly for display purposes (i.e., the model shouldn't - concern itself with how the responses are displayed). - - Parameters - ---------- - format - The message format to return. The default value of `MISSING` means - chat messages are returned as :class:`ChatMessage` objects (a dictionary - with `content` and `role` keys). Other supported formats include: - - * `"anthropic"`: Anthropic message format. - * `"google"`: Google message (aka content) format. - * `"langchain"`: LangChain message format. - * `"openai"`: OpenAI message format. - * `"ollama"`: Ollama message format. - token_limits - Limit the conversation history based on token limits. If specified, only - the most recent messages that fit within the token limits are returned. This - is useful for avoiding "exceeded token limit" errors when sending messages - to the relevant model, while still providing the most recent context available. - A specified value must be a tuple of two integers. The first integer is the - maximum number of tokens that can be sent to the model in a single request. - The second integer is the amount of tokens to reserve for the model's response. - Note that token counts based on the `tokenizer` provided to the `Chat` - constructor. - transform_user - Whether to return user input messages with transformation applied. This only - matters if a `transform_user_input` was provided to the chat constructor. - The default value of `"all"` means all user input messages are transformed. - The value of `"last"` means only the last user input message is transformed. - The value of `"none"` means no user input messages are transformed. - transform_assistant - Whether to return assistant messages with transformation applied. This only - matters if an `transform_assistant_response` was provided to the chat - constructor. + Obtain chat messages that have been appended after initial load. Note ---- - Messages are listed in the order they were added. As a result, when this method - is called in a `.on_user_submit()` callback (as it most often is), the last - message will be the most recent one submitted by the user. + Startup messages (i.e., those passed to the `.ui()` method) are not included in the + return value. Also, this method must be called in a reactive context, and will invl Returns ------- - tuple[ChatMessage, ...] + tuple[ChatMessageDict, ...] A tuple of chat messages. """ - messages = self._messages() - - # Anthropic requires a user message first and no system messages - if format == "anthropic": - messages = self._trim_anthropic_messages(messages) - - if token_limits is not None: - messages = self._trim_messages(messages, token_limits, format) - - res: list[ChatMessageDict | ProviderMessage] = [] - for i, m in enumerate(messages): - transform = False - if m.role == "assistant": - transform = transform_assistant - elif m.role == "user": - transform = transform_user == "all" or ( - transform_user == "last" and i == len(messages) - 1 - ) - content_key = getattr( - m, "transform_key" if transform else "pre_transform_key" - ) - content = getattr(m, content_key) - chat_msg = ChatMessageDict(content=str(content), role=m.role) - if not isinstance(format, MISSING_TYPE): - chat_msg = as_provider_message(chat_msg, format) - res.append(chat_msg) - - return tuple(res) + return tuple( + ChatMessageDict(content=m.content, role=m.role) + for m in self._messages() + ) async def append_message( self, @@ -637,9 +398,6 @@ async def append_message( return msg = normalize_message(message) - msg = await self._transform_message(msg) - if msg is None: - return self._store_message(msg) await self._send_append_message( message=msg, @@ -764,21 +522,7 @@ async def _append_message_chunk( self._current_stream_message += msg.content try: - if self._needs_transform(msg): - # Transforming may change the meaning of msg.content to be a *replace* - # not *append*. So, update msg.content and the operation accordingly. - chunk_content = msg.content - msg.content = self._current_stream_message - operation = "replace" - msg = await self._transform_message( - msg, chunk=chunk, chunk_content=chunk_content - ) - # Act like nothing happened if transformed to None - if msg is None: - return - if chunk == "end": - self._store_message(msg) - elif chunk == "end": + if chunk == "end": # When `operation="append"`, msg.content is just a chunk, but we must # store the full message self._store_message( @@ -958,14 +702,11 @@ async def _flush_pending_messages(self): # Send a message to the UI async def _send_append_message( self, - message: TransformedMessage | ChatMessage, + message: ChatMessage, chunk: ChunkOption = False, operation: Literal["append", "replace"] = "append", icon: HTML | Tag | TagList | None = None, ): - if not isinstance(message, TransformedMessage): - message = TransformedMessage.from_chat_message(message) - if message.role == "system": # System messages are not displayed in the UI return @@ -981,7 +722,7 @@ async def _send_append_message( elif chunk == "end": chunk_type = "message_end" - content = message.content_client + content = message.content content_type = "html" if isinstance(content, HTML) else "markdown" # TODO: pass along dependencies for both content and icon (if any) @@ -1006,174 +747,12 @@ async def _send_append_message( # TODO: Joe said it's a good idea to yield here, but I'm not sure why? # await asyncio.sleep(0) - @overload - def transform_user_input( - self, fn: TransformUserInput | TransformUserInputAsync - ) -> None: ... - - @overload - def transform_user_input( - self, - ) -> Callable[[TransformUserInput | TransformUserInputAsync], None]: ... - - def transform_user_input( - self, fn: TransformUserInput | TransformUserInputAsync | None = None - ) -> None | Callable[[TransformUserInput | TransformUserInputAsync], None]: - """ - Transform user input. - - Use this method as a decorator on a function (`fn`) that transforms user input - before storing it in the chat messages returned by `.messages()`. This is - useful for implementing RAG workflows, like taking a URL and scraping it for - text before sending it to the model. - - Parameters - ---------- - fn - A function to transform user input before storing it in the chat - `.messages()`. If `fn` returns `None`, the user input is effectively - ignored, and `.on_user_submit()` callbacks are suspended until more input is - submitted. This behavior is often useful to catch and handle errors that - occur during transformation. In this case, the transform function should - append an error message to the chat (via `.append_message()`) to inform the - user of the error. - """ - - def _set_transform(fn: TransformUserInput | TransformUserInputAsync): - self._transform_user = _utils.wrap_async(fn) - - if fn is None: - return _set_transform - else: - return _set_transform(fn) - - @overload - def transform_assistant_response( - self, fn: TransformAssistantResponseFunction - ) -> None: ... - - @overload - def transform_assistant_response( - self, - ) -> Callable[[TransformAssistantResponseFunction], None]: ... - - def transform_assistant_response( - self, - fn: TransformAssistantResponseFunction | None = None, - ) -> None | Callable[[TransformAssistantResponseFunction], None]: - """ - Transform assistant responses. - - Use this method as a decorator on a function (`fn`) that transforms assistant - responses before displaying them in the chat. This is useful for post-processing - model responses before displaying them to the user. - - Parameters - ---------- - fn - A function that takes a string and returns either a string, - :class:`shiny.ui.HTML`, or `None`. If `fn` returns a string, it gets - interpreted and parsed as a markdown on the client (and the resulting HTML - is then sanitized). If `fn` returns :class:`shiny.ui.HTML`, it will be - displayed as-is. If `fn` returns `None`, the response is effectively ignored. - - Note - ---- - When doing an `.append_message_stream()`, `fn` gets called on every chunk of the - response (thus, it should be performant), and can optionally access more - information (i.e., arguments) about the stream. The 1st argument (required) - contains the accumulated content, the 2nd argument (optional) contains the - current chunk, and the 3rd argument (optional) is a boolean indicating whether - this chunk is the last one in the stream. - """ - - def _set_transform( - fn: TransformAssistantResponseFunction, - ): - nparams = len(inspect.signature(fn).parameters) - if nparams == 1: - fn = cast( - Union[ - TransformAssistantResponse, - TransformAssistantResponseAsync, - ], - fn, - ) - fn = _utils.wrap_async(fn) - - async def _transform_wrapper( - content: str, chunk: str, done: bool - ): - return await fn(content) - - self._transform_assistant = _transform_wrapper - - elif nparams == 3: - fn = cast( - Union[ - TransformAssistantResponseChunk, - TransformAssistantResponseChunkAsync, - ], - fn, - ) - self._transform_assistant = _utils.wrap_async(fn) - else: - raise Exception( - "A @transform_assistant_response function must take 1 or 3 arguments" - ) - - if fn is None: - return _set_transform - else: - return _set_transform(fn) - - async def _transform_message( - self, - message: ChatMessage, - chunk: ChunkOption = False, - chunk_content: str = "", - ) -> TransformedMessage | None: - res = TransformedMessage.from_chat_message(message) - - if message.role == "user" and self._transform_user is not None: - content = await self._transform_user(message.content) - elif ( - message.role == "assistant" - and self._transform_assistant is not None - ): - content = await self._transform_assistant( - message.content, - chunk_content, - chunk == "end" or chunk is False, - ) - else: - return res - - if content is None: - return None - - setattr(res, res.transform_key, content) - return res - - def _needs_transform(self, message: ChatMessage) -> bool: - if message.role == "user" and self._transform_user is not None: - return True - elif ( - message.role == "assistant" - and self._transform_assistant is not None - ): - return True - return False - # Just before storing, handle chunk msg type and calculate tokens def _store_message( self, - message: TransformedMessage | ChatMessage, + message: ChatMessage, index: int | None = None, ) -> None: - if not isinstance(message, TransformedMessage): - message = TransformedMessage.from_chat_message(message) - with reactive.isolate(): messages = self._messages() @@ -1184,115 +763,17 @@ def _store_message( messages.insert(index, message) self._messages.set(tuple(messages)) - if message.role == "user": - self._latest_user_input.set(message) return None - def _trim_messages( - self, - messages: tuple[TransformedMessage, ...], - token_limits: tuple[int, int], - format: MISSING_TYPE | ProviderMessageFormat, - ) -> tuple[TransformedMessage, ...]: - n_total, n_reserve = token_limits - if n_total <= n_reserve: - raise ValueError( - f"Invalid token limits: {token_limits}. The 1st value must be greater " - "than the 2nd value." - ) - - # Since don't trim system messages, 1st obtain their total token count - # (so we can determine how many non-system messages can fit) - n_system_tokens: int = 0 - n_system_messages: int = 0 - n_other_messages: int = 0 - token_counts: list[int] = [] - for m in messages: - count = self._get_token_count(m.content_server) - token_counts.append(count) - if m.role == "system": - n_system_tokens += count - n_system_messages += 1 - else: - n_other_messages += 1 - - remaining_non_system_tokens = n_total - n_reserve - n_system_tokens - - if remaining_non_system_tokens <= 0: - raise ValueError( - f"System messages exceed `.messages(token_limits={token_limits})`. " - "Consider increasing the 1st value of `token_limit` or setting it to " - "`token_limit=None` to disable token limits." - ) - - # Now, iterate through the messages in reverse order and appending - # until we run out of tokens - messages2: list[TransformedMessage] = [] - n_other_messages2: int = 0 - token_counts.reverse() - for i, m in enumerate(reversed(messages)): - if m.role == "system": - messages2.append(m) - continue - remaining_non_system_tokens -= token_counts[i] - if remaining_non_system_tokens >= 0: - messages2.append(m) - n_other_messages2 += 1 - - messages2.reverse() - - if len(messages2) == n_system_messages and n_other_messages2 > 0: - raise ValueError( - f"Only system messages fit within `.messages(token_limits={token_limits})`. " - "Consider increasing the 1st value of `token_limit` or setting it to " - "`token_limit=None` to disable token limits." - ) - - return tuple(messages2) - - def _trim_anthropic_messages( - self, - messages: tuple[TransformedMessage, ...], - ) -> tuple[TransformedMessage, ...]: - if any(m.role == "system" for m in messages): - raise ValueError( - "Anthropic requires a system prompt to be specified in it's `.create()` method " - "(not in the chat messages with `role: system`)." - ) - for i, m in enumerate(messages): - if m.role == "user": - return messages[i:] - - return () - - def _get_token_count( - self, - content: str, - ) -> int: - if self._tokenizer is None: - self._tokenizer = get_default_tokenizer() - - encoded = self._tokenizer.encode(content) - if isinstance(encoded, TokenizersEncoding): - return len(encoded.ids) - else: - return len(encoded) - - def user_input(self, transform: bool = False) -> str | None: + def user_input(self) -> str: """ Reactively read the user's message. - Parameters - ---------- - transform - Whether to apply the user input transformation function (if one was - provided). - Returns ------- - str | None - The user input message (before any transformation). + str + The user input message. Note ---- @@ -1303,12 +784,7 @@ def user_input(self, transform: bool = False) -> str | None: 2. Maintaining message state separately from `.messages()`. """ - msg = self._latest_user_input() - if msg is None: - return None - key = "content_server" if transform else "content_client" - val = getattr(msg, key) - return str(val) + return self._user_input() def _user_input(self) -> str: id = self.user_input_id @@ -1359,17 +835,6 @@ def update_user_input( self._session._send_message_sync({"custom": {"shinyChatMessage": msg}}) - def set_user_message(self, value: str): - """ - Deprecated. Use `update_user_input(value=value)` instead. - """ - - warn_deprecated( - "set_user_message() is deprecated. Use update_user_input(value=value) instead." - ) - - self.update_user_input(value=value) - async def clear_messages(self): """ Clear all chat messages. @@ -1506,11 +971,9 @@ async def _(url: str): if bookmark_on == "response": @reactive.effect - @reactive.event( - lambda: self.messages(format=MISSING), ignore_init=True - ) + @reactive.event(lambda: self.messages(), ignore_init=True) async def _(): - messages = self.messages(format=MISSING) + messages = self.messages() if len(messages) == 0: return @@ -1556,12 +1019,7 @@ def _on_bookmark_ui(state: BookmarkState): # This does NOT contain the `chat.ui(messages=)` values. # When restoring, the `chat.ui(messages=)` values will need to be kept # and the `ui.Chat(messages=)` values will need to be reset - state.values[resolved_bookmark_id_msgs_str] = self.messages( - format=MISSING - ) - - # Attempt to stop the initialization of the `ui.Chat(messages=)` messages - self._init_chat.destroy() + state.values[resolved_bookmark_id_msgs_str] = self.messages() @root_session.bookmark.on_restore async def _on_restore_ui(state: RestoreState): @@ -1574,8 +1032,6 @@ async def _on_restore_ui(state: RestoreState): # calling `self._init_chat.destroy()` above if resolved_bookmark_id_msgs_str not in state.values: - # If no messages to restore, display the `__init__(messages=)` messages - await self._append_init_messages() return msgs: list[Any] = state.values[resolved_bookmark_id_msgs_str] @@ -1603,7 +1059,7 @@ class ChatExpress(Chat): def ui( self, *, - messages: Optional[Sequence[str | ChatMessageDict]] = None, + messages: Optional[Sequence[TagChild | ChatMessageDict]] = None, placeholder: str = "Enter a message...", width: CssUnit = "min(680px, 100%)", height: CssUnit = "auto", diff --git a/pkg-py/src/shinychat/_chat_provider_types.py b/pkg-py/src/shinychat/_chat_provider_types.py deleted file mode 100644 index cb79ec4..0000000 --- a/pkg-py/src/shinychat/_chat_provider_types.py +++ /dev/null @@ -1,127 +0,0 @@ -import sys -from typing import TYPE_CHECKING, Literal, Union - -from ._chat_types import ChatMessageDict - -if TYPE_CHECKING: - from anthropic.types import MessageParam as AnthropicMessage - from langchain_core.messages import AIMessage, HumanMessage, SystemMessage - from ollama import Message as OllamaMessage - from openai.types.chat import ( - ChatCompletionAssistantMessageParam, - ChatCompletionSystemMessageParam, - ChatCompletionUserMessageParam, - ) - - if sys.version_info >= (3, 9): - import google.generativeai.types as gtypes # pyright: ignore[reportMissingTypeStubs] - - GoogleMessage = gtypes.ContentDict - else: - GoogleMessage = object - - LangChainMessage = Union[AIMessage, HumanMessage, SystemMessage] - OpenAIMessage = Union[ - ChatCompletionAssistantMessageParam, - ChatCompletionSystemMessageParam, - ChatCompletionUserMessageParam, - ] - - ProviderMessage = Union[ - AnthropicMessage, GoogleMessage, LangChainMessage, OpenAIMessage, OllamaMessage - ] -else: - AnthropicMessage = GoogleMessage = LangChainMessage = OpenAIMessage = ( - OllamaMessage - ) = ProviderMessage = object - -ProviderMessageFormat = Literal[ - "anthropic", - "google", - "langchain", - "openai", - "ollama", -] - - -# TODO: use a strategy pattern to allow others to register -# their own message formats -def as_provider_message( - message: ChatMessageDict, format: ProviderMessageFormat -) -> "ProviderMessage": - if format == "anthropic": - return as_anthropic_message(message) - if format == "google": - return as_google_message(message) - if format == "langchain": - return as_langchain_message(message) - if format == "openai": - return as_openai_message(message) - if format == "ollama": - return as_ollama_message(message) - raise ValueError(f"Unknown format: {format}") - - -def as_anthropic_message(message: ChatMessageDict) -> "AnthropicMessage": - from anthropic.types import MessageParam as AnthropicMessage - - if message["role"] == "system": - raise ValueError( - "Anthropic requires a system prompt to be specified in the `.create()` method" - ) - return AnthropicMessage(content=message["content"], role=message["role"]) - - -def as_google_message(message: ChatMessageDict) -> "GoogleMessage": - if sys.version_info < (3, 9): - raise ValueError("Google requires Python 3.9") - - import google.generativeai.types as gtypes # pyright: ignore[reportMissingTypeStubs] - - role = message["role"] - - if role == "system": - raise ValueError( - "Google requires a system prompt to be specified in the `GenerativeModel()` constructor." - ) - elif role == "assistant": - role = "model" - return gtypes.ContentDict(parts=[message["content"]], role=role) - - -def as_langchain_message(message: ChatMessageDict) -> "LangChainMessage": - from langchain_core.messages import AIMessage, HumanMessage, SystemMessage - - content = message["content"] - role = message["role"] - if role == "system": - return SystemMessage(content=content) - if role == "assistant": - return AIMessage(content=content) - if role == "user": - return HumanMessage(content=content) - raise ValueError(f"Unknown role: {message['role']}") - - -def as_openai_message(message: ChatMessageDict) -> "OpenAIMessage": - from openai.types.chat import ( - ChatCompletionAssistantMessageParam, - ChatCompletionSystemMessageParam, - ChatCompletionUserMessageParam, - ) - - content = message["content"] - role = message["role"] - if role == "system": - return ChatCompletionSystemMessageParam(content=content, role=role) - if role == "assistant": - return ChatCompletionAssistantMessageParam(content=content, role=role) - if role == "user": - return ChatCompletionUserMessageParam(content=content, role=role) - raise ValueError(f"Unknown role: {role}") - - -def as_ollama_message(message: ChatMessageDict) -> "OllamaMessage": - from ollama import Message as OllamaMessage - - return OllamaMessage(content=message["content"], role=message["role"]) diff --git a/pkg-py/src/shinychat/_chat_tokenizer.py b/pkg-py/src/shinychat/_chat_tokenizer.py deleted file mode 100644 index 3e0fc6f..0000000 --- a/pkg-py/src/shinychat/_chat_tokenizer.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import annotations - -from typing import ( - AbstractSet, - Any, - Collection, - Literal, - Protocol, - Union, - runtime_checkable, -) - - -# A duck type for tiktoken.Encoding -class TiktokenEncoding(Protocol): - name: str - - def encode( - self, - text: str, - *, - allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006 - disallowed_special: Union[Literal["all"], Collection[str]] = "all", - ) -> list[int]: ... - - -# A duck type for tokenizers.Encoding -@runtime_checkable -class TokenizersEncoding(Protocol): - @property - def ids(self) -> list[int]: ... - - -# A duck type for tokenizers.Tokenizer -class TokenizersTokenizer(Protocol): - def encode( - self, - sequence: Any, - pair: Any = None, - is_pretokenized: bool = False, - add_special_tokens: bool = True, - ) -> TokenizersEncoding: ... - - -TokenEncoding = Union[TiktokenEncoding, TokenizersTokenizer] - - -def get_default_tokenizer() -> TokenizersTokenizer: - try: - from tokenizers import Tokenizer - - return Tokenizer.from_pretrained("bert-base-cased") # type: ignore - except ImportError: - raise ImportError( - "Failed to download a default tokenizer. " - "A tokenizer is required to impose `token_limits` on `chat.messages()`. " - "To get a generic default tokenizer, install the `tokenizers` " - "package (`pip install tokenizers`). " - ) - except Exception as e: - raise RuntimeError( - "Failed to download a default tokenizer. " - "A tokenizer is required to impose `token_limits` on `chat.messages()`. " - "Try manually downloading a tokenizer using " - "`tokenizers.Tokenizer.from_pretrained()` and passing it to `ui.Chat()`." - f"Error: {e}" - ) from e diff --git a/pkg-py/src/shinychat/_chat_types.py b/pkg-py/src/shinychat/_chat_types.py index ea7dda4..b2d8c7c 100644 --- a/pkg-py/src/shinychat/_chat_types.py +++ b/pkg-py/src/shinychat/_chat_types.py @@ -1,9 +1,8 @@ from __future__ import annotations -from dataclasses import dataclass from typing import Literal, TypedDict -from htmltools import HTML, TagChild +from htmltools import TagChild from shiny.session import require_active_session from ._typing_extensions import NotRequired @@ -39,35 +38,6 @@ def __init__( self.html_deps = deps -# A message once transformed have been applied -@dataclass -class TransformedMessage: - content_client: str | HTML - content_server: str - role: Role - transform_key: Literal["content_client", "content_server"] - pre_transform_key: Literal["content_client", "content_server"] - html_deps: list[dict[str, str]] | None = None - - @classmethod - def from_chat_message(cls, message: ChatMessage) -> "TransformedMessage": - if message.role == "user": - transform_key = "content_server" - pre_transform_key = "content_client" - else: - transform_key = "content_client" - pre_transform_key = "content_server" - - return TransformedMessage( - content_client=message.content, - content_server=message.content, - role=message.role, - transform_key=transform_key, - pre_transform_key=pre_transform_key, - html_deps=message.html_deps, - ) - - # A message that can be sent to the client class ClientMessage(TypedDict): content: str diff --git a/pkg-py/tests/playwright/chat/basic/app.py b/pkg-py/tests/playwright/chat/basic/app.py index 2fd7caa..2c6e24d 100644 --- a/pkg-py/tests/playwright/chat/basic/app.py +++ b/pkg-py/tests/playwright/chat/basic/app.py @@ -5,15 +5,10 @@ ui.page_opts(title="Hello Chat") # Create a chat instance, with an initial message -chat = Chat( - id="chat", - messages=[ - {"content": "Hello! How can I help you today?", "role": "assistant"}, - ], -) +chat = Chat(id="chat") # Display the chat -chat.ui() +chat.ui(messages=["Hello! How can I help you today?"]) # Define a callback to run when the user submits a message diff --git a/pkg-py/tests/playwright/chat/basic/test_chat_basic.py b/pkg-py/tests/playwright/chat/basic/test_chat_basic.py index ae19478..d811097 100644 --- a/pkg-py/tests/playwright/chat/basic/test_chat_basic.py +++ b/pkg-py/tests/playwright/chat/basic/test_chat_basic.py @@ -42,7 +42,6 @@ def test_validate_chat_basic(page: Page, local_app: ShinyAppProc) -> None: message_state = controller.OutputCode(page, "message_state") message_state_expected = tuple( [ - {"content": initial_message, "role": "assistant"}, {"content": f"\n{user_message}", "role": "user"}, {"content": f"You said: \n{user_message}", "role": "assistant"}, {"content": f"{user_message2}", "role": "user"}, diff --git a/pkg-py/tests/playwright/chat/icon/app.py b/pkg-py/tests/playwright/chat/icon/app.py index fd2adfa..927d3c1 100644 --- a/pkg-py/tests/playwright/chat/icon/app.py +++ b/pkg-py/tests/playwright/chat/icon/app.py @@ -11,19 +11,14 @@ with ui.layout_columns(): # Default Bot --------------------------------------------------------------------- - chat_default = Chat( - id="chat_default", - messages=[ - { - "content": "Hello! I'm Default Bot. How can I help you today?", - "role": "assistant", - }, - ], - ) + chat_default = Chat(id="chat_default") with ui.div(): ui.h2("Default Bot") - chat_default.ui(icon_assistant=None) + chat_default.ui( + messages=["Hello! I'm Default Bot. How can I help you today?"], + icon_assistant=None, + ) @chat_default.on_user_submit async def handle_user_input_default(user_input: str): diff --git a/pkg-py/tests/playwright/chat/input-suggestion/app.py b/pkg-py/tests/playwright/chat/input-suggestion/app.py index c9207cb..254f3c1 100644 --- a/pkg-py/tests/playwright/chat/input-suggestion/app.py +++ b/pkg-py/tests/playwright/chat/input-suggestion/app.py @@ -12,9 +12,9 @@ And this suggestion will also auto-submit.
""" -chat = Chat("chat", messages=[suggestion2]) +chat = Chat("chat") -chat.ui(messages=[suggestions1]) +chat.ui(messages=[suggestions1, suggestion2]) @chat.on_user_submit diff --git a/pkg-py/tests/playwright/chat/shiny_input/app.py b/pkg-py/tests/playwright/chat/shiny_input/app.py index e7d3c0b..0dfed9a 100644 --- a/pkg-py/tests/playwright/chat/shiny_input/app.py +++ b/pkg-py/tests/playwright/chat/shiny_input/app.py @@ -19,11 +19,11 @@ ), ) -chat = Chat( - id="chat", +chat = Chat(id="chat") +chat.ui( + class_="mb-5", messages=[welcome], ) -chat.ui(class_="mb-5") @reactive.effect diff --git a/pkg-py/tests/playwright/chat/shiny_output/app.py b/pkg-py/tests/playwright/chat/shiny_output/app.py index 1d9853e..0f7f6d6 100644 --- a/pkg-py/tests/playwright/chat/shiny_output/app.py +++ b/pkg-py/tests/playwright/chat/shiny_output/app.py @@ -18,12 +18,9 @@ def map(): return ipyl.Map(center=(52, 10), zoom=8) -chat = ui.Chat( - id="chat", - messages=[map_ui], -) +chat = ui.Chat(id="chat") -chat.ui() +chat.ui(messages=[map_ui]) with ui.hold() as df_1: diff --git a/pkg-py/tests/playwright/chat/transform/app.py b/pkg-py/tests/playwright/chat/transform/app.py deleted file mode 100644 index 50a695b..0000000 --- a/pkg-py/tests/playwright/chat/transform/app.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Union - -from shiny.express import render, ui -from shinychat.express import Chat - -# Set some Shiny page options -ui.page_opts(title="Hello Chat") - -# Create a chat instance, with an initial message -chat = Chat(id="chat") - -# Display the chat -chat.ui() - - -@chat.transform_user_input -async def capitalize(input: str) -> Union[str, None]: - if input == "return None": - return None - elif input == "return custom message": - await chat.append_message("Custom message") - return None - else: - return input.upper() - - -@chat.on_user_submit -async def _(): - user = chat.user_input(transform=True) - await chat.append_message(f"Transformed input: {user}") - - -"chat.messages():" - - -@render.code -def message_state(): - return str(chat.messages()) - - -"chat.messages(transform_user='none'):" - - -@render.code -def message_state2(): - return str(chat.messages(transform_user="none")) diff --git a/pkg-py/tests/playwright/chat/transform/test_chat_transform.py b/pkg-py/tests/playwright/chat/transform/test_chat_transform.py deleted file mode 100644 index f8c5b56..0000000 --- a/pkg-py/tests/playwright/chat/transform/test_chat_transform.py +++ /dev/null @@ -1,64 +0,0 @@ -from playwright.sync_api import Page, expect -from shiny.playwright import controller -from shiny.run import ShinyAppProc -from shinychat.playwright import ChatController - - -def test_validate_chat_transform(page: Page, local_app: ShinyAppProc) -> None: - page.goto(local_app.url) - - chat = ChatController(page, "chat") - message_state = controller.OutputCode(page, "message_state") - message_state2 = controller.OutputCode(page, "message_state2") - - # Wait for app to load - message_state.expect_value("()", timeout=30 * 1000) - - expect(chat.loc).to_be_visible(timeout=30 * 1000) - expect(chat.loc_input_button).to_be_disabled() - - user_msg = "hello" - chat.set_user_input(user_msg) - chat.send_user_input() - chat.expect_latest_message( - f"Transformed input: {user_msg.upper()}", - timeout=30 * 1000, - ) - - user_msg2 = "return None" - chat.set_user_input(user_msg2) - chat.send_user_input() - chat.expect_latest_message("return None") - - user_msg3 = "return custom message" - chat.set_user_input(user_msg3) - chat.send_user_input() - chat.expect_latest_message("Custom message") - - message_state_expected = tuple( - [ - {"content": user_msg.upper(), "role": "user"}, - { - "content": f"Transformed input: {user_msg.upper()}", - "role": "assistant", - }, - {"content": "return None", "role": "user"}, - {"content": "return custom message", "role": "user"}, - {"content": "Custom message", "role": "assistant"}, - ] - ) - message_state.expect_value(str(message_state_expected)) - - message_state_expected2 = tuple( - [ - {"content": user_msg, "role": "user"}, - { - "content": f"Transformed input: {user_msg.upper()}", - "role": "assistant", - }, - {"content": "return None", "role": "user"}, - {"content": "return custom message", "role": "user"}, - {"content": "Custom message", "role": "assistant"}, - ] - ) - message_state2.expect_value(str(message_state_expected2)) diff --git a/pkg-py/tests/playwright/chat/transform_assistant/app.py b/pkg-py/tests/playwright/chat/transform_assistant/app.py deleted file mode 100644 index 8f78b86..0000000 --- a/pkg-py/tests/playwright/chat/transform_assistant/app.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import Union - -from shiny.express import render, ui -from shinychat.express import Chat - -# Set some Shiny page options -ui.page_opts(title="Hello Chat") - -# Create a chat instance, with an initial message -chat = Chat(id="chat") - -# Display the chat -chat.ui() - - -# TODO: test with append_message_stream() as well -@chat.transform_assistant_response -def transform(content: str) -> Union[str, ui.HTML]: - if content == "return HTML": - return ui.HTML(f"Transformed response: {content}") - else: - return f"Transformed response: `{content}`" - - -@chat.on_user_submit -async def _(): - user = chat.user_input() - await chat.append_message(user) - - -"chat.messages():" - - -@render.code -def message_state(): - return str(chat.messages()) - - -"chat.messages(transform_assistant=True):" - - -@render.code -def message_state2(): - return str(chat.messages(transform_assistant=True)) diff --git a/pkg-py/tests/playwright/chat/transform_assistant/test_chat_transform_assistant.py b/pkg-py/tests/playwright/chat/transform_assistant/test_chat_transform_assistant.py deleted file mode 100644 index e80e993..0000000 --- a/pkg-py/tests/playwright/chat/transform_assistant/test_chat_transform_assistant.py +++ /dev/null @@ -1,55 +0,0 @@ -from playwright.sync_api import Page, expect -from shiny.playwright import controller -from shiny.run import ShinyAppProc -from shinychat.playwright import ChatController - - -def test_validate_chat_transform_assistant( - page: Page, local_app: ShinyAppProc -) -> None: - page.goto(local_app.url) - - chat = ChatController(page, "chat") - message_state = controller.OutputCode(page, "message_state") - message_state2 = controller.OutputCode(page, "message_state2") - - # Wait for app to load - message_state.expect_value("()", timeout=30 * 1000) - - expect(chat.loc).to_be_visible(timeout=30 * 1000) - expect(chat.loc_input_button).to_be_disabled() - - user_msg = "hello" - chat.set_user_input(user_msg) - chat.send_user_input() - code = chat.loc_latest_message.locator("code") - expect(code).to_have_text("hello", timeout=30 * 1000) - - user_msg2 = "return HTML" - chat.set_user_input(user_msg2) - chat.send_user_input() - bold = chat.loc_latest_message.locator("b") - expect(bold).to_have_text("Transformed response") - - message_state_expected = tuple( - [ - {"content": "hello", "role": "user"}, - {"content": "hello", "role": "assistant"}, - {"content": "return HTML", "role": "user"}, - {"content": "return HTML", "role": "assistant"}, - ] - ) - message_state.expect_value(str(message_state_expected)) - - message_state_expected2 = tuple( - [ - {"content": "hello", "role": "user"}, - {"content": "Transformed response: `hello`", "role": "assistant"}, - {"content": "return HTML", "role": "user"}, - { - "content": "Transformed response: return HTML", - "role": "assistant", - }, - ] - ) - message_state2.expect_value(str(message_state_expected2)) diff --git a/pkg-py/tests/playwright/chat/transform_assistant_stream/app.py b/pkg-py/tests/playwright/chat/transform_assistant_stream/app.py deleted file mode 100644 index d24e36c..0000000 --- a/pkg-py/tests/playwright/chat/transform_assistant_stream/app.py +++ /dev/null @@ -1,27 +0,0 @@ -import shiny.express # noqa: F401 -from shiny import render -from shinychat.express import Chat - -chat = Chat(id="chat") -chat.ui() - - -@chat.transform_assistant_response -def transform(content: str, chunk: str, done: bool): - if done: - return content + "...DONE!" - else: - return content - - -@chat.on_user_submit -async def _(): - await chat.append_message_stream(("Simple ", "response")) - - -"Message state:" - - -@render.code -def message_state(): - return str(chat.messages()) diff --git a/pkg-py/tests/playwright/chat/transform_assistant_stream/test_chat_transform_assistant_stream.py b/pkg-py/tests/playwright/chat/transform_assistant_stream/test_chat_transform_assistant_stream.py deleted file mode 100644 index c5a3c82..0000000 --- a/pkg-py/tests/playwright/chat/transform_assistant_stream/test_chat_transform_assistant_stream.py +++ /dev/null @@ -1,31 +0,0 @@ -from playwright.sync_api import Page, expect -from shiny.playwright import controller -from shiny.run import ShinyAppProc -from shinychat.playwright import ChatController - - -def test_validate_chat_transform_assistant( - page: Page, local_app: ShinyAppProc -) -> None: - page.goto(local_app.url) - - chat = ChatController(page, "chat") - message_state = controller.OutputCode(page, "message_state") - - # Wait for app to load - message_state.expect_value("()", timeout=30 * 1000) - - expect(chat.loc).to_be_visible(timeout=30 * 1000) - expect(chat.loc_input_button).to_be_disabled() - - chat.set_user_input("foo") - chat.send_user_input() - chat.expect_latest_message("Simple response...DONE!", timeout=30 * 1000) - - message_state_expected = tuple( - [ - {"content": "foo", "role": "user"}, - {"content": "Simple response", "role": "assistant"}, - ] - ) - message_state.expect_value(str(message_state_expected))