Skip to content

Commit

Permalink
Accept messages during construction; add messages and system_prompt a…
Browse files Browse the repository at this point in the history
…ccessors
  • Loading branch information
jcheng5 authored Sep 11, 2024
1 parent df48d31 commit 7628a0b
Show file tree
Hide file tree
Showing 9 changed files with 283 additions and 54 deletions.
3 changes: 2 additions & 1 deletion R/cat-wordwrap.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ cat_word_wrap_impl <- coro::generator(function(con = stdout()) {
cat(..., file = con, sep = "")
}

console_width <- cli::console_width()

pos_cursor <- 1
buffer <- ""

while (TRUE) {
input <- coro::yield()
console_width <- as.integer(Sys.getenv("COLUMNS", getOption("width", 80))) * 0.9

input <- paste0(buffer, input)
buffer <- ""
Expand Down
171 changes: 140 additions & 31 deletions R/chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,21 @@ NULL
#' (aka an R function), it also takes care of the tool loop.
#'
#' @param system_prompt A system prompt to set the behavior of the assistant.
#' @param base_url The base URL to the endpoint; the default uses ChatGPT.
#' @param messages A list of messages to start the chat with (i.e., continuing a
#' previous conversation). If not provided, the conversation begins from
#' scratch. Do not provide non-`NULL` values for both `messages` and
#' `system_prompt`.
#'
#' Each message in the list should be a named list with at least `role`
#' (usually `system`, `user`, or `assistant`, but `tool` is also possible).
#' Normally there is also a `content` field, which is a string.
#' @param base_url The base URL to the endpoint; the default uses OpenAI.
#' @param api_key The API key to use for authentication. You generally should
#' not supply this directly, but instead set the `OPENAI_API_KEY` environment
#' variable.
#' @param model The model to use for the chat; defaults to GPT-4o mini.
#' @param model The model to use for the chat; set to `NULL` (the default) to
#' use a reasonable model, currently `gpt-4o-mini`. We strongly recommend
#' explicitly choosing a model for all but the most casual use.
#' @param echo If `TRUE`, the `chat()` method streams the response to stdout by
#' default. (Note that this has no effect on the `stream()`, `chat_async()`,
#' and `stream_async()` methods.)
Expand Down Expand Up @@ -49,50 +59,117 @@ NULL
#' Briefly explain your work.
#' ")
new_chat_openai <- function(system_prompt = NULL,
messages = NULL,
base_url = "https://api.openai.com/v1",
api_key = openai_key(),
model = "gpt-4o-mini",
model = NULL,
echo = FALSE) {
check_string(system_prompt, allow_null = TRUE)
check_string(base_url)
check_openai_conversation(messages, allow_null = TRUE)
check_string(base_url)
check_string(api_key)
check_string(model)
check_string(model, allow_null = TRUE, allow_na = TRUE)
check_bool(echo)

system_prompt <- system_prompt %||%
"You are a helpful assistant from New Zealand who is an experienced R programmer"
model <- model %||% "gpt-4o-mini"

messages <- apply_system_prompt_openai(system_prompt, messages)

ChatOpenAI$new(
base_url = base_url,
model = model,
messages = messages,
api_key = api_key,
system_prompt = system_prompt,
echo = echo
)
}

apply_system_prompt_openai <- function(system_prompt, messages) {
if (is.null(system_prompt)) {
return(messages)
}

system_prompt_message <- list(
role = "system",
content = system_prompt
)

# No messages; start with just the system prompt
if (length(messages) == 0) {
return(list(system_prompt_message))
}

# No existing system prompt message; prepend the new one
if (messages[[1]][["role"]] != "system") {
return(c(list(system_prompt_message), messages))
}

# Duplicate system prompt; return as-is
if (messages[[1]][["content"]] == system_prompt) {
return(messages)
}

stop("`system_prompt` and `messages[[1]]` contained conflicting system prompts")
}

check_openai_conversation <- function(messages, allow_null = FALSE) {
if (is.null(messages) && isTRUE(allow_null)) {
return()
}

if (!is.list(messages) ||
!(is.null(names(messages)) || all(names(messages) == ""))) {
stop_input_type(
messages,
"an unnamed list of messages",
allow_null = FALSE
)
}

for (message in messages) {
if (!is.list(message) ||
!is.character(message$role)) {
stop("Each message must be a named list with at least a `role` field.")
}
}
}

#' @rdname new_chat_openai
ChatOpenAI <- R6::R6Class("ChatOpenAI",
public = list(
#' @param system_prompt A system prompt to set the behavior of the assistant.
#' @param base_url The base URL to the endpoint; the default uses ChatGPT.
#' @param api_key The API key to use for authentication. You generally should
#' not supply this directly, but instead set the `OPENAI_API_KEY` environment
#' variable.
#' @param model The model to use for the chat; defaults to GPT-4o mini.
#' @param messages An unnamed list of messages to start the chat with (i.e.,
#' continuing a previous conversation). If `NULL` or zero-length list, the
#' conversation begins from scratch.
#' @param echo If `TRUE`, the `chat()` method streams the response to stdout
#' (while also returning the final response). Note that this has no effect
#' on the `stream()`, `chat_async()`, and `stream_async()` methods.
initialize = function(base_url, model, api_key, system_prompt, echo = FALSE) {
initialize = function(base_url, model, messages, api_key, echo = FALSE) {
private$base_url <- base_url
private$model <- model
private$msgs <- messages %||% list()
private$api_key <- api_key
private$echo <- echo
},

#' @description The messages that have been sent and received so far
#' (optionally starting with the system prompt, if any).
#' @param include_system_prompt Whether to include the system prompt in the
#' messages (if any exists).
messages = function(include_system_prompt = FALSE) {
if (length(private$msgs) == 0) {
return(private$msgs)
}

private$add_message(list(
role = "system",
content = system_prompt
))
if (!include_system_prompt && private$msgs[[1]][["role"]] == "system") {
private$msgs[-1]
} else {
private$msgs
}
},

#' @description Submit text to the chatbot, and return the response as a
Expand All @@ -109,7 +186,7 @@ ChatOpenAI <- R6::R6Class("ChatOpenAI",
# Returns a single message (the final response from the assistant), even if
# multiple rounds of back and forth happened.
coro::collect(private$chat_impl(text, stream = echo, echo = echo))
last_message <- private$messages[[length(private$messages)]]
last_message <- private$msgs[[length(private$msgs)]]
stopifnot(identical(last_message[["role"]], "assistant"))

if (echo) {
Expand All @@ -132,7 +209,7 @@ ChatOpenAI <- R6::R6Class("ChatOpenAI",
private$chat_impl_async(text, stream = FALSE, echo = FALSE)
)
promises::then(done, function(dummy) {
last_message <- private$messages[[length(private$messages)]]
last_message <- private$msgs[[length(private$msgs)]]
stopifnot(identical(last_message[["role"]], "assistant"))
last_message$content
})
Expand Down Expand Up @@ -225,12 +302,23 @@ ChatOpenAI <- R6::R6Class("ChatOpenAI",
invisible(self)
}
),
active = list(
system_prompt = function() {
if (length(private$msgs) == 0) {
return(NULL)
}
if (private$msgs[[1]][["role"]] != "system") {
return(NULL)
}
private$msgs[[1]][["content"]]
}
),
private = list(
base_url = NULL,
model = NULL,
api_key = NULL,

messages = NULL,
msgs = NULL,
echo = NULL,

# OpenAI-compliant tool metadata
Expand All @@ -239,7 +327,7 @@ ChatOpenAI <- R6::R6Class("ChatOpenAI",
tool_funs = NULL,

add_message = function(message) {
private$messages <- c(private$messages, list(message))
private$msgs <- c(private$msgs, list(message))
invisible(self)
},

Expand Down Expand Up @@ -297,7 +385,7 @@ ChatOpenAI <- R6::R6Class("ChatOpenAI",
# complete assistant messages.
submit_messages = generator_method(function(self, private, stream, echo) {
response <- openai_chat(
messages = private$messages,
messages = private$msgs,
tools = private$tool_infos,
base_url = private$base_url,
model = private$model,
Expand Down Expand Up @@ -347,7 +435,7 @@ ChatOpenAI <- R6::R6Class("ChatOpenAI",
# complete assistant messages.
submit_messages_async = async_generator_method(function(self, private, stream, echo) {
response <- openai_chat_async(
messages = private$messages,
messages = private$msgs,
tools = private$tool_infos,
base_url = private$base_url,
model = private$model,
Expand Down Expand Up @@ -396,11 +484,11 @@ ChatOpenAI <- R6::R6Class("ChatOpenAI",

invoke_tools = function() {
if (length(private$tool_infos) > 0) {
last_message <- private$messages[[length(private$messages)]]
last_message <- private$msgs[[length(private$msgs)]]
tool_message <- call_tools(private$tool_funs, last_message)

if (!is.null(tool_message)) {
private$messages <- c(private$messages, tool_message)
private$msgs <- c(private$msgs, tool_message)
return(TRUE)
}
}
Expand All @@ -411,23 +499,44 @@ ChatOpenAI <- R6::R6Class("ChatOpenAI",

#' @export
print.ChatOpenAI <- function(x, ...) {
cat("<ChatOpenAI>\n")
messages <- x$.__enclos_env__$private$messages
for (message in messages) {
msgs <- x$messages(include_system_prompt = TRUE)
msgs_without_system_prompt <- x$messages(include_system_prompt = FALSE)
cat(paste0("<ChatOpenAI messages=", length(msgs_without_system_prompt), ">\n"))
for (message in msgs) {
color <- switch(message$role,
user = "blue",
system = ,
assistant = "green"
user = cli::col_blue,
assistant = cli::col_green,
identity
)
cat(cli::rule(message$role, col = color), "\n", sep = "")
cat(message$content, "\n", sep = "")
cli::cli_rule("{color(message$role)}")
if (!is.null(message$content)) {
# Using cli_text for word wrapping. Passing `"{message$content}"` instead of
# `message$content` to avoid evaluation of the (potentially malicious)
# content.
cli::cli_text("{message$content}")
}
if (!is.null(message$tool_calls)) {
cli::cli_text("Tool calls:")
for (tool_call in message$tool_calls) {
funcname <- tool_call$`function`$name
args <- tool_call$`function`$arguments
tryCatch({
args_parsed <- jsonlite::parse_json(tool_call$`function`$arguments)
args <- rlang::call2(funcname, !!!args_parsed)
cli::cli_text(format(args))
}, error = function(e) {
# In case parsing the JSON fails
cli::cli_text("{funcname}({args})")
})
}
}
}

invisible(x)
}


last_message <- function(chat) {
messages <- chat$.__enclos_env__$private$messages
messages <- chat$messages()
messages[[length(messages)]]
}
5 changes: 3 additions & 2 deletions man/elmer-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 7628a0b

Please sign in to comment.