Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for token object for ProviderAzure. Fixes #195 #196

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions R/provider-azure.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ NULL
#' @param api_key The API key to use for authentication. You generally should
#' not supply this directly, but instead set the `AZURE_OPENAI_API_KEY` environment
#' variable.
#' @param token Azure token for authentication. This is typically not required for
#' Azure OpenAI API calls, but can be used if your setup requires it.
#' @param azure_token token object of class AzureToken for authentication. This is typically not required for
#' Azure OpenAI API calls, but can be used if your setup requires it. The azure_token object is retrieved using the AzureAuth package.#' Using the azure_token object ensures a refresh method is available for the token.
#' @inheritParams chat_openai
#' @inherit chat_openai return
#' @export
Expand All @@ -30,7 +30,7 @@ chat_azure <- function(endpoint = azure_endpoint(),
system_prompt = NULL,
turns = NULL,
api_key = azure_key(),
token = NULL,
azure_token = NULL,
api_args = list(),
echo = c("none", "text", "all")) {
check_string(endpoint)
Expand All @@ -41,12 +41,29 @@ chat_azure <- function(endpoint = azure_endpoint(),

base_url <- paste0(endpoint, "/openai/deployments/", deployment_id)

if(is.null(azure_token)){
access_token = azure_token
} else if(is_azure_token(azure_token)) {
# uses the token object method to validate the azure_token (for example if it expired)
valid = azure_token$validate()
if(!valid ) {
# uses the token object method to refresh the azure_token
azure_token = azure_token$refresh()
}
# retrieves the actual access token from the azure_token object for further use.
access_token = azure_token$credentials$access_token

} else {
cli::cli_abort("azure_token must be of class <AzureToken> or NULL. Please consider using the AzureAuth package to create a token object.")
return()
}

provider <- ProviderAzure(
base_url = base_url,
endpoint = endpoint,
model = deployment_id,
api_version = api_version,
token = token,
access_token = access_token,
extra_args = api_args,
api_key = api_key
)
Expand All @@ -58,7 +75,7 @@ ProviderAzure <- new_class(
parent = ProviderOpenAI,
properties = list(
api_key = prop_string(),
token = prop_string(allow_null = TRUE),
access_token = prop_string(allow_null = TRUE),
endpoint = prop_string(),
api_version = prop_string()
)
Expand All @@ -85,8 +102,8 @@ method(chat_request, ProviderAzure) <- function(provider,
req <- req_url_path_append(req, "/chat/completions")
req <- req_url_query(req, `api-version` = provider@api_version)
req <- req_headers(req, `api-key` = provider@api_key, .redact = "api-key")
if (!is.null(provider@token)) {
req <- req_auth_bearer_token(req, provider@token)
if (!is.null(provider@access_token)) {
req <- req_auth_bearer_token(req, provider@access_token)
}
req <- req_retry(req, max_tries = 2)
req <- req_error(req, body = function(resp) resp_body_json(resp)$message)
Expand Down
1 change: 1 addition & 0 deletions R/utils-S7.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,4 @@ prop_number_whole <- function(default = NULL, min = NULL, max = NULL, allow_null
}
)
}

8 changes: 8 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,11 @@ dots_named <- function(...) {
x[[length(x) + 1]] <- value
x
}

is_azure_token <- function (object)
{
R6::is.R6(object) && inherits(object, "AzureToken")
}



6 changes: 3 additions & 3 deletions man/chat_azure.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading