diff --git a/R/provider-azure.R b/R/provider-azure.R index 77d784b..6099011 100644 --- a/R/provider-azure.R +++ b/R/provider-azure.R @@ -48,8 +48,9 @@ chat_azure <- function(endpoint = azure_endpoint(), check_string(deployment_id) api_version <- set_default(api_version, "2024-06-01") turns <- normalize_turns(turns, system_prompt) - check_exclusive(api_key, token, credentials, .require = FALSE) + check_exclusive(token, credentials, .require = FALSE) check_string(api_key, allow_null = TRUE) + api_key <- api_key %||% Sys.getenv("AZURE_OPENAI_API_KEY") check_string(token, allow_null = TRUE) echo <- check_echo(echo) if (is_list(credentials)) { @@ -63,6 +64,7 @@ chat_azure <- function(endpoint = azure_endpoint(), endpoint = endpoint, deployment_id = deployment_id, api_version = api_version, + api_key = api_key, credentials = credentials, extra_args = api_args ) @@ -72,13 +74,13 @@ chat_azure <- function(endpoint = azure_endpoint(), ProviderAzure <- new_class( "ProviderAzure", parent = ProviderOpenAI, - constructor = function(endpoint, deployment_id, api_version, credentials, - extra_args = list()) { + constructor = function(endpoint, deployment_id, api_version, api_key, + credentials, extra_args = list()) { new_object( ProviderOpenAI( base_url = paste0(endpoint, "/openai/deployments/", deployment_id), model = deployment_id, - api_key = "", + api_key = api_key, extra_args = extra_args ), api_version = api_version, @@ -86,7 +88,7 @@ ProviderAzure <- new_class( ) }, properties = list( - credentials = class_function | NULL, + credentials = class_function, api_version = prop_string() ) ) @@ -109,11 +111,10 @@ method(chat_request, ProviderAzure) <- function(provider, req <- req_url_query(req, `api-version` = provider@api_version) # Note: could use req_headers_redacted() here but it requires a very new # httr2 version. - req <- req_headers( - req, - !!!provider@credentials(), - .redact = c("api-key", "Authorization") - ) + if (nchar(provider@api_key)) { + req <- req_headers(req, `api-key` = provider@api_key, .redact = "api-key") + } + req <- req_headers(req, !!!provider@credentials(), .redact = "Authorization") req <- req_retry(req, max_tries = 2) req <- req_error(req, body = function(resp) resp_body_json(resp)$message) @@ -150,15 +151,15 @@ method(chat_request, ProviderAzure) <- function(provider, } default_azure_credentials <- function(api_key = NULL, token = NULL) { - api_key <- api_key %||% Sys.getenv("AZURE_OPENAI_API_KEY") - if (nchar(api_key)) { - return(function() list(`api-key` = api_key)) - } - if (!is.null(token)) { return(function() list(Authorization = paste("Bearer", token))) } + # If we have an API key, rely on that for credentials. + if (nchar(api_key)) { + return(function() list()) + } + if (is_testing()) { testthat::skip("no Azure credentials available") } diff --git a/tests/testthat/_snaps/provider-azure.md b/tests/testthat/_snaps/provider-azure.md index 7a12aae..902ee56 100644 --- a/tests/testthat/_snaps/provider-azure.md +++ b/tests/testthat/_snaps/provider-azure.md @@ -30,3 +30,20 @@ * retry_on_failure: FALSE * error_body: a function +--- + + Code + req + Message + + POST + https://ai-hwickhamai260967855527.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-06-01 + Headers: + * api-key: '' + * Authorization: '' + Body: json encoded data + Policies: + * retry_max_tries: 2 + * retry_on_failure: FALSE + * error_body: a function + diff --git a/tests/testthat/test-provider-azure.R b/tests/testthat/test-provider-azure.R index ed04b5a..049d56f 100644 --- a/tests/testthat/test-provider-azure.R +++ b/tests/testthat/test-provider-azure.R @@ -26,6 +26,7 @@ test_that("Azure request headers are generated correctly", { endpoint = endpoint, deployment_id = deployment_id, api_version = "2024-06-01", + api_key = "key", credentials = default_azure_credentials("key") ) req <- chat_request(p, FALSE, list(turn)) @@ -36,8 +37,20 @@ test_that("Azure request headers are generated correctly", { endpoint = endpoint, deployment_id = deployment_id, api_version = "2024-06-01", + api_key = "", credentials = default_azure_credentials("", "token") ) req <- chat_request(p, FALSE, list(turn)) expect_snapshot(req) + + # Both. + p <- ProviderAzure( + endpoint = endpoint, + deployment_id = deployment_id, + api_version = "2024-06-01", + api_key = "key", + credentials = default_azure_credentials("key", "token") + ) + req <- chat_request(p, FALSE, list(turn)) + expect_snapshot(req) })