Skip to content

Commit

Permalink
Improve support for temporary IAM credentials in chat_bedrock().
Browse files Browse the repository at this point in the history
Many AWS IAM credentials expire, but previously we ignored this by
looking up credentials only once, in `chat_bedrock()`. This commit
introduces a caching layer that handles expiry and moves credential
retrievable closer to request time, instead.

The design is almost identical to httr2's OAuth token caching mechanism,
but I had to re-implement various pieces because not all of that API is
exported.

(We could probably introduce a `req_aws_credentials()` function to
`httr2` itself that would handle this, but that might tie us too closely
to the semantics of `paws.common`.)

Unit tests are included.

Closes #261.

Signed-off-by: Aaron Jacobs <[email protected]>
  • Loading branch information
atheriel committed Jan 23, 2025
1 parent 7a4855e commit 6037b07
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 14 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
API](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api)
(#258, @atheriel).

* `chat_bedrock()` now handles temporary IAM credentials better (#261, @atheriel).

# ellmer 0.1.0

* New `chat_vllm()` to chat with models served by vLLM (#140).
Expand Down
55 changes: 41 additions & 14 deletions R/provider-bedrock.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ chat_bedrock <- function(system_prompt = NULL,
echo = NULL) {

check_installed("paws.common", "AWS authentication")
credentials <- paws_credentials(profile)
cache <- aws_creds_cache(profile)
credentials <- paws_credentials(profile, cache = cache)

turns <- normalize_turns(turns, system_prompt)
model <- set_default(model, "anthropic.claude-3-5-sonnet-20240620-v1:0")
Expand All @@ -43,7 +44,8 @@ chat_bedrock <- function(system_prompt = NULL,
base_url = "",
model = model,
profile = profile,
credentials = credentials
region = credentials$region,
cache = cache
)

Chat$new(provider = provider, turns = turns, echo = echo)
Expand All @@ -55,7 +57,8 @@ ProviderBedrock <- new_class(
properties = list(
model = prop_string(),
profile = prop_string(allow_null = TRUE),
credentials = class_list
region = prop_string(),
cache = class_list
)
)

Expand All @@ -67,19 +70,20 @@ method(chat_request, ProviderBedrock) <- function(provider,
extra_args = list()) {

req <- request(paste0(
"https://bedrock-runtime.", provider@credentials$region, ".amazonaws.com"
"https://bedrock-runtime.", provider@region, ".amazonaws.com"
))
req <- req_url_path_append(
req,
"model",
provider@model,
if (stream) "converse-stream" else "converse"
)
creds <- paws_credentials(provider@profile, provider@cache)
req <- req_auth_aws_v4(
req,
aws_access_key_id = provider@credentials$access_key_id,
aws_secret_access_key = provider@credentials$secret_access_key,
aws_session_token = provider@credentials$session_token
aws_access_key_id = creds$access_key_id,
aws_secret_access_key = creds$secret_access_key,
aws_session_token = creds$session_token
)

req <- req_error(req, body = function(resp) {
Expand Down Expand Up @@ -295,15 +299,38 @@ method(as_json, list(ProviderBedrock, ToolDef)) <- function(provider, x) {

# Helpers ----------------------------------------------------------------

paws_credentials <- function(profile) {
if (is_testing()) {
tryCatch(
paws.common::locate_credentials(profile),
paws_credentials <- function(profile, cache = aws_creds_cache(profile),
reauth = FALSE) {
creds <- cache$get()
if (reauth || is.null(creds) || creds$expiration < Sys.time()) {
cache$clear()
try_fetch(
creds <- locate_aws_credentials(profile),
error = function(cnd) {
testthat::skip("Failed to locate AWS credentails")
if (is_testing()) {
testthat::skip("Failed to locate AWS credentails")
}
cli::cli_abort("No IAM credentials found.", parent = cnd)
}
)
} else {
paws.common::locate_credentials(profile)
cache$set(creds)
}
creds
}

# Wrapper for paws.common::locate_credentials() so we can mock it in tests.
locate_aws_credentials <- function(profile) {
paws.common::locate_credentials(profile)
}

# In-memory cache for AWS credentials. Analogous to httr2:::cache_mem().
aws_creds_cache <- function(profile) {
key <- hash(profile)
list(
get = function() env_get(the$aws_credentials_cache, key, default = NULL),
set = function(creds) env_poke(the$aws_credentials_cache, key, creds),
clear = function() env_unbind(the$aws_credentials_cache, key)
)
}

the$aws_credentials_cache <- new_environment()
45 changes: 45 additions & 0 deletions tests/testthat/test-provider-bedrock.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,48 @@ test_that("can use images", {
test_images_inline(chat_fun)
test_images_remote_error(chat_fun)
})

# Auth --------------------------------------------------------------------

test_that("AWS credential caching works as expected", {
# Mock AWS credentials for different profiles.
local_mocked_bindings(
locate_aws_credentials = function(profile) {
if (!is.null(profile) && profile == "test") {
list(
access_key = "key1",
secret_key = "secret1",
expiration = Sys.time() + 3600
)
} else {
list(
access_key = "key2",
secret_key = "secret2",
expiration = Sys.time() + 3600
)
}
}
)

creds1 <- paws_credentials(profile = "test")
creds2 <- paws_credentials(profile = NULL)

# Verify different credentials were returned.
expect_false(identical(creds1, creds2))
expect_equal(creds1$access_key, "key1")
expect_equal(creds2$access_key, "key2")

# Verify cached credentials match original ones.
expect_identical(creds1, paws_credentials(profile = "test"))
expect_identical(creds2, paws_credentials(profile = NULL))

# Simulate a cache entry that has expired.
creds_modified <- creds1
creds_modified$expiration <- Sys.time() - 5
aws_creds_cache(profile = "test")$set(creds_modified)

# Ensure the new credentials have been updated.
expect_false(identical(creds_modified, paws_credentials(profile = "test")))
expect_false(identical(creds1, paws_credentials(profile = "test")))
expect_false(identical(creds2, paws_credentials(profile = "test")))
})

0 comments on commit 6037b07

Please sign in to comment.