Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polish decorator implementation a litte #21

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions R/chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ ChatOpenAI <- R6::R6Class("ChatOpenAI",

# If stream = TRUE, yields completion deltas. If stream = FALSE, yields
# complete assistant messages.
chat_impl = generator_method(function(self, private, text, stream, echo) {
chat_impl = R6_decorate(coro::generator, function(self, private, text, stream, echo) {
private$add_message(list(role = "user", content = text))
while (TRUE) {
for (chunk in private$submit_messages(stream = stream, echo = echo)) {
Expand All @@ -364,7 +364,7 @@ ChatOpenAI <- R6::R6Class("ChatOpenAI",

# If stream = TRUE, yields completion deltas. If stream = FALSE, yields
# complete assistant messages.
chat_impl_async = async_generator_method(function(self, private, text, stream, echo) {
chat_impl_async = R6_decorate(coro::async_generator, function(self, private, text, stream, echo) {
private$add_message(list(role = "user", content = text))
while (TRUE) {
for (chunk in await_each(private$submit_messages_async(stream = stream, echo = echo))) {
Expand All @@ -383,7 +383,7 @@ ChatOpenAI <- R6::R6Class("ChatOpenAI",

# If stream = TRUE, yields completion deltas. If stream = FALSE, yields
# complete assistant messages.
submit_messages = generator_method(function(self, private, stream, echo) {
submit_messages = R6_decorate(coro::generator, function(self, private, stream, echo) {
response <- openai_chat(
messages = private$msgs,
tools = private$tool_infos,
Expand Down Expand Up @@ -433,7 +433,7 @@ ChatOpenAI <- R6::R6Class("ChatOpenAI",

# If stream = TRUE, yields completion deltas. If stream = FALSE, yields
# complete assistant messages.
submit_messages_async = async_generator_method(function(self, private, stream, echo) {
submit_messages_async = R6_decorate(coro::async_generator, function(self, private, stream, echo) {
response <- openai_chat_async(
messages = private$msgs,
tools = private$tool_infos,
Expand Down
84 changes: 30 additions & 54 deletions R/coro-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,72 +10,48 @@
# R6 classes that depend on this will be instantiated at package build time; so
# the coro generator functions will be burned into the package .Rds file.

# So `R CMD check` doesn't get confused about these variables being used from
# methods
utils::globalVariables(c("self", "private", "generator_env", "exits"))
generators <- new_environment()
generators$cur_id <- 1L

generators <- new.env()
new_id <- function() {
generators$cur_id <- generators$cur_id + 1L
as.character(generators$cur_id)
}

# Decorator for anonymous functions; the return value is intended to be used as
# an R6 method. Unlike regular R6 methods, the decorated function must have
# `self` as the first argument (which will be automatically passed in by the
# decorator). If necessary we can also provide access to `private` in the same
# way.
generator_method <- function(func) {
fn <- substitute(func)

stopifnot(
"generator methods must have `self` parameter" = identical(names(formals(func))[1], "self")
)
stopifnot(
"generator methods must have `self` parameter" = identical(names(formals(func))[2], "private")
)

expr <- rlang::inject(
base::quote(coro::generator(!!fn))
)
generator <- eval(expr, parent.frame())

unique_id <- as.character(sample.int(99999999, 1))
generators[[unique_id]] <- generator

rlang::inject(
function(...) {
# Must use elmer::: because the lexical environment of this function is
# about to get wrecked by R6
elmer:::generators[[!!unique_id]](self, private, ...)
}
)
}
R6_decorate <- function(wrapper, func, print = FALSE) {
wrapper <- enexpr(wrapper)
fn <- enexpr(func)

arg_names <- names(formals(func))
if (length(arg_names) < 2) {
cli::cli_abort("Function must have at least two arguments.", .internal = TRUE)
} else if (arg_names[[1]] != "self") {
cli::cli_abort("First argument must be {.arg self}.", .internal = TRUE)
} else if (arg_names[[2]] != "private") {
cli::cli_abort("Second argument must be {.arg private}.", .internal = TRUE)
}

# Same as generator_method, but for async logic
async_generator_method <- function(func, print = FALSE) {
fn <- substitute(func)
args_def <- formals(func)[-(1:2)]
args_call <- lapply(set_names(names(args_def)), as.symbol)

stopifnot(
"generator methods must have `self` parameter" = identical(names(formals(func))[1], "self")
)
stopifnot(
"generator methods must have `self` parameter" = identical(names(formals(func))[2], "private")
)

expr <- rlang::inject(
base::quote(coro::async_generator(!!fn))
)
generator <- eval(expr, parent.frame())
id <- new_id()
generators[[id]] <- inject((!!wrapper)(!!fn), parent.frame())

unique_id <- paste0("a", sample.int(99999999, 1))
generators[[unique_id]] <- generator
# Supress R CMD check note
self <- private <- NULL

gen <- rlang::inject(
function(...) {
# Must use elmer::: because the lexical environment of this function is
# about to get wrecked by R6
elmer:::generators[[!!unique_id]](self, private, ...)
}
# Must use elmer::: because the lexical environment of this function is
# about to get wrecked by R6
gen_method <- new_function(args_def,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably overkill since they're private methods, but this matches the actual argument names rather than relying on ....

expr(elmer:::generators[[!!id]](self, private, !!!args_call))
)
if (print) {
print(gen, internals = TRUE)
print(gen_method, internals = TRUE)
}
gen
gen_method
}
Loading