-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a new chat_snowflake() provider.
This commit adds `chat_snowflake()` for chatting with models hosted through Snowflake's [Cortex LLM REST API][0]: chat <- chat_snowflake() chat$chat("Tell me a joke in the form of a SQL query.") On the backend it looks fairly similar to OpenAI, though it has only the basic textual functionality, and so many advanced ellmer features are not available. I also reused quite a bit of the credential support and utilities from `chat_cortex()`, so this commit also includes some minor refactoring of that provider. Right now the default model for `chat_snowflake()` is Llama 3.1 70B, but we should change it to Claude 3.5 Sonnet when that gets rolled out more widely; it's only available to customers in the `us-west-1` Snowflake region right now. Unit tests are included. Part of #255. [0]: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api Signed-off-by: Aaron Jacobs <[email protected]>
- Loading branch information
Showing
10 changed files
with
340 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
#' @include provider-openai.R | ||
#' @include content.R | ||
NULL | ||
|
||
#' Chat with a model hosted on Snowflake | ||
#' | ||
#' @description | ||
#' The Snowflake provider allows you to interact with LLM models available | ||
#' through the [Cortex LLM REST API](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api). | ||
#' | ||
#' Note that Snowflake-hosted models do not support images, tool calling, or | ||
#' structured outputs. | ||
#' | ||
#' ## Authentication | ||
#' | ||
#' `chat_snowflake()` picks up the following ambient Snowflake credentials: | ||
#' | ||
#' - A static OAuth token defined via the `SNOWFLAKE_TOKEN` environment | ||
#' variable. | ||
#' - Key-pair authentication credentials defined via the `SNOWFLAKE_USER` and | ||
#' `SNOWFLAKE_PRIVATE_KEY` (which can be a PEM-encoded private key or a path | ||
#' to one) environment variables. | ||
#' - Posit Workbench-managed Snowflake credentials for the corresponding | ||
#' `account`. | ||
#' | ||
#' @inheritParams chat_openai | ||
#' @inheritParams chat_cortex | ||
#' @inherit chat_openai return | ||
#' @examplesIf has_credentials("cortex") | ||
#' chat <- chat_snowflake() | ||
#' chat$chat("Tell me a joke in the form of a SQL query.") | ||
#' @export | ||
chat_snowflake <- function(system_prompt = NULL, | ||
turns = NULL, | ||
account = snowflake_account(), | ||
credentials = NULL, | ||
model = NULL, | ||
api_args = list(), | ||
echo = c("none", "text", "all")) { | ||
turns <- normalize_turns(turns, system_prompt) | ||
check_string(account, allow_empty = FALSE) | ||
model <- set_default(model, "llama3.1-70b") | ||
echo <- check_echo(echo) | ||
|
||
if (is_list(credentials)) { | ||
static_credentials <- force(credentials) | ||
credentials <- function(account) static_credentials | ||
} | ||
check_function(credentials, allow_null = TRUE) | ||
|
||
provider <- ProviderSnowflake( | ||
base_url = snowflake_url(account), | ||
account = account, | ||
credentials = credentials, | ||
model = model, | ||
extra_args = api_args, | ||
# We need an empty api_key for S7 validation. | ||
api_key = "" | ||
) | ||
|
||
Chat$new(provider = provider, turns = turns, echo = echo) | ||
} | ||
|
||
ProviderSnowflake <- new_class( | ||
"ProviderSnowflake", | ||
parent = ProviderOpenAI, | ||
properties = list( | ||
account = prop_string(), | ||
credentials = class_function | NULL | ||
) | ||
) | ||
|
||
# See: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#api-reference | ||
method(chat_request, ProviderSnowflake) <- function(provider, | ||
stream = TRUE, | ||
turns = list(), | ||
tools = list(), | ||
type = NULL, | ||
extra_args = list()) { | ||
if (length(tools) != 0) { | ||
cli::cli_abort( | ||
"Tool calling is not supported for {.code chat_snowflake()}." | ||
) | ||
} | ||
if (!is.null(type) != 0) { | ||
cli::cli_abort( | ||
"Structured data extraction is not supported for {.code chat_snowflake()}." | ||
) | ||
} | ||
if (!stream) { | ||
cli::cli_abort( | ||
"Non-streaming responses are not supported for {.code chat_snowflake()}." | ||
) | ||
} | ||
|
||
req <- request(provider@base_url) | ||
req <- req_url_path_append(req, "/api/v2/cortex/inference:complete") | ||
creds <- cortex_credentials(provider@account, provider@credentials) | ||
req <- req_headers(req, !!!creds, .redact = "Authorization") | ||
req <- req_retry(req, max_tries = 2) | ||
req <- req_timeout(req, 60) | ||
req <- req_user_agent(req, snowflake_user_agent()) | ||
|
||
# Snowflake-specific error response handling: | ||
req <- req_error(req, body = function(resp) resp_body_json(resp)$message) | ||
|
||
messages <- as_json(provider, turns) | ||
extra_args <- utils::modifyList(provider@extra_args, extra_args) | ||
|
||
data <- compact(list2( | ||
messages = messages, | ||
model = provider@model, | ||
stream = stream, | ||
!!!extra_args | ||
)) | ||
req <- req_body_json(req, data) | ||
|
||
req | ||
} | ||
|
||
# Snowflake -> ellmer -------------------------------------------------------- | ||
|
||
method(stream_parse, ProviderSnowflake) <- function(provider, event) { | ||
# Snowflake's SSEs look much like the OpenAI ones, except in their | ||
# handling of EOF. | ||
if (is.null(event)) { | ||
# This seems to be how Snowflake's backend signals that the stream is done. | ||
return(NULL) | ||
} | ||
jsonlite::parse_json(event$data) | ||
} | ||
|
||
method(value_turn, ProviderSnowflake) <- function(provider, result, has_type = FALSE) { | ||
deltas <- compact(sapply(result$choices, function(x) x$delta$content)) | ||
content <- list(as_content(paste(deltas, collapse = ""))) | ||
tokens <- c( | ||
result$usage$prompt_tokens %||% NA_integer_, | ||
result$usage$completion_tokens %||% NA_integer_ | ||
) | ||
tokens_log(paste0("Snowflake-", provider@account), tokens) | ||
Turn( | ||
# Snowflake's response format seems to omit the role. | ||
"assistant", | ||
content, | ||
json = result, | ||
tokens = tokens | ||
) | ||
} | ||
|
||
# ellmer -> Snowflake -------------------------------------------------------- | ||
|
||
# Snowflake only supports simple textual messages. | ||
|
||
method(as_json, list(ProviderSnowflake, Turn)) <- function(provider, x) { | ||
list( | ||
role = x@role, | ||
content = as_json(provider, x@contents[[1]]) | ||
) | ||
} | ||
|
||
method(as_json, list(ProviderSnowflake, ContentText)) <- function(provider, x) { | ||
x@text | ||
} | ||
|
||
# Utilities ------------------------------------------------------------------ | ||
|
||
snowflake_account <- function() { | ||
key_get("SNOWFLAKE_ACCOUNT") | ||
} | ||
|
||
snowflake_url <- function(account) { | ||
paste0("https://", account, ".snowflakecomputing.com") | ||
} | ||
|
||
# Snowflake uses the User Agent header to identify "parter applications", so | ||
# identify requests as coming from "r_ellmer" (unless an explicit partner | ||
# application is set via the ambient SF_PARTNER environment variable). | ||
snowflake_user_agent <- function() { | ||
user_agent <- paste0("r_ellmer/", utils::packageVersion("ellmer")) | ||
if (nchar(Sys.getenv("SF_PARTNER")) != 0) { | ||
user_agent <- Sys.getenv("SF_PARTNER") | ||
} | ||
user_agent | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# defaults are reported | ||
|
||
Code | ||
. <- chat_snowflake(credentials = credentials) | ||
Message | ||
Using model = "llama3.1-70b". | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.