diff --git a/NEWS.md b/NEWS.md index ba7bf65..8ce7328 100644 --- a/NEWS.md +++ b/NEWS.md @@ -20,6 +20,8 @@ API](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api) (#258, @atheriel). +* `chat_bedrock()` now handles temporary IAM credentials better (#261, @atheriel). + # ellmer 0.1.0 * New `chat_vllm()` to chat with models served by vLLM (#140). diff --git a/R/provider-bedrock.R b/R/provider-bedrock.R index 81072b7..313e341 100644 --- a/R/provider-bedrock.R +++ b/R/provider-bedrock.R @@ -33,7 +33,8 @@ chat_bedrock <- function(system_prompt = NULL, echo = NULL) { check_installed("paws.common", "AWS authentication") - credentials <- paws_credentials(profile) + cache <- aws_creds_cache(profile) + credentials <- paws_credentials(profile, cache = cache) turns <- normalize_turns(turns, system_prompt) model <- set_default(model, "anthropic.claude-3-5-sonnet-20240620-v1:0") @@ -43,7 +44,8 @@ chat_bedrock <- function(system_prompt = NULL, base_url = "", model = model, profile = profile, - credentials = credentials + region = credentials$region, + cache = cache ) Chat$new(provider = provider, turns = turns, echo = echo) @@ -55,7 +57,8 @@ ProviderBedrock <- new_class( properties = list( model = prop_string(), profile = prop_string(allow_null = TRUE), - credentials = class_list + region = prop_string(), + cache = class_list ) ) @@ -67,7 +70,7 @@ method(chat_request, ProviderBedrock) <- function(provider, extra_args = list()) { req <- request(paste0( - "https://bedrock-runtime.", provider@credentials$region, ".amazonaws.com" + "https://bedrock-runtime.", provider@region, ".amazonaws.com" )) req <- req_url_path_append( req, @@ -75,11 +78,12 @@ method(chat_request, ProviderBedrock) <- function(provider, provider@model, if (stream) "converse-stream" else "converse" ) + creds <- paws_credentials(provider@profile, provider@cache) req <- req_auth_aws_v4( req, - aws_access_key_id = provider@credentials$access_key_id, - aws_secret_access_key = provider@credentials$secret_access_key, - aws_session_token = provider@credentials$session_token + aws_access_key_id = creds$access_key_id, + aws_secret_access_key = creds$secret_access_key, + aws_session_token = creds$session_token ) req <- req_error(req, body = function(resp) { @@ -295,15 +299,38 @@ method(as_json, list(ProviderBedrock, ToolDef)) <- function(provider, x) { # Helpers ---------------------------------------------------------------- -paws_credentials <- function(profile) { - if (is_testing()) { - tryCatch( - paws.common::locate_credentials(profile), +paws_credentials <- function(profile, cache = aws_creds_cache(profile), + reauth = FALSE) { + creds <- cache$get() + if (reauth || is.null(creds) || creds$expiration < Sys.time()) { + cache$clear() + try_fetch( + creds <- locate_aws_credentials(profile), error = function(cnd) { - testthat::skip("Failed to locate AWS credentails") + if (is_testing()) { + testthat::skip("Failed to locate AWS credentails") + } + cli::cli_abort("No IAM credentials found.", parent = cnd) } ) - } else { - paws.common::locate_credentials(profile) + cache$set(creds) } + creds +} + +# Wrapper for paws.common::locate_credentials() so we can mock it in tests. +locate_aws_credentials <- function(profile) { + paws.common::locate_credentials(profile) +} + +# In-memory cache for AWS credentials. Analogous to httr2:::cache_mem(). +aws_creds_cache <- function(profile) { + key <- hash(profile) + list( + get = function() env_get(the$aws_credentials_cache, key, default = NULL), + set = function(creds) env_poke(the$aws_credentials_cache, key, creds), + clear = function() env_unbind(the$aws_credentials_cache, key) + ) } + +the$aws_credentials_cache <- new_environment() diff --git a/tests/testthat/test-provider-bedrock.R b/tests/testthat/test-provider-bedrock.R index ef68987..9f208da 100644 --- a/tests/testthat/test-provider-bedrock.R +++ b/tests/testthat/test-provider-bedrock.R @@ -45,3 +45,48 @@ test_that("can use images", { test_images_inline(chat_fun) test_images_remote_error(chat_fun) }) + +# Auth -------------------------------------------------------------------- + +test_that("AWS credential caching works as expected", { + # Mock AWS credentials for different profiles. + local_mocked_bindings( + locate_aws_credentials = function(profile) { + if (!is.null(profile) && profile == "test") { + list( + access_key = "key1", + secret_key = "secret1", + expiration = Sys.time() + 3600 + ) + } else { + list( + access_key = "key2", + secret_key = "secret2", + expiration = Sys.time() + 3600 + ) + } + } + ) + + creds1 <- paws_credentials(profile = "test") + creds2 <- paws_credentials(profile = NULL) + + # Verify different credentials were returned. + expect_false(identical(creds1, creds2)) + expect_equal(creds1$access_key, "key1") + expect_equal(creds2$access_key, "key2") + + # Verify cached credentials match original ones. + expect_identical(creds1, paws_credentials(profile = "test")) + expect_identical(creds2, paws_credentials(profile = NULL)) + + # Simulate a cache entry that has expired. + creds_modified <- creds1 + creds_modified$expiration <- Sys.time() - 5 + aws_creds_cache(profile = "test")$set(creds_modified) + + # Ensure the new credentials have been updated. + expect_false(identical(creds_modified, paws_credentials(profile = "test"))) + expect_false(identical(creds1, paws_credentials(profile = "test"))) + expect_false(identical(creds2, paws_credentials(profile = "test"))) +})