Skip to content

Commit

Permalink
Capture and pass random seed (#28)
Browse files Browse the repository at this point in the history
Defaults to set seed for testing to minimise spurious differences. Also fleshes out a bunch of other API parameters, which I _think_ will be useful for thinking about the API more broadly.

Fixes #27
  • Loading branch information
hadley authored Sep 23, 2024
1 parent 726b433 commit 55d7b93
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 43 deletions.
41 changes: 23 additions & 18 deletions R/chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ NULL
#' @param model The model to use for the chat; set to `NULL` (the default) to
#' use a reasonable model, currently `gpt-4o-mini`. We strongly recommend
#' explicitly choosing a model for all but the most casual use.
#' @param seed Optional integer seed that ChatGPT uses to try and make output
#' more reproducible.
#' @param api_args Named list of arbitrary extra arguments passed to every
#' chat API call.
#' @param echo If `TRUE`, the `chat()` method streams the response to stdout by
#' default. (Note that this has no effect on the `stream()`, `chat_async()`,
#' and `stream_async()` methods.)
Expand Down Expand Up @@ -64,25 +68,32 @@ new_chat_openai <- function(system_prompt = NULL,
base_url = "https://api.openai.com/v1",
api_key = openai_key(),
model = NULL,
seed = NULL,
api_args = list(),
echo = FALSE) {
check_string(system_prompt, allow_null = TRUE)
check_openai_conversation(messages, allow_null = TRUE)
check_string(base_url)
check_string(base_url)
check_string(api_key)
check_string(model, allow_null = TRUE, allow_na = TRUE)
check_number_decimal(seed, allow_null = TRUE)
check_bool(echo)

model <- model %||% "gpt-4o-mini"
if (is_testing() && is.null(seed)) {
seed <- 1014
}

messages <- apply_system_prompt_openai(system_prompt, messages)

ChatOpenAI$new(
model <- openai_model(
base_url = base_url,
model = model,
messages = messages,
api_key = api_key,
echo = echo
seed = seed,
extra_args = api_args,
api_key = api_key
)

messages <- apply_system_prompt_openai(system_prompt, messages)
ChatOpenAI$new(model = model, messages = messages, echo = echo)
}

apply_system_prompt_openai <- function(system_prompt, messages) {
Expand Down Expand Up @@ -138,23 +149,17 @@ check_openai_conversation <- function(messages, allow_null = FALSE) {
#' @rdname new_chat_openai
ChatOpenAI <- R6::R6Class("ChatOpenAI",
public = list(
#' @param base_url The base URL to the endpoint; the default uses ChatGPT.
#' @param api_key The API key to use for authentication. You generally should
#' not supply this directly, but instead set the `OPENAI_API_KEY` environment
#' variable.
#' @param model The model to use for the chat; defaults to GPT-4o mini.
#' @param model Model object.
#' @param messages An unnamed list of messages to start the chat with (i.e.,
#' continuing a previous conversation). If `NULL` or zero-length list, the
#' conversation begins from scratch.
#' @param seed Optional integer seed that ChatGPT uses to try and make output
#' more reproducible.
#' @param echo If `TRUE`, the `chat()` method streams the response to stdout
#' (while also returning the final response). Note that this has no effect
#' on the `stream()`, `chat_async()`, and `stream_async()` methods.
initialize = function(base_url, model, messages, api_key, echo = FALSE) {
private$model <- openai_model(
base_url = base_url,
model = model,
api_key = api_key
)
initialize = function(model, messages, seed = NULL, echo = FALSE) {
private$model <- model
private$msgs <- messages %||% list()
private$echo <- echo
},
Expand Down
5 changes: 5 additions & 0 deletions R/http.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
chat_stream <- function(is_done, parse_data) {
# silence R CMD check note
yield <- NULL

force(is_done)
force(parse_data)

Expand Down Expand Up @@ -26,6 +29,8 @@ chat_stream <- function(is_done, parse_data) {
}

chat_stream_async <- function(is_done, parse_data, polling_interval_secs = 0.1) {
# silence R CMD check note
yield <- await <- NULL
force(is_done)
force(parse_data)
force(polling_interval_secs)
Expand Down
50 changes: 32 additions & 18 deletions R/open-ai.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
openai_model <- function(base_url = "https://api.openai.com/v1",
model = "gpt-4o-mini",
api_key = openai_key()) {
model = "gpt-4o-mini",
seed = NULL,
extra_args = list(),
api_key = openai_key()) {
structure(
list(
base_url = base_url,
model = model,
api_key = api_key
seed = seed,
api_key = api_key,
extra_args = list()
),
class = "elmer::openai_model"
)
Expand All @@ -23,8 +27,11 @@ openai_chat <- function(mode = c("value", "stream", "async-stream", "async-value
req <- openai_chat_req(
messages = messages,
tools = tools,
model = model,
stream = stream
model = model$model,
seed = model$seed,
stream = stream,
base_url = model$base_url,
api_key = model$api_key
)

handle_response <- switch(mode,
Expand Down Expand Up @@ -78,23 +85,30 @@ openai_key <- function() {
key
}

openai_chat_req <- function(model,
messages,
# https://platform.openai.com/docs/api-reference/chat/create
openai_chat_req <- function(messages,
tools = list(),
stream = FALSE) {
if (length(tools) == 0) {
# OpenAI rejects tools=[]
tools <- NULL
}
model = "gpt-4o-mini",
seed = NULL,
stream = TRUE,
extra_args = list(),
base_url = "https://api.openai.com/v1",
api_key = openai_key()) {

data <- list(
model = model$model,
stream = stream,
check_string(model)
check_number_whole(seed, allow_null = TRUE)
check_bool(stream)

data <- compact(list2(
messages = messages,
tools = tools
)
model = model,
seed = seed,
stream = stream,
tools = tools,
!!!extra_args
))

req <- openai_request(base_url = model$base_url, key = model$api_key)
req <- openai_request(base_url = base_url, key = api_key)
req <- req_url_path_append(req, "/chat/completions")
req <- req_body_json(req, data)
req
Expand Down
19 changes: 12 additions & 7 deletions man/new_chat_openai.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 55d7b93

Please sign in to comment.