Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Native Chat Handlers Overridable via Entry Points #1249

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 79 additions & 41 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import partial
from typing import Dict

import jupyter_ydoc # must be imported before YChat
import traitlets
from dask.distributed import Client as DaskClient
from importlib_metadata import entry_points
Expand All @@ -20,14 +21,7 @@
from tornado.web import StaticFileHandler
from traitlets import Integer, List, Unicode

from .chat_handlers import (
AskChatHandler,
BaseChatHandler,
DefaultChatHandler,
GenerateChatHandler,
HelpChatHandler,
LearnChatHandler,
)
from .chat_handlers.base import BaseChatHandler
from .completions.handlers import DefaultInlineCompletionHandler
from .config_manager import ConfigManager
from .constants import BOT
Expand Down Expand Up @@ -460,37 +454,34 @@ def _init_chat_handlers(self, ychat: YChat) -> Dict[str, BaseChatHandler]:
assert self.serverapp

eps = entry_points()
chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers")
all_chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers")

# Override native chat handlers if duplicates are present
sorted_eps = sorted(
all_chat_handler_eps, key=lambda ep: ep.dist.name != "jupyter_ai"
)
seen = {}
for ep in sorted_eps:
seen[ep.name] = ep
chat_handler_eps = list(seen.values())

Comment on lines 456 to +467
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is important to avoid very long method definitions for readability. Can this block to moved to a new function inside new module, e.g. get_chat_handler_eps() under jupyter_ai/entry_points/utils.py?

chat_handlers: Dict[str, BaseChatHandler] = {}
llm_chat_memory = YChatHistory(ychat, k=self.default_max_chat_history)

chat_handler_kwargs = {
"log": self.log,
"config_manager": self.settings["jai_config_manager"],
"model_parameters": self.settings["model_parameters"],
"config_manager": self.settings.get("jai_config_manager"),
"model_parameters": self.settings.get("model_parameters"),
"llm_chat_memory": llm_chat_memory,
"root_dir": self.serverapp.root_dir,
"dask_client_future": self.settings["dask_client_future"],
"dask_client_future": self.settings.get("dask_client_future"),
"preferred_dir": self.serverapp.contents_manager.preferred_dir,
"help_message_template": self.help_message_template,
"chat_handlers": chat_handlers,
"context_providers": self.settings["jai_context_providers"],
"message_interrupted": self.settings["jai_message_interrupted"],
"context_providers": self.settings.get("jai_context_providers"),
"message_interrupted": self.settings.get("jai_message_interrupted"),
"ychat": ychat,
}
default_chat_handler = DefaultChatHandler(**chat_handler_kwargs)
generate_chat_handler = GenerateChatHandler(
**chat_handler_kwargs,
log_dir=self.error_logs_dir,
)
learn_chat_handler = LearnChatHandler(**chat_handler_kwargs)
retriever = Retriever(learn_chat_handler=learn_chat_handler)
ask_chat_handler = AskChatHandler(**chat_handler_kwargs, retriever=retriever)

chat_handlers["default"] = default_chat_handler
chat_handlers["/ask"] = ask_chat_handler
chat_handlers["/generate"] = generate_chat_handler
chat_handlers["/learn"] = learn_chat_handler

slash_command_pattern = r"^[a-zA-Z0-9_]+$"
for chat_handler_ep in chat_handler_eps:
Expand All @@ -503,21 +494,34 @@ def _init_chat_handlers(self, ychat: YChat) -> Dict[str, BaseChatHandler]:
)
continue

# Skip disabled entrypoints
ep_disabled = getattr(chat_handler, "disabled", False)
if ep_disabled:
self.log.warn(
f"Skipping registration of chat handler `{chat_handler_ep.name}` as it is explicitly disabled."
)
continue

Comment on lines +497 to +504
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the validation checks should also be moved to a new function in a new module, e.g. verify_chat_handler_class() in jupyter_ai/chat_handlers/utils.py.

if chat_handler.routing_type.routing_method == "slash_command":
# Each slash ID must be used only once.
# Slash IDs may contain only alphanumerics and underscores.
slash_id = chat_handler.routing_type.slash_id
# Set default slash_id if it's the default chat handler
slash_id = (
"default"
if chat_handler.id == "default"
else chat_handler.routing_type.slash_id
)

if slash_id is None:
if not slash_id:
self.log.error(
f"Handler `{chat_handler_ep.name}` has an invalid slash command "
+ f"`None`; only the default chat handler may use this"
)
continue

# Validate slash ID (/^[A-Za-z0-9_]+$/)
# Validate the slash command name
if re.match(slash_command_pattern, slash_id):
command_name = f"/{slash_id}"
command_name = (
"default" if slash_id == "default" else f"/{slash_id}"
)
else:
self.log.error(
f"Handler `{chat_handler_ep.name}` has an invalid slash command "
Expand All @@ -527,20 +531,54 @@ def _init_chat_handlers(self, ychat: YChat) -> Dict[str, BaseChatHandler]:
continue

if command_name in chat_handlers:
self.log.error(
f"Unable to register chat handler `{chat_handler.id}` because command `{command_name}` already has a handler"
self.log.warn(
f"Overriding existing handler `{command_name}` with `{chat_handler.id}`."
)
Comment on lines 533 to 536
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem necessary anymore because we are filtering out duplicates when getting the list of chat handler entry points (on line 457), correct?

continue

# The entry point is a class; we need to instantiate the class to send messages to it
chat_handlers[command_name] = chat_handler(**chat_handler_kwargs)
# Special handling for native `/ask`
if (
command_name == "/ask"
and chat_handler.__module__ == "jupyter_ai.chat_handlers.ask"
):
try:
learn_ep = next(
(ep for ep in chat_handler_eps if ep.name == "learn"), None
)
if not learn_ep:
self.log.error(
"No entry point found for 'learn' handler; skipping '/ask' registration."
)
continue

LearnChatHandler = learn_ep.load()
learn_handler = LearnChatHandler(**chat_handler_kwargs)

retriever = Retriever(learn_chat_handler=learn_handler)

chat_handlers[command_name] = chat_handler(
**chat_handler_kwargs, retriever=retriever
)
except Exception as e:
self.log.error(f"Failed to load 'learn' handler for '/ask': {e}")
continue

# Special handling for `/generate`
elif (
command_name == "/generate"
and chat_handler.__module__ == "jupyter_ai.chat_handlers.generate"
):
chat_handlers[command_name] = chat_handler(
**chat_handler_kwargs, log_dir=self.error_logs_dir
)

# General case
else:
chat_handlers[command_name] = chat_handler(**chat_handler_kwargs)

self.log.info(
f"Registered chat handler `{chat_handler.id}` with command `{command_name}`."
)

# Make help always appear as the last command
chat_handlers["/help"] = HelpChatHandler(**chat_handler_kwargs)

return chat_handlers

def _init_context_providers(self):
Expand Down
7 changes: 7 additions & 0 deletions packages/jupyter-ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ dynamic = ["version", "description", "authors", "urls", "keywords"]
[project.entry-points."jupyter_ai.default_tasks"]
core_default_tasks = "jupyter_ai:tasks"

[project.entry-points."jupyter_ai.chat_handlers"]
default = "jupyter_ai.chat_handlers.default:DefaultChatHandler"
ask = "jupyter_ai.chat_handlers.ask:AskChatHandler"
generate = "jupyter_ai.chat_handlers.generate:GenerateChatHandler"
learn = "jupyter_ai.chat_handlers.learn:LearnChatHandler"
help = "jupyter_ai.chat_handlers.help:HelpChatHandler"

Comment on lines +46 to +52
Copy link
Member

@dlqqq dlqqq Feb 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that providing our own entry points to "jupyter_ai.chat_handlers" may not allow other extensions to override our entry points. Suppose another server extension wants to override Jupyter AI's /help command. Then their pyproject.toml file contains:

[project.entry-points."jupyter_ai.chat_handlers"]
help = "example_extension.chat_handlers:CustomHelpChatHandler"

In the current implementation, the chat handler used for /help is defined by the order the entry points are loaded in, because both are setting the same key in the chat_handlers dictionary.

For example, if Jupyter AI gets loaded first, then example_extension overrides /help with their custom implementation (which is desired). However, if example_extension is loaded first, then the custom /help is overridden by the Jupyter AI's default /help (not desired).

  1. Can you find documentation on what order entry points are loaded in if multiple packages are providing the same entry point?
  2. Can you make sure that other extensions can always override our entry points?

Feel free to add an example in the packages/jupyter-ai-test package, which is a local Python package used to verify that entry points work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point. As per https://packaging.python.org/en/latest/specifications/entry-points/#entry-points

Within a distribution, entry point names should be unique. If different distributions provide the same name, the consumer decides how to handle such conflicts.

My understanding is that we need to loop the output of entry_points() (chat_handler_eps) twice or thrice:

  • sort so that jupyter-ai is first
  • reduce so that duplicates later in the list win
  • only then load the entry points

I think we would ideally have a way to specify an entry point which will skip adding a chat handler (say to disable /ask if downstream has its own RAG which is not compatible with Retriver class). This could already be done by providing one that will error out but this is of course suboptimal.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this will be useful. Should we let extension authors add a 'disabled: true' field in the entry point classes to skip adding the entry point?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krassowski @Darshan808 Allowing other server extensions to optionally disable JAI's native slash commands seems reasonable. However, we should add contributor documentation that indicates you can disable a slash command globally by providing a mostly-empty chat handler with disabled = True:

class AskChatHandler:
    disabled = True

@krassowski BTW, we are exploring the idea of performing RAG automatically in JAI, removing the need for /ask. I'm going to open more issues for this to share our ideas with you & the broader public.

[project.optional-dependencies]
test = [
"jupyter-server[test]>=1.6,<3",
Expand Down
Loading