Skip to content

Commit

Permalink
Add a new chat_snowflake() provider.
Browse files Browse the repository at this point in the history
This commit adds `chat_snowflake()` for chatting with models hosted
through Snowflake's [Cortex LLM REST API][0]:

    chat <- chat_snowflake()
    chat$chat("Tell me a joke in the form of a SQL query.")

On the backend it looks fairly similar to OpenAI, though it has only the
basic textual functionality, and so many advanced ellmer features are
not available. I also reused quite a bit of the credential support and
utilities from `chat_cortex()`, so this commit also includes some minor
refactoring of that provider.

Right now the default model for `chat_snowflake()` is Llama 3.1 70B, but
we should change it to Claude 3.5 Sonnet when that gets rolled out more
widely; it's only available to customers in the `us-west-1` Snowflake
region right now.

Unit tests are included.

Part of #255.

[0]: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api

Signed-off-by: Aaron Jacobs <[email protected]>
  • Loading branch information
atheriel committed Jan 21, 2025
1 parent bf973c9 commit 3b1c73d
Show file tree
Hide file tree
Showing 10 changed files with 340 additions and 19 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ Collate:
'provider-groq.R'
'provider-ollama.R'
'provider-perplexity.R'
'provider-snowflake.R'
'provider-vllm.R'
'shiny.R'
'tokens.R'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export(chat_groq)
export(chat_ollama)
export(chat_openai)
export(chat_perplexity)
export(chat_snowflake)
export(chat_vllm)
export(content_image_file)
export(content_image_plot)
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
* `chat_databricks()` now respects the `SPARK_CONNECT_USER_AGENT` environment
variable when making requests (#254, @atheriel).

* A new `chat_snowflake()` allows chatting with models hosted through
Snowflake's [Cortex LLM REST
API](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api)
(#258, @atheriel).

# ellmer 0.1.0

* New `chat_vllm()` to chat with models served by vLLM (#140).
Expand Down
21 changes: 5 additions & 16 deletions R/provider-cortex.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ NULL
#' `account`.
#'
#' @param account A Snowflake [account identifier](https://docs.snowflake.com/en/user-guide/admin-account-identifier),
#' e.g. `"testorg-test_account"`.
#' e.g. `"testorg-test_account"`. Defaults to the value of the
#' `SNOWFLAKE_ACCOUNT` environment variable.
#' @param credentials A list of authentication headers to pass into
#' [`httr2::req_headers()`], a function that returns them when passed
#' `account` as a parameter, or `NULL` to use ambient credentials.
Expand All @@ -50,7 +51,7 @@ NULL
#' )
#' chat$chat("What questions can I ask?")
#' @export
chat_cortex <- function(account = Sys.getenv("SNOWFLAKE_ACCOUNT"),
chat_cortex <- function(account = snowflake_account(),
credentials = NULL,
model_spec = NULL,
model_file = NULL,
Expand Down Expand Up @@ -84,7 +85,7 @@ ProviderCortex <- new_class(
parent = Provider,
constructor = function(account, credentials, model_spec = NULL,
model_file = NULL, extra_args = list()) {
base_url <- paste0("https://", account, ".snowflakecomputing.com")
base_url <- snowflake_url(account)
extra_args <- compact(list2(
semantic_model = model_spec,
semantic_model_file = model_file,
Expand Down Expand Up @@ -124,6 +125,7 @@ method(chat_request, ProviderCortex) <- function(provider,
req <- httr2::req_headers(req, !!!creds, .redact = "Authorization")
req <- req_retry(req, max_tries = 2)
req <- req_timeout(req, 60)
req <- req_user_agent(req, snowflake_user_agent())

# Snowflake doesn't document the error response format for Cortex Analyst at
# this time, but empirically errors look like the following:
Expand All @@ -136,15 +138,6 @@ method(chat_request, ProviderCortex) <- function(provider,
# }
req <- req_error(req, body = function(resp) resp_body_json(resp)$message)

# Snowflake uses the User Agent header to identify "parter applications",
# so identify requests as coming from "r_ellmer" (unless an explicit
# partner application is set via the ambient SF_PARTNER environment
# variable).
req <- req_user_agent(req, ellmer_user_agent())
if (nchar(Sys.getenv("SF_PARTNER")) != 0) {
req <- req_user_agent(req, Sys.getenv("SF_PARTNER"))
}

# Cortex does not yet support multi-turn chats.
turns <- tail(turns, n = 1)
messages <- as_json(provider, turns)
Expand All @@ -156,10 +149,6 @@ method(chat_request, ProviderCortex) <- function(provider,
req
}

ellmer_user_agent <- function() {
paste0("r_ellmer/", utils::packageVersion("ellmer"))
}

# Cortex -> ellmer --------------------------------------------------------------

method(stream_parse, ProviderCortex) <- function(provider, event) {
Expand Down
184 changes: 184 additions & 0 deletions R/provider-snowflake.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
#' @include provider-openai.R
#' @include content.R
NULL

#' Chat with a model hosted on Snowflake
#'
#' @description
#' The Snowflake provider allows you to interact with LLM models available
#' through the [Cortex LLM REST API](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api).
#'
#' Note that Snowflake-hosted models do not support images, tool calling, or
#' structured outputs.
#'
#' ## Authentication
#'
#' `chat_snowflake()` picks up the following ambient Snowflake credentials:
#'
#' - A static OAuth token defined via the `SNOWFLAKE_TOKEN` environment
#' variable.
#' - Key-pair authentication credentials defined via the `SNOWFLAKE_USER` and
#' `SNOWFLAKE_PRIVATE_KEY` (which can be a PEM-encoded private key or a path
#' to one) environment variables.
#' - Posit Workbench-managed Snowflake credentials for the corresponding
#' `account`.
#'
#' @inheritParams chat_openai
#' @inheritParams chat_cortex
#' @inherit chat_openai return
#' @examplesIf has_credentials("cortex")
#' chat <- chat_snowflake()
#' chat$chat("Tell me a joke in the form of a SQL query.")
#' @export
chat_snowflake <- function(system_prompt = NULL,
turns = NULL,
account = snowflake_account(),
credentials = NULL,
model = NULL,
api_args = list(),
echo = c("none", "text", "all")) {
turns <- normalize_turns(turns, system_prompt)
check_string(account, allow_empty = FALSE)
model <- set_default(model, "llama3.1-70b")
echo <- check_echo(echo)

if (is_list(credentials)) {
static_credentials <- force(credentials)
credentials <- function(account) static_credentials
}
check_function(credentials, allow_null = TRUE)

provider <- ProviderSnowflake(
base_url = snowflake_url(account),
account = account,
credentials = credentials,
model = model,
extra_args = api_args,
# We need an empty api_key for S7 validation.
api_key = ""
)

Chat$new(provider = provider, turns = turns, echo = echo)
}

ProviderSnowflake <- new_class(
"ProviderSnowflake",
parent = ProviderOpenAI,
properties = list(
account = prop_string(),
credentials = class_function | NULL
)
)

# See: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#api-reference
method(chat_request, ProviderSnowflake) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
type = NULL,
extra_args = list()) {
if (length(tools) != 0) {
cli::cli_abort(
"Tool calling is not supported for {.code chat_snowflake()}."
)
}
if (!is.null(type) != 0) {
cli::cli_abort(
"Structured data extraction is not supported for {.code chat_snowflake()}."
)
}
if (!stream) {
cli::cli_abort(
"Non-streaming responses are not supported for {.code chat_snowflake()}."
)
}

req <- request(provider@base_url)
req <- req_url_path_append(req, "/api/v2/cortex/inference:complete")
creds <- cortex_credentials(provider@account, provider@credentials)
req <- req_headers(req, !!!creds, .redact = "Authorization")
req <- req_retry(req, max_tries = 2)
req <- req_timeout(req, 60)
req <- req_user_agent(req, snowflake_user_agent())

# Snowflake-specific error response handling:
req <- req_error(req, body = function(resp) resp_body_json(resp)$message)

messages <- as_json(provider, turns)
extra_args <- utils::modifyList(provider@extra_args, extra_args)

data <- compact(list2(
messages = messages,
model = provider@model,
stream = stream,
!!!extra_args
))
req <- req_body_json(req, data)

req
}

# Snowflake -> ellmer --------------------------------------------------------

method(stream_parse, ProviderSnowflake) <- function(provider, event) {
# Snowflake's SSEs look much like the OpenAI ones, except in their
# handling of EOF.
if (is.null(event)) {
# This seems to be how Snowflake's backend signals that the stream is done.
return(NULL)
}
jsonlite::parse_json(event$data)
}

method(value_turn, ProviderSnowflake) <- function(provider, result, has_type = FALSE) {
deltas <- compact(sapply(result$choices, function(x) x$delta$content))
content <- list(as_content(paste(deltas, collapse = "")))
tokens <- c(
result$usage$prompt_tokens %||% NA_integer_,
result$usage$completion_tokens %||% NA_integer_
)
tokens_log(paste0("Snowflake-", provider@account), tokens)
Turn(
# Snowflake's response format seems to omit the role.
"assistant",
content,
json = result,
tokens = tokens
)
}

# ellmer -> Snowflake --------------------------------------------------------

# Snowflake only supports simple textual messages.

method(as_json, list(ProviderSnowflake, Turn)) <- function(provider, x) {
list(
role = x@role,
content = as_json(provider, x@contents[[1]])
)
}

method(as_json, list(ProviderSnowflake, ContentText)) <- function(provider, x) {
x@text
}

# Utilities ------------------------------------------------------------------

snowflake_account <- function() {
key_get("SNOWFLAKE_ACCOUNT")
}

snowflake_url <- function(account) {
paste0("https://", account, ".snowflakecomputing.com")
}

# Snowflake uses the User Agent header to identify "parter applications", so
# identify requests as coming from "r_ellmer" (unless an explicit partner
# application is set via the ambient SF_PARTNER environment variable).
snowflake_user_agent <- function() {
user_agent <- paste0("r_ellmer/", utils::packageVersion("ellmer"))
if (nchar(Sys.getenv("SF_PARTNER")) != 0) {
user_agent <- Sys.getenv("SF_PARTNER")
}
user_agent
}
5 changes: 3 additions & 2 deletions man/chat_cortex.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

77 changes: 77 additions & 0 deletions man/chat_snowflake.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions tests/testthat/_snaps/provider-snowflake.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# defaults are reported

Code
. <- chat_snowflake(credentials = credentials)
Message
Using model = "llama3.1-70b".

2 changes: 1 addition & 1 deletion tests/testthat/test-provider-cortex.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ test_that("Cortex API requests are generated correctly", {
req <- chat_request(p, FALSE, list(turn))
expect_snapshot(
req,
transform = function(x) gsub(ellmer_user_agent(), "<ellmer_user_agent>", x, fixed = TRUE)
transform = function(x) gsub(snowflake_user_agent(), "<ellmer_user_agent>", x, fixed = TRUE)
)
expect_snapshot(req$body$data)
})
Expand Down
Loading

0 comments on commit 3b1c73d

Please sign in to comment.