Skip to content

Commit

Permalink
Make more usable
Browse files Browse the repository at this point in the history
* Reduce chat interface
* Add suffix for API type
* Add a basic integration test
* Add docs
* Check argument types
  • Loading branch information
hadley committed Sep 4, 2024
1 parent 333b624 commit c071368
Show file tree
Hide file tree
Showing 14 changed files with 1,337 additions and 146 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Imports:
httr2 (>= 1.0.3.9000),
jsonlite,
R6,
rlang
rlang (>= 1.1.0)
Suggests:
testthat (>= 3.0.0),
withr
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Generated by roxygen2: do not edit by hand

export(new_chat_openai)
export(tool_arg)
import(rlang)
importFrom(R6,R6Class)
256 changes: 167 additions & 89 deletions R/chat.R
Original file line number Diff line number Diff line change
@@ -1,105 +1,183 @@
#' @examples
#' chat <- new_chat()
#' chat$chat("What is the difference between a tibble and a data frame? Answer briefly")
#' chat$chat("Please summarise into a very concise bulleted list.", stream = FALSE)
#' chat$chat("Even more concise!!!", stream = FALSE)
#' chat$chat("Even more concise! Use emoji to save characters", stream = FALSE)
#' Create a chatbot that speaks to an OpenAI compatible endpoint
#'
#' chat <- new_chat()
#' chat$add_tool(rnorm)
#' chat$chat("Give me five numbers from a random normal distribution. Briefly explain your work.")
new_chat <- function(system_prompt = NULL,
base_url = "https://api.openai.com/v1",
api_key = open_ai_key(),
model = "gpt-4o-mini") {

#' This function returns an R6 object that takes care of managing the state
#' associated with the chat; i.e. it records the messages that you send to the
#' server, and the messages that you receive back. If you register a tool
#' (aka an R function), it also takes care of the tool loop.
#'
#' @param system_prompt A system prompt to set the behavior of the assistant.
#' @param base_url The base URL to the endpoint; the default uses ChatGPT.
#' @param api_key The API key to use for authentication. You generally should
#' not supply this directly, but instead set the `OPENAI_API_KEY` environment
#' variable.
#' @param model The model to use for the chat; defaults to GPT-4o mini.
#' @param quiet If `TRUE` does not print output as its received.
#' @export
#' @examplesIf elmer:::openai_key_exists()
#' chat <- new_chat_openai()
#' chat$chat("
#' What is the difference between a tibble and a data frame?
#' Answer with a bulleted list
#' ")
#'
#' chat <- new_chat_openai()
#' chat$register_tool(
#' name = "rnorm",
#' description = "Drawn numbers from a random normal distribution",
#' arguments = list(
#' tool_arg(
#' "n",
#' type = "integer",
#' description = "The number of observations. Must be a positive integer."
#' ),
#' tool_arg(
#' "mean",
#' type = "number",
#' description = "The mean value of the distribution."
#' ),
#' tool_arg(
#' "sd",
#' type = "number",
#' description = "The standard deviation of the distribution. Must be a non-negative number."
#' )
#' )
#' )
#' chat$chat("
#' Give me five numbers from a random normal distribution.
#' Briefly explain your work.
#' ")
new_chat_openai <- function(system_prompt = NULL,
base_url = "https://api.openai.com/v1",
api_key = openai_key(),
model = "gpt-4o-mini",
quiet = FALSE) {
check_string(system_prompt, allow_null = TRUE)
check_string(base_url)
check_string(api_key)
check_string(model)
check_bool(quiet)

system_prompt <- system_prompt %||%
"You are a helpful assistant from New Zealand who is an experienced R programmer"

chat <- Chat$new(
Chat$new(
base_url = base_url,
model = model,
api_key = api_key
api_key = api_key,
system_prompt = system_prompt,
quiet = quiet
)
chat$add_message(list(
role = "system",
content = system_prompt
))
chat
}

Chat <- R6::R6Class("Chat", public = list(
base_url = NULL,
model = NULL,
api_key = NULL,

messages = NULL,
tools = NULL,

initialize = function(base_url, model, api_key) {
self$base_url <- base_url
self$model <- model
self$api_key <- api_key
},

add_message = function(message) {
self$messages <- c(self$messages, list(message))
invisible(self)
},

add_tool = function(tool) {
self$tools <- c(self$tools, list(tool))
invisible(self)
},

register_tool = function(name, description, arguments, strict = TRUE) {
tool <- tool_def(
name = name,
description = description,
arguments = arguments,
strict = strict
)
self$add_tool(tool)
},

chat = function(text, stream = TRUE) {
self$add_message(list(role = "user", content = text))
self$submit_messages(stream = stream)
self$tool_loop()
invisible(self)
},

submit_messages = function(stream = TRUE) {
result <- open_ai_chat(
messages = self$messages,
tools = self$tools,
base_url = self$base_url,
model = self$model,
stream = stream,
api_key = self$api_key
)
if (stream) {
self$add_message(result$choices[[1]]$delta)
} else {
self$add_message(result$choices[[1]]$message)
}
#' @rdname new_chat_openai
Chat <- R6::R6Class("Chat",
public = list(
initialize = function(base_url, model, api_key, system_prompt, quiet = TRUE) {
private$base_url <- base_url
private$model <- model
private$api_key <- api_key
private$quiet <- quiet

private$add_message(list(
role = "system",
content = system_prompt
))
},

invisible(self)
},
#' @description Submit text to the chatbot.
#' @param text The text to send to the chatbot
#' @param stream Whether to stream the response or not.
chat = function(text, stream = TRUE) {
check_string(text)
check_bool(stream)

tool_loop = function() {
if (is.null(self$tools)) {
return()
private$add_message(list(role = "user", content = text))
private$submit_messages(stream = stream)
private$tool_loop()
invisible(self)
},

#' @description Register a tool (an R function) that the chatbot can use.
#' If the chatbot decides to use the function, elmer will automatically
#' call it and submit the results back.
#' @param name The name of the function.
#' @param description A detailed description of what the function does.
#' Generally, the more information that you can provide here, the better.
#' @param arguments A list of arguments that the function accepts.
#' Should be a list of objects created by [tool_arg()].
#' @param strict Should the argument definition be strictly enforced?
register_tool = function(name, description, arguments, strict = TRUE) {
check_string(name)
check_string(description)
check_bool(strict)

tool <- tool_def(
name = name,
description = description,
arguments = arguments,
strict = strict
)
private$add_tool(tool)
invisible(self)
}
),
private = list(
base_url = NULL,
model = NULL,
api_key = NULL,

messages = NULL,
tools = NULL,
quiet = NULL,

add_message = function(message) {
private$messages <- c(private$messages, list(message))
invisible(self)
},

last_message <- self$messages[[length(self$messages)]]
tool_message <- call_tools(last_message)
add_tool = function(tool) {
private$tools <- c(private$tools, list(tool))
invisible(self)
},

if (is.null(tool_message)) {
return()
submit_messages = function(stream = TRUE) {
result <- openai_chat(
messages = private$messages,
tools = private$tools,
base_url = private$base_url,
model = private$model,
stream = stream,
api_key = private$api_key,
quiet = private$quiet
)
if (stream) {
private$add_message(result$choices[[1]]$delta)
} else {
private$add_message(result$choices[[1]]$message)
}

invisible(self)
},

tool_loop = function() {
if (is.null(private$tools)) {
return()
}

last_message <- private$messages[[length(private$messages)]]
tool_message <- call_tools(last_message)

if (is.null(tool_message)) {
return()
}
private$messages <- c(private$messages, tool_message)
private$submit_messages(stream = FALSE)
}
self$messages <- c(self$messages, tool_message)
self$submit_messages(stream = FALSE)
}
))
)
)


last_message <- function(chat) {
messages <- chat$.__enclos_env__$private$messages
messages[[length(messages)]]
}
1 change: 1 addition & 0 deletions R/elmer-package.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#' @keywords internal
#' @importFrom R6 R6Class
"_PACKAGE"

## usethis namespace: start
Expand Down
Loading

0 comments on commit c071368

Please sign in to comment.