Skip to content

Commit

Permalink
Make tool calling generic (#51)
Browse files Browse the repository at this point in the history
hadley authored Sep 26, 2024
1 parent b24a679 commit 72e49ce
Showing 8 changed files with 235 additions and 212 deletions.
46 changes: 46 additions & 0 deletions R/api-openai.R
Original file line number Diff line number Diff line change
@@ -246,3 +246,49 @@ method(value_text, openai_model) <- function(model, event) {
method(value_message, openai_model) <- function(model, result) {
result$choices[[1]]$message
}

method(value_tool_calls, openai_model) <- function(model, message, tools) {
lapply(message$tool_calls, function(call) {
fun <- tools[[call$`function`$name]]
args <- jsonlite::parse_json(call$`function`$arguments)
list(fun = fun, args = args, id = call$id)
})
}

method(call_tools, openai_model) <- function(model, tool_calls) {
lapply(tool_calls, function(call) {
result <- call_tool(call$fun, call$args)

if (promises::is.promise(result)) {
cli::cli_abort(c(
"Can't use async tools with `$chat()` or `$stream()`.",
i = "Async tools are supported, but you must use `$chat_async()` or `$stream_async()`."
))
}

openai_tool_result(result, call$id)
})
}

rlang::on_load(
method(call_tools_async, openai_model) <- coro::async(function(model, tool_calls) {
# We call it this way instead of a more natural for + await_each() because
# we want to run all the async tool calls in parallel
result_promises <- lapply(tool_calls, function(call) {
promises::then(
call_tool_async(call$fun, call$args),
function(result) openai_tool_result(result, id = call$id)
)
})

promises::promise_all(.list = result_promises)
})
)

openai_tool_result <- function(result, id) {
list(
role = "tool",
content = toString(result),
tool_call_id = id
)
}
20 changes: 20 additions & 0 deletions R/api.R
Original file line number Diff line number Diff line change
@@ -127,3 +127,23 @@ value_message <- new_generic("value_message", "model",
S7_dispatch()
}
)

# Tool calling -----------------------------------------------------------------

value_tool_calls <- new_generic("value_tool_calls", "model",
function(model, message, tools) {
S7_dispatch()
}
)

call_tools <- new_generic("call_tools", "model",
function(model, tool_calls) {
S7_dispatch()
}
)

call_tools_async <- new_generic("call_tools_async", "model",
function(model, tool_calls) {
S7_dispatch()
}
)
8 changes: 4 additions & 4 deletions R/chat.R
Original file line number Diff line number Diff line change
@@ -322,8 +322,8 @@ Chat <- R6::R6Class("Chat",
}

last_message <- private$msgs[[length(private$msgs)]]
tool_calls <- openai_value_tool_calls(last_message, private$tool_funs)
tool_messages <- openai_call_tools(tool_calls)
tool_calls <- value_tool_calls(private$model, last_message, private$tool_funs)
tool_messages <- call_tools(private$model, tool_calls)

if (length(tool_messages) > 0) {
private$msgs <- c(private$msgs, tool_messages)
@@ -339,8 +339,8 @@ Chat <- R6::R6Class("Chat",
}

last_message <- private$msgs[[length(private$msgs)]]
tool_calls <- openai_value_tool_calls(last_message, private$tool_funs)
tool_messages <- await(openai_call_tools_async(tool_calls))
tool_calls <- value_tool_calls(private$model, last_message, private$tool_funs)
tool_messages <- await(call_tools_async(private$model, tool_calls))

if (length(tool_messages) > 0) {
private$msgs <- c(private$msgs, tool_messages)
44 changes: 0 additions & 44 deletions R/tools.R
Original file line number Diff line number Diff line change
@@ -1,47 +1,3 @@
openai_value_tool_calls <- function(message, tools) {
lapply(message$tool_calls, function(call) {
fun <- tools[[call$`function`$name]]
args <- jsonlite::parse_json(call$`function`$arguments)
list(fun = fun, args = args, id = call$id)
})
}

openai_call_tools <- function(tool_calls) {
lapply(tool_calls, function(call) {
result <- call_tool(call$fun, call$args)

if (promises::is.promise(result)) {
cli::cli_abort(c(
"Can't use async tools with `$chat()` or `$stream()`.",
i = "Async tools are supported, but you must use `$chat_async()` or `$stream_async()`."
))
}

openai_tool_result(result, call$id)
})
}

rlang::on_load(openai_call_tools_async <- coro::async(function(tool_calls) {
# We call it this way instead of a more natural for + await_each() because
# we want to run all the async tool calls in parallel
result_promises <- lapply(tool_calls, function(call) {
promises::then(
call_tool_async(call$fun, call$args),
function(result) openai_tool_result(result, id = call$id)
)
})

promises::promise_all(.list = result_promises)
}))

openai_tool_result <- function(result, id) {
list(
role = "tool",
content = toString(result),
tool_call_id = id
)
}

# Also need to handle edge caess: https://platform.openai.com/docs/guides/function-calling/edge-cases
call_tool <- function(fun, arguments) {
if (is.null(fun)) {
79 changes: 79 additions & 0 deletions tests/testthat/_snaps/api-openai.md
Original file line number Diff line number Diff line change
@@ -5,3 +5,82 @@
Message
Using model = "gpt-4o-mini".

# repeated tool calls (sync)

Code
chat
Output
<Chat messages=12>
Message
-- system ----------------------------------------------------------------------
Be very terse, not even punctuation.
-- user ------------------------------------------------------------------------
Pick a random number. If it's positive, tell me the current time in New York.
If it's negative, tell me the current time in Seattle. Use ISO-8601, e.g.
'2006-01-02T15:04:05'.
-- assistant -------------------------------------------------------------------
Tool calls:
rnorm(n = 1L)
-- tool ------------------------------------------------------------------------
1
-- assistant -------------------------------------------------------------------
Tool calls:
get_time(tz = "America/New_York")
-- tool ------------------------------------------------------------------------
2020-08-01 14:00:00
-- assistant -------------------------------------------------------------------
2020-08-01T14:00:00
-- user ------------------------------------------------------------------------
Great. Do it again.
-- assistant -------------------------------------------------------------------
Tool calls:
rnorm(n = 1L)
-- tool ------------------------------------------------------------------------
-1
-- assistant -------------------------------------------------------------------
Tool calls:
get_time(tz = "America/Los_Angeles")
-- tool ------------------------------------------------------------------------
2020-08-01 11:00:00
-- assistant -------------------------------------------------------------------
2020-08-01T11:00:00

# repeated tool calls (async)

Code
chat_async
Output
<Chat messages=12>
Message
-- system ----------------------------------------------------------------------
Be very terse, not even punctuation.
-- user ------------------------------------------------------------------------
Pick a random number. If it's positive, tell me the current time in New York.
If it's negative, tell me the current time in Seattle. Use ISO-8601.
-- assistant -------------------------------------------------------------------
Tool calls:
rnorm(n = 1L)
-- tool ------------------------------------------------------------------------
1
-- assistant -------------------------------------------------------------------
Tool calls:
get_time(tz = "America/New_York")
-- tool ------------------------------------------------------------------------
2020-08-01 14:00:00
-- assistant -------------------------------------------------------------------
2020-08-01T14:00:00
-- user ------------------------------------------------------------------------
Great. Do it again.
-- assistant -------------------------------------------------------------------
Tool calls:
rnorm(n = 1L)
-- tool ------------------------------------------------------------------------
-1
-- assistant -------------------------------------------------------------------
Tool calls:
get_time(tz = "America/Los_Angeles")
-- tool ------------------------------------------------------------------------
2020-08-01 11:00:00
-- assistant -------------------------------------------------------------------
2020-08-01T11:00:00

79 changes: 0 additions & 79 deletions tests/testthat/_snaps/tools.md
Original file line number Diff line number Diff line change
@@ -5,82 +5,3 @@
Output
[1] "Error calling tool: second argument must be a list"

# repeated tool calls (sync)

Code
chat
Output
<Chat messages=12>
Message
-- system ----------------------------------------------------------------------
Be very terse, not even punctuation.
-- user ------------------------------------------------------------------------
Pick a random number. If it's positive, tell me the current time in New York.
If it's negative, tell me the current time in Seattle. Use ISO-8601, e.g.
'2006-01-02T15:04:05'.
-- assistant -------------------------------------------------------------------
Tool calls:
rnorm(n = 1L)
-- tool ------------------------------------------------------------------------
1
-- assistant -------------------------------------------------------------------
Tool calls:
get_time(tz = "America/New_York")
-- tool ------------------------------------------------------------------------
2020-08-01 14:00:00
-- assistant -------------------------------------------------------------------
2020-08-01T14:00:00
-- user ------------------------------------------------------------------------
Great. Do it again.
-- assistant -------------------------------------------------------------------
Tool calls:
rnorm(n = 1L)
-- tool ------------------------------------------------------------------------
-1
-- assistant -------------------------------------------------------------------
Tool calls:
get_time(tz = "America/Los_Angeles")
-- tool ------------------------------------------------------------------------
2020-08-01 11:00:00
-- assistant -------------------------------------------------------------------
2020-08-01T11:00:00

# repeated tool calls (async)

Code
chat_async
Output
<Chat messages=12>
Message
-- system ----------------------------------------------------------------------
Be very terse, not even punctuation.
-- user ------------------------------------------------------------------------
Pick a random number. If it's positive, tell me the current time in New York.
If it's negative, tell me the current time in Seattle. Use ISO-8601.
-- assistant -------------------------------------------------------------------
Tool calls:
rnorm(n = 1L)
-- tool ------------------------------------------------------------------------
1
-- assistant -------------------------------------------------------------------
Tool calls:
get_time(tz = "America/New_York")
-- tool ------------------------------------------------------------------------
2020-08-01 14:00:00
-- assistant -------------------------------------------------------------------
2020-08-01T14:00:00
-- user ------------------------------------------------------------------------
Great. Do it again.
-- assistant -------------------------------------------------------------------
Tool calls:
rnorm(n = 1L)
-- tool ------------------------------------------------------------------------
-1
-- assistant -------------------------------------------------------------------
Tool calls:
get_time(tz = "America/Los_Angeles")
-- tool ------------------------------------------------------------------------
2020-08-01 11:00:00
-- assistant -------------------------------------------------------------------
2020-08-01T11:00:00

86 changes: 86 additions & 0 deletions tests/testthat/test-api-openai.R
Original file line number Diff line number Diff line change
@@ -42,3 +42,89 @@ test_that("system prompt is applied correctly", {
test_that("default model is reported", {
expect_snapshot(. <- new_chat_openai()$chat("Hi"))
})


test_that("can make a simple tool call", {
get_date <- function() "2024-01-01"
chat <- new_chat_openai(system_prompt = "Be very terse, not even punctuation.")
chat$register_tool(get_date, "get_date", "Gets the current date", list())

result <- chat$chat("What's the current date?")
expect_equal(result, "2024-01-01")

result <- chat$chat("What day of the week is it?")
expect_equal(result, "Tuesday")
})

test_that("repeated tool calls (sync)", {
not_actually_random_number <- 1

chat <- new_chat_openai(system_prompt = "Be very terse, not even punctuation.")
chat$register_tool(
fun = function(tz) format(as.POSIXct(as.POSIXct("2020-08-01 18:00:00", tz="UTC"), tz=tz)),
"get_time",
"Gets the current time",
list("tz" = tool_arg(type = "string", description = "Time zone", required = TRUE)),
strict = TRUE
)
chat$register_tool(
fun = function(n, mean, sd) { not_actually_random_number },
name = "rnorm",
description = "Drawn numbers from a random normal distribution",
arguments = list(
"n" = tool_arg(type = "integer", description = "The number of observations. Must be a positive integer."),
"mean" = tool_arg(type = "number", description = "The mean value of the distribution. Defaults to 0.", required = FALSE),
"sd" = tool_arg(type = "number", description = "The standard deviation of the distribution. Must be a non-negative number. Defaults to 1.", required = FALSE)
),
strict = FALSE
)

result <- coro::collect(chat$stream("Pick a random number. If it's positive, tell me the current time in New York. If it's negative, tell me the current time in Seattle. Use ISO-8601, e.g. '2006-01-02T15:04:05'."))
expect_identical(paste(result, collapse = ""), "2020-08-01T14:00:00\n")

not_actually_random_number <- -1
result <- coro::collect(chat$stream("Great. Do it again."))
expect_identical(paste(result, collapse = ""), "2020-08-01T11:00:00\n")

expect_snapshot(chat)
})

test_that("repeated tool calls (async)", {
not_actually_random_number <- 1

chat_async <- new_chat_openai(system_prompt = "Be very terse, not even punctuation.")
chat_async$register_tool(
fun = function(tz) format(as.POSIXct(as.POSIXct("2020-08-01 18:00:00", tz="UTC"), tz=tz)),
"get_time",
"Gets the current time",
list("tz" = tool_arg(type = "string", description = "Time zone", required = TRUE)),
strict = TRUE
)
# An async tool
chat_async$register_tool(
fun = coro::async(function(n, mean, sd) {
await(coro::async_sleep(0.2))
not_actually_random_number
}),
name = "rnorm",
description = "Drawn numbers from a random normal distribution",
arguments = list(
"n" = tool_arg(type = "integer", description = "The number of observations. Must be a positive integer."),
"mean" = tool_arg(type = "number", description = "The mean value of the distribution. Defaults to 0.", required = FALSE),
"sd" = tool_arg(type = "number", description = "The standard deviation of the distribution. Must be a non-negative number. Defaults to 1.", required = FALSE)
),
strict = FALSE
)

result <- sync(coro::async_collect(chat_async$stream_async("Pick a random number. If it's positive, tell me the current time in New York. If it's negative, tell me the current time in Seattle. Use ISO-8601.")))
expect_identical(paste(result, collapse = ""), "2020-08-01T14:00:00\n")

not_actually_random_number <- -1
result <- sync(coro::async_collect(chat_async$stream_async("Great. Do it again.")))
expect_identical(paste(result, collapse = ""), "2020-08-01T11:00:00\n")

expect_snapshot(chat_async)

# Can't use async tools with sync methods
expect_error(chat_async$chat("Great. Do it again."), "chat_async")
})
Loading

0 comments on commit 72e49ce

Please sign in to comment.