Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make tool calling generic #51

Merged
merged 3 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions R/api-openai.R
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,49 @@ method(value_text, openai_model) <- function(model, event) {
method(value_message, openai_model) <- function(model, result) {
result$choices[[1]]$message
}

method(value_tool_calls, openai_model) <- function(model, message, tools) {
lapply(message$tool_calls, function(call) {
fun <- tools[[call$`function`$name]]
args <- jsonlite::parse_json(call$`function`$arguments)
list(fun = fun, args = args, id = call$id)
})
}

method(call_tools, openai_model) <- function(model, tool_calls) {
lapply(tool_calls, function(call) {
result <- call_tool(call$fun, call$args)

if (promises::is.promise(result)) {
cli::cli_abort(c(
"Can't use async tools with `$chat()` or `$stream()`.",
i = "Async tools are supported, but you must use `$chat_async()` or `$stream_async()`."
))
}

openai_tool_result(result, call$id)
})
}

rlang::on_load(
method(call_tools_async, openai_model) <- coro::async(function(model, tool_calls) {
# We call it this way instead of a more natural for + await_each() because
# we want to run all the async tool calls in parallel
result_promises <- lapply(tool_calls, function(call) {
promises::then(
call_tool_async(call$fun, call$args),
function(result) openai_tool_result(result, id = call$id)
)
})

promises::promise_all(.list = result_promises)
})
)

openai_tool_result <- function(result, id) {
list(
role = "tool",
content = toString(result),
tool_call_id = id
)
}
20 changes: 20 additions & 0 deletions R/api.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,23 @@ value_message <- new_generic("value_message", "model",
S7_dispatch()
}
)

# Tool calling -----------------------------------------------------------------

value_tool_calls <- new_generic("value_tool_calls", "model",
function(model, message, tools) {
S7_dispatch()
}
)

call_tools <- new_generic("call_tools", "model",
function(model, tool_calls) {
S7_dispatch()
}
)

call_tools_async <- new_generic("call_tools_async", "model",
function(model, tool_calls) {
S7_dispatch()
}
)
8 changes: 4 additions & 4 deletions R/chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,8 @@ Chat <- R6::R6Class("Chat",
}

last_message <- private$msgs[[length(private$msgs)]]
tool_calls <- openai_value_tool_calls(last_message, private$tool_funs)
tool_messages <- openai_call_tools(tool_calls)
tool_calls <- value_tool_calls(private$model, last_message, private$tool_funs)
tool_messages <- call_tools(private$model, tool_calls)

if (length(tool_messages) > 0) {
private$msgs <- c(private$msgs, tool_messages)
Expand All @@ -339,8 +339,8 @@ Chat <- R6::R6Class("Chat",
}

last_message <- private$msgs[[length(private$msgs)]]
tool_calls <- openai_value_tool_calls(last_message, private$tool_funs)
tool_messages <- await(openai_call_tools_async(tool_calls))
tool_calls <- value_tool_calls(private$model, last_message, private$tool_funs)
tool_messages <- await(call_tools_async(private$model, tool_calls))

if (length(tool_messages) > 0) {
private$msgs <- c(private$msgs, tool_messages)
Expand Down
44 changes: 0 additions & 44 deletions R/tools.R
Original file line number Diff line number Diff line change
@@ -1,47 +1,3 @@
openai_value_tool_calls <- function(message, tools) {
lapply(message$tool_calls, function(call) {
fun <- tools[[call$`function`$name]]
args <- jsonlite::parse_json(call$`function`$arguments)
list(fun = fun, args = args, id = call$id)
})
}

openai_call_tools <- function(tool_calls) {
lapply(tool_calls, function(call) {
result <- call_tool(call$fun, call$args)

if (promises::is.promise(result)) {
cli::cli_abort(c(
"Can't use async tools with `$chat()` or `$stream()`.",
i = "Async tools are supported, but you must use `$chat_async()` or `$stream_async()`."
))
}

openai_tool_result(result, call$id)
})
}

rlang::on_load(openai_call_tools_async <- coro::async(function(tool_calls) {
# We call it this way instead of a more natural for + await_each() because
# we want to run all the async tool calls in parallel
result_promises <- lapply(tool_calls, function(call) {
promises::then(
call_tool_async(call$fun, call$args),
function(result) openai_tool_result(result, id = call$id)
)
})

promises::promise_all(.list = result_promises)
}))

openai_tool_result <- function(result, id) {
list(
role = "tool",
content = toString(result),
tool_call_id = id
)
}

# Also need to handle edge caess: https://platform.openai.com/docs/guides/function-calling/edge-cases
call_tool <- function(fun, arguments) {
if (is.null(fun)) {
Expand Down
79 changes: 79 additions & 0 deletions tests/testthat/_snaps/api-openai.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,82 @@
Message
Using model = "gpt-4o-mini".

# repeated tool calls (sync)

Code
chat
Output
<Chat messages=12>
Message
-- system ----------------------------------------------------------------------
Be very terse, not even punctuation.
-- user ------------------------------------------------------------------------
Pick a random number. If it's positive, tell me the current time in New York.
If it's negative, tell me the current time in Seattle. Use ISO-8601, e.g.
'2006-01-02T15:04:05'.
-- assistant -------------------------------------------------------------------
Tool calls:
rnorm(n = 1L)
-- tool ------------------------------------------------------------------------
1
-- assistant -------------------------------------------------------------------
Tool calls:
get_time(tz = "America/New_York")
-- tool ------------------------------------------------------------------------
2020-08-01 14:00:00
-- assistant -------------------------------------------------------------------
2020-08-01T14:00:00
-- user ------------------------------------------------------------------------
Great. Do it again.
-- assistant -------------------------------------------------------------------
Tool calls:
rnorm(n = 1L)
-- tool ------------------------------------------------------------------------
-1
-- assistant -------------------------------------------------------------------
Tool calls:
get_time(tz = "America/Los_Angeles")
-- tool ------------------------------------------------------------------------
2020-08-01 11:00:00
-- assistant -------------------------------------------------------------------
2020-08-01T11:00:00

# repeated tool calls (async)

Code
chat_async
Output
<Chat messages=12>
Message
-- system ----------------------------------------------------------------------
Be very terse, not even punctuation.
-- user ------------------------------------------------------------------------
Pick a random number. If it's positive, tell me the current time in New York.
If it's negative, tell me the current time in Seattle. Use ISO-8601.
-- assistant -------------------------------------------------------------------
Tool calls:
rnorm(n = 1L)
-- tool ------------------------------------------------------------------------
1
-- assistant -------------------------------------------------------------------
Tool calls:
get_time(tz = "America/New_York")
-- tool ------------------------------------------------------------------------
2020-08-01 14:00:00
-- assistant -------------------------------------------------------------------
2020-08-01T14:00:00
-- user ------------------------------------------------------------------------
Great. Do it again.
-- assistant -------------------------------------------------------------------
Tool calls:
rnorm(n = 1L)
-- tool ------------------------------------------------------------------------
-1
-- assistant -------------------------------------------------------------------
Tool calls:
get_time(tz = "America/Los_Angeles")
-- tool ------------------------------------------------------------------------
2020-08-01 11:00:00
-- assistant -------------------------------------------------------------------
2020-08-01T11:00:00

79 changes: 0 additions & 79 deletions tests/testthat/_snaps/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,82 +5,3 @@
Output
[1] "Error calling tool: second argument must be a list"

# repeated tool calls (sync)

Code
chat
Output
<Chat messages=12>
Message
-- system ----------------------------------------------------------------------
Be very terse, not even punctuation.
-- user ------------------------------------------------------------------------
Pick a random number. If it's positive, tell me the current time in New York.
If it's negative, tell me the current time in Seattle. Use ISO-8601, e.g.
'2006-01-02T15:04:05'.
-- assistant -------------------------------------------------------------------
Tool calls:
rnorm(n = 1L)
-- tool ------------------------------------------------------------------------
1
-- assistant -------------------------------------------------------------------
Tool calls:
get_time(tz = "America/New_York")
-- tool ------------------------------------------------------------------------
2020-08-01 14:00:00
-- assistant -------------------------------------------------------------------
2020-08-01T14:00:00
-- user ------------------------------------------------------------------------
Great. Do it again.
-- assistant -------------------------------------------------------------------
Tool calls:
rnorm(n = 1L)
-- tool ------------------------------------------------------------------------
-1
-- assistant -------------------------------------------------------------------
Tool calls:
get_time(tz = "America/Los_Angeles")
-- tool ------------------------------------------------------------------------
2020-08-01 11:00:00
-- assistant -------------------------------------------------------------------
2020-08-01T11:00:00

# repeated tool calls (async)

Code
chat_async
Output
<Chat messages=12>
Message
-- system ----------------------------------------------------------------------
Be very terse, not even punctuation.
-- user ------------------------------------------------------------------------
Pick a random number. If it's positive, tell me the current time in New York.
If it's negative, tell me the current time in Seattle. Use ISO-8601.
-- assistant -------------------------------------------------------------------
Tool calls:
rnorm(n = 1L)
-- tool ------------------------------------------------------------------------
1
-- assistant -------------------------------------------------------------------
Tool calls:
get_time(tz = "America/New_York")
-- tool ------------------------------------------------------------------------
2020-08-01 14:00:00
-- assistant -------------------------------------------------------------------
2020-08-01T14:00:00
-- user ------------------------------------------------------------------------
Great. Do it again.
-- assistant -------------------------------------------------------------------
Tool calls:
rnorm(n = 1L)
-- tool ------------------------------------------------------------------------
-1
-- assistant -------------------------------------------------------------------
Tool calls:
get_time(tz = "America/Los_Angeles")
-- tool ------------------------------------------------------------------------
2020-08-01 11:00:00
-- assistant -------------------------------------------------------------------
2020-08-01T11:00:00

86 changes: 86 additions & 0 deletions tests/testthat/test-api-openai.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,89 @@ test_that("system prompt is applied correctly", {
test_that("default model is reported", {
expect_snapshot(. <- new_chat_openai()$chat("Hi"))
})


test_that("can make a simple tool call", {
get_date <- function() "2024-01-01"
chat <- new_chat_openai(system_prompt = "Be very terse, not even punctuation.")
chat$register_tool(get_date, "get_date", "Gets the current date", list())

result <- chat$chat("What's the current date?")
expect_equal(result, "2024-01-01")

result <- chat$chat("What day of the week is it?")
expect_equal(result, "Tuesday")
})

test_that("repeated tool calls (sync)", {
not_actually_random_number <- 1

chat <- new_chat_openai(system_prompt = "Be very terse, not even punctuation.")
chat$register_tool(
fun = function(tz) format(as.POSIXct(as.POSIXct("2020-08-01 18:00:00", tz="UTC"), tz=tz)),
"get_time",
"Gets the current time",
list("tz" = tool_arg(type = "string", description = "Time zone", required = TRUE)),
strict = TRUE
)
chat$register_tool(
fun = function(n, mean, sd) { not_actually_random_number },
name = "rnorm",
description = "Drawn numbers from a random normal distribution",
arguments = list(
"n" = tool_arg(type = "integer", description = "The number of observations. Must be a positive integer."),
"mean" = tool_arg(type = "number", description = "The mean value of the distribution. Defaults to 0.", required = FALSE),
"sd" = tool_arg(type = "number", description = "The standard deviation of the distribution. Must be a non-negative number. Defaults to 1.", required = FALSE)
),
strict = FALSE
)

result <- coro::collect(chat$stream("Pick a random number. If it's positive, tell me the current time in New York. If it's negative, tell me the current time in Seattle. Use ISO-8601, e.g. '2006-01-02T15:04:05'."))
expect_identical(paste(result, collapse = ""), "2020-08-01T14:00:00\n")

not_actually_random_number <- -1
result <- coro::collect(chat$stream("Great. Do it again."))
expect_identical(paste(result, collapse = ""), "2020-08-01T11:00:00\n")

expect_snapshot(chat)
})

test_that("repeated tool calls (async)", {
not_actually_random_number <- 1

chat_async <- new_chat_openai(system_prompt = "Be very terse, not even punctuation.")
chat_async$register_tool(
fun = function(tz) format(as.POSIXct(as.POSIXct("2020-08-01 18:00:00", tz="UTC"), tz=tz)),
"get_time",
"Gets the current time",
list("tz" = tool_arg(type = "string", description = "Time zone", required = TRUE)),
strict = TRUE
)
# An async tool
chat_async$register_tool(
fun = coro::async(function(n, mean, sd) {
await(coro::async_sleep(0.2))
not_actually_random_number
}),
name = "rnorm",
description = "Drawn numbers from a random normal distribution",
arguments = list(
"n" = tool_arg(type = "integer", description = "The number of observations. Must be a positive integer."),
"mean" = tool_arg(type = "number", description = "The mean value of the distribution. Defaults to 0.", required = FALSE),
"sd" = tool_arg(type = "number", description = "The standard deviation of the distribution. Must be a non-negative number. Defaults to 1.", required = FALSE)
),
strict = FALSE
)

result <- sync(coro::async_collect(chat_async$stream_async("Pick a random number. If it's positive, tell me the current time in New York. If it's negative, tell me the current time in Seattle. Use ISO-8601.")))
expect_identical(paste(result, collapse = ""), "2020-08-01T14:00:00\n")

not_actually_random_number <- -1
result <- sync(coro::async_collect(chat_async$stream_async("Great. Do it again.")))
expect_identical(paste(result, collapse = ""), "2020-08-01T11:00:00\n")

expect_snapshot(chat_async)

# Can't use async tools with sync methods
expect_error(chat_async$chat("Great. Do it again."), "chat_async")
})
Loading
Loading