Skip to content

WIP: Tool calling UI #52

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg-r/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export(chat_enable_bookmarking)
export(chat_mod_server)
export(chat_mod_ui)
export(chat_ui)
export(contents_shinychat)
export(markdown_stream)
export(output_markdown_stream)
export(update_chat_user_input)
Expand Down
10 changes: 8 additions & 2 deletions pkg-r/R/chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ chat_deps <- function() {
src = "lib/shiny",
script = list(
list(src = "chat/chat.js", type = "module"),
list(src = "markdown-stream/markdown-stream.js", type = "module")
list(src = "markdown-stream/markdown-stream.js", type = "module"),
list(src = "tools/tool-request.js")
),
stylesheet = c(
"chat/chat.css",
"markdown-stream/markdown-stream.css"
"markdown-stream/markdown-stream.css",
"tools/tool-request.css"
)
)
}
Expand Down Expand Up @@ -474,6 +476,10 @@ rlang::on_load(

res$add(msg)

if (S7::S7_inherits(msg, ellmer::Content)) {
msg <- contents_shinychat(msg)
}

chat_append_message(
id,
list(role = role, content = msg),
Expand Down
36 changes: 23 additions & 13 deletions pkg-r/R/chat_app.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,7 @@ chat_mod_ui <- function(id, ..., client = NULL, messages = NULL) {
if (!is.null(client)) {
check_ellmer_chat(client)

client_msgs <- map(client$get_turns(), function(turn) {
content <- ellmer::contents_markdown(turn)
if (is.null(content) || identical(content, "")) {
return(NULL)
}
list(role = turn@role, content = content)
})
client_msgs <- compact(client_msgs)
client_msgs <- contents_shinychat(client)

if (length(client_msgs)) {
if (!is.null(messages)) {
Expand All @@ -134,7 +127,7 @@ chat_mod_ui <- function(id, ..., client = NULL, messages = NULL) {
}
}

shinychat::chat_ui(
chat_ui(
shiny::NS(id, "chat"),
messages = messages,
...
Expand All @@ -148,12 +141,29 @@ chat_mod_server <- function(id, client) {

append_stream_task <- shiny::ExtendedTask$new(
function(client, ui_id, user_input) {
promises::then(
promises::promise_resolve(client$stream_async(user_input)),
function(stream) {
chat_append(ui_id, stream)
clear_on_tool_result <- client$on_tool_result(function(result) {
session <- shiny::getDefaultReactiveDomain()
if (is.null(session)) {
return()
}
session$sendCustomMessage(
"shinychat-hide-tool-request",
result@request@id
)
})

stream <- client$stream_async(
user_input,
stream = "content"
)

p <- promises::promise_resolve(stream)
p <- promises::then(p, function(stream) {
chat_append(ui_id, stream)
})
promises::finally(p, function() {
clear_on_tool_result()
})
}
)

Expand Down
221 changes: 221 additions & 0 deletions pkg-r/R/contents_shinychat.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
#' Format ellmer content for shinychat
#'
#' @param content An [`ellmer::Content`] object.
#'
#' @return Returns text or HTML formatted for use in `chat_ui()`.
#'
#' @export
contents_shinychat <- S7::new_generic("contents_shinychat", "content")

S7::method(contents_shinychat, ellmer::Content) <- function(content) {
# Fall back to html or markdown
html <- ellmer::contents_html(content)
if (!is.null(html)) shiny::HTML(html) else ellmer::contents_markdown(content)
}

S7::method(contents_shinychat, ellmer::ContentText) <- function(content) {
content@text
}

S7::method(contents_shinychat, ellmer::ContentToolRequest) <- function(
content
) {
call <- format(content, show = "call")
if (length(call) > 1) {
call <- sprintf("%s()", content@name)
}
shiny::HTML(sprintf(
'\n\n<p class="shiny-tool-request" data-tool-call-id="%s">Running <code>%s</code></p>\n\n',
content@id,
call
))
}

S7::method(contents_shinychat, ellmer::ContentToolResult) <- function(
content
) {
deps <- NULL

tool_result_display <- function(content) {
display <- content@extra$display
if (is.null(display)) {
return(pre_code(content@value))
}

html <- NULL
md <- NULL
text <- NULL

if (
is.list(display) &&
!inherits(display, c("shiny.tag.list", "shiny.tag"))
) {
if (
!some(
c("text", "markdown", "html"),
\(x) x %in% names(display)
)
) {
stop(
"ContentToolResult@extra$display must be a list with at least one of the following elements: text, markdown, html."
)
}
html <- display$html
md <- display$markdown
text <- display$text
} else {
if (inherits(display, "html")) {
html <- display
} else {
md <- display
}
}

if (!is.null(html)) {
deps <<- htmltools::findDependencies(html)
return(format(html))
}

if (!is.null(markdown)) {
md <- paste(md, collapse = "\n")
md <- paste0("\n\n", md, "\n\n")
return(md)
}

return(text %||% pre_code(contents$value))
}

if (isFALSE(content@extra$display_tool_request)) {
res <- tool_result_display(content)
if (!is.null(deps)) {
res <- htmltools::attachDependencies(res, deps)
}
return(res)
}

if (!is.null(content@error)) {
class <- "shiny-tool-result failed"
summary_text <- "Failed to call"
tool_result <- sprintf(
"<strong>Error</strong>%s",
pre_code(strip_ansi(content@error))
)
} else {
class <- "shiny-tool-result"
summary_text <- "Result from"
tool_result <- sprintf(
'<strong>Tool Result</strong>%s',
tool_result_display(content)
)
}

tool_name <- "unknown tool"
tool <- content@request@tool
if (!is.null(tool)) {
tool_name <- tool@name
if (!is.null(tool@annotations$title)) {
tool_name <- tool@annotations$title
summary_text <- ""
}
}

intent <- ""
if (!is.null(content@request@arguments$intent)) {
intent <- sprintf(
' | <span class="intent">%s</span>',
content@request@arguments$intent
)
}

details_open <- sprintf(
'<details class="%s" id="%s">',
class,
content@request@id
)

summary <- sprintf(
'<summary>%s <span class="function-name">%s</span>%s</summary>',
summary_text,
tool_name,
intent
)

tool_call <- sprintf(
'<strong>Tool Call</strong>%s',
pre_code(format(content@request, show = "call"))
)

body <- sprintf(
'<p>%s</p><p>%s</p></details>\n\n',
tool_call,
tool_result
)

res <- shiny::HTML(paste0(details_open, summary, body))
if (!is.null(deps)) {
res <- htmltools::attachDependencies(res, deps)
}
return(res)
}

S7::method(contents_shinychat, ellmer::Turn) <- function(content) {
lapply(content@contents, contents_shinychat)
}

S7::method(contents_shinychat, S7::new_S3_class(c("Chat", "R6"))) <- function(
content
) {
# Consolidate tool calls into assistant turns. This currently assumes that
# tool calls are always returned in user turns that have at least one
# proceeding assistant turn.
turns <- map(content$get_turns(), function(turn) {
if (
all(map_lgl(turn@contents, S7::S7_inherits, ellmer::ContentToolResult))
) {
turn@role <- "assistant"
}
is_tool_request <- map_lgl(
turn@contents,
S7::S7_inherits,
ellmer::ContentToolRequest
)
turn@contents <- turn@contents[!is_tool_request]
turn
})
turns <- reduce(turns, .init = list(), function(turns, turn) {
if (length(turns) == 0) {
return(list(turn))
}

# consolidate turns with adjacent roles
last_turn <- turns[[length(turns)]]
if (identical(last_turn@role, turn@role)) {
turns[[length(turns)]]@contents <- c(last_turn@contents, turn@contents)
return(turns)
}

c(turns, list(turn))
})

messages <- map(turns, function(turn) {
content <- compact(contents_shinychat(turn))
if (is.null(content) || identical(content, "")) {
return(NULL)
}
if (every(content, is.character)) {
# TODO: Fix chat_ui() to handle lists of strings
content <- paste(unlist(content), collapse = "\n\n")
}
list(role = turn@role, content = content)
})

compact(messages)
}


pre_code <- function(x) {
x <- gsub("`", "&#96;", x, fixed = TRUE)
x <- gsub("<", "&lt;", x, fixed = TRUE)
x <- gsub(">", "&gt;", x, fixed = TRUE)
sprintf("<pre><code>%s</code></pre>", paste(x, collapse = "\n"))
}
1 change: 1 addition & 0 deletions pkg-r/R/shinychat-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ NULL
ignore_unused_imports <- function() {
jsonlite::fromJSON
fastmap::fastqueue
ellmer::contents_html
}

release_bullets <- function() {
Expand Down
Loading