Skip to content

Commit

Permalink
Improve chat_snowflake() performance by caching keypair JWTs (#281)
Browse files Browse the repository at this point in the history
Generating a new JWT on each request is quite wasteful, so this commit
caches them instead (and saves us about 10ms per request). It reuses and
generalises the caching code originally introduced for AWS credentials.

Unit tests are included.

Signed-off-by: Aaron Jacobs <[email protected]>
  • Loading branch information
atheriel authored Jan 27, 2025
1 parent bdc63b1 commit 2bc1ca9
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 21 deletions.
1 change: 1 addition & 0 deletions R/ellmer-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"_PACKAGE"

the <- new_environment()
the$credentials_cache <- new_environment()

silence_r_cmd_check_note <- function() {
later::later()
Expand Down
10 changes: 1 addition & 9 deletions R/provider-bedrock.R
Original file line number Diff line number Diff line change
Expand Up @@ -325,14 +325,6 @@ locate_aws_credentials <- function(profile) {
paws.common::locate_credentials(profile)
}

# In-memory cache for AWS credentials. Analogous to httr2:::cache_mem().
aws_creds_cache <- function(profile) {
key <- hash(profile)
list(
get = function() env_get(the$aws_credentials_cache, key, default = NULL),
set = function(creds) env_poke(the$aws_credentials_cache, key, creds),
clear = function() env_unbind(the$aws_credentials_cache, key)
)
credentials_cache(key = hash(c("aws", profile)))
}

the$aws_credentials_cache <- new_environment()
13 changes: 1 addition & 12 deletions R/provider-cortex.R
Original file line number Diff line number Diff line change
Expand Up @@ -418,18 +418,7 @@ cortex_credentials <- function(account = Sys.getenv("SNOWFLAKE_ACCOUNT"),
if (nchar(user) != 0 && nchar(private_key) != 0) {
check_installed("jose", "for key-pair authentication")
key <- openssl::read_key(private_key)
# We can't use openssl::fingerprint() here because it uses a different
# algorithm.
fp <- openssl::base64_encode(
openssl::sha256(openssl::write_der(key$pubkey))
)
sub <- toupper(paste0(account, ".", user))
iss <- paste0(sub, ".SHA256:", fp)
# Note: Snowflake employs a malformed issuer claim, so we have to inject it
# manually after jose's validation phase.
claim <- httr2::jwt_claim("dummy", sub)
claim$iss <- iss
token <- httr2::jwt_encode_sig(claim, key)
token <- snowflake_keypair_token(account, user, key)
return(
list(
Authorization = paste("Bearer", token),
Expand Down
35 changes: 35 additions & 0 deletions R/provider-snowflake.R
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,38 @@ snowflake_user_agent <- function() {
}
user_agent
}

snowflake_keypair_token <- function(
account,
user,
key,
cache = snowflake_keypair_cache(account, key),
lifetime = 600L,
reauth = FALSE
) {
# Producing a signed JWT is a fairly expensive operation (in the order of
# ~10ms), but adding a cache speeds this up approximately 500x.
creds <- cache$get()
if (reauth || is.null(creds) || creds$expiry < Sys.time()) {
cache$clear()
expiry <- Sys.time() + lifetime
# We can't use openssl::fingerprint() here because it uses a different
# algorithm.
fp <- openssl::base64_encode(
openssl::sha256(openssl::write_der(key$pubkey))
)
sub <- toupper(paste0(account, ".", user))
iss <- paste0(sub, ".SHA256:", fp)
# Note: Snowflake employs a malformed issuer claim, so we have to inject it
# manually after jose's validation phase.
claim <- jwt_claim("dummy", sub, exp = as.integer(expiry))
claim$iss <- iss
creds <- list(expiry = expiry, token = jwt_encode_sig(claim, key))
cache$set(creds)
}
creds$token
}

snowflake_keypair_cache <- function(account, key) {
credentials_cache(key = hash(c("sf", account, openssl::fingerprint(key))))
}
9 changes: 9 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,12 @@ has_credentials <- function(provider) {
)

}

# In-memory cache for credentials. Analogous to httr2:::cache_mem().
credentials_cache <- function(key) {
list(
get = function() env_get(the$credentials_cache, key, default = NULL),
set = function(creds) env_poke(the$credentials_cache, key, creds),
clear = function() env_unbind(the$credentials_cache, key)
)
}
69 changes: 69 additions & 0 deletions tests/testthat/test-provider-snowflake.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,72 @@ test_that("can use images", {
# test_images_inline(chat_snowflake)
# test_images_remote(chat_snowflake)
})

# Auth --------------------------------------------------------------------

test_that("Snowflake keypair token caching works as expected", {
skip_if_not_installed("jose")

# Random RSA key for testing.
testkey <- openssl::read_key(
"-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCbxG4OC5HU9QlK
dmtbQCa7r+uoKyDSisxqJQchfkDy64v6V6WsovI8evUGPQpkAbqsmXY3DR3T/Mco
P2oHyzGsfd2t7v6NLHNtGbMiEJYjVJvOw52Yn1m4WH5bEtl5JP/8W2qyTdr6qym+
m47X8hAqb+ToQjnolRq9xlme9n6vhwGi8mlco5POLCcEDhcYiqcPxI/WRqHDcdi8
/nU1eRhGTSe77NUnC0QQOojjRZ3P59NuA7zpgFdMLdE0I5qrfL3e6SQlnmTizdng
qiZUAI5p1ISZffn1FLf9GZlnD/usG0Dbp2MDGbYMhx8a5ii0RGyEINYrATdxkKaW
/AKGeKbdAgMBAAECggEAH+a022+HKGQeyP9DsWaMCDhZPRHIIRaIEt0Ofs+KobWX
72dv6NFeZwCPmf16WUz5XEv5qACpsTa92wJRxtLYk4kbk3m07FjEMv3mb/2Roh67
4jax2gYYq+aDykcr/uGTA639RhMn29qeLAlT0eojYW2VJfQaRAX1ehRbWnEFNRFR
Pyy1pBOCReDG0yyw1OtDv85H09UdiRWyNC7HkxdaMZns//GOQ3MwuZL61st7Aezg
Xz+mGw7v+SEIL0zk92GSIOHA0TXiUAhIWGxIyqSeNqw+Cl0+4r6ZuT+z2lILPR9C
UPVMtXUzUBhBPhtvPpq2RoRqcHzXWsdUcfteKyN/iQKBgQC5JvzZOwOnBcqaTUpn
ykrYwyiAOk0h3uOs4Mrs7A40xWmQ35VOb1gWVnTgvC91SBBfP/jGf02ZdLk5NG1/
oe13aKvQ6mh/jTImPLEPxsMm+469+nklitHwF8b6R3zrSPHoqdF4XUOHOcbK5V8W
MgUIIXDGtLCqxTns41VbIM9/5wKBgQDXXvgG5238F1LtHFG0FNZRilRt3d+cO1CU
HctSPGRXVe8ZGEJZ4F/TV6pWEOrdsuk5bp/IoDGKE2b9FI6K9BKy3Xc8qx5Um9zF
q5ca671UmZkcqu8jh99JSn9sKM7PP9QZInhP1eca7J9r2lhROHk0hsyTWtzuVcWO
JttBO0lamwKBgQCJWFGCNxO7h0FGewUxvs8MwqA9loH3GScc69e8LlNPdA2eKSzR
dSkL0PB8cTxnLKDwdzzsyixfJEXuGGUNo6nKxTuHCwufarcDxEu4H0JOnZbCeJX7
cmHPT2QL7pHM21yPscEwH0bjfcloYwPJLCutX1kQHaNb2lfg0LZVlh42iwKBgQCW
3yp0+66qiFRJUitSMb6pRHQ8us8ojMy31d9W7oOEQujJ9ZqVh37ZeHIU9KjzQZ/r
4bkBPGc3yLu+0qXAZZarwkUDNQR8VOtldfzWmQn6t9bwpDX99/LNTujQhg3KVXZp
XSJXGwtYayaK0VxJGXye9UdeeqqGM4O/Py0dF0EdvQKBgDo82ImF2mKzJUEBK33r
uGtR8Fxbg4cNRAc0W6xME86IVTnLnqLp1yeTZZGCFek6hDqERLCbQhQk8t1Szm0V
OdYSh6YfkxhsBGp6hHefOTWuoto4zHZ98uuu0GD8NkzGmnZApZ7It1MiH+SZPG9w
AK4HbizZMWlkvg87OphvnQhC
-----END PRIVATE KEY-----
"
)

token1 <- snowflake_keypair_token("test1", "user", testkey)
token2 <- snowflake_keypair_token("test2", "user", testkey)

# Verify different tokens were returned
expect_false(identical(token1, token2))

# Verify cached tokens match original ones
expect_identical(token1, snowflake_keypair_token("test1", "user", testkey))
expect_identical(token2, snowflake_keypair_token("test2", "user", testkey))

# Simulate a cache entry that has expired
cache <- snowflake_keypair_cache("test1", testkey)
creds_modified <- cache$get()
creds_modified$expiry <- Sys.time() - 5
cache$set(creds_modified)

# Ensure the new token has been updated
expect_false(
identical(
creds_modified,
snowflake_keypair_token("test1", "user", testkey)
)
)
expect_false(
identical(token1, snowflake_keypair_token("test1", "user", testkey))
)
expect_false(
identical(token2, snowflake_keypair_token("test1", "user", testkey))
)
})

0 comments on commit 2bc1ca9

Please sign in to comment.