From 7ed6aec4b02eb9b760b07bcfd01980ab1aad4820 Mon Sep 17 00:00:00 2001 From: Hadley Wickham Date: Mon, 27 Jan 2025 11:33:29 -0600 Subject: [PATCH] Bump Azure API version (#272) So it now supports structured data extraction (Fixes #271). And has tests. --- NEWS.md | 2 + R/provider-azure.R | 17 +++++++- tests/testthat/_snaps/provider-azure.md | 16 +++++++ tests/testthat/test-provider-azure.R | 58 ++++++++++++++++++++----- 4 files changed, 81 insertions(+), 12 deletions(-) diff --git a/NEWS.md b/NEWS.md index 2930ffc..685391d 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # ellmer (development version) +* `chat_azure()` now defaults `api_version = "2024-10-21"` which includes data for structured data extraction (#271). + * `chat_openrouter()` provides support for OpenRouter models (#212) * `chat_deepseek()` provides support for DeepSeek models (#242) diff --git a/R/provider-azure.R b/R/provider-azure.R index 4ac33df..de6591f 100644 --- a/R/provider-azure.R +++ b/R/provider-azure.R @@ -71,14 +71,15 @@ chat_azure <- function(endpoint = azure_endpoint(), } check_string(endpoint) check_string(deployment_id) - api_version <- set_default(api_version, "2024-06-01") + api_version <- set_default(api_version, "2024-10-21") turns <- normalize_turns(turns, system_prompt) check_exclusive(token, credentials, .require = FALSE) check_string(api_key, allow_null = TRUE) api_key <- api_key %||% Sys.getenv("AZURE_OPENAI_API_KEY") check_string(token, allow_null = TRUE) echo <- check_echo(echo) - if (is_list(credentials)) { + + if (is_list(credentials)) { static_credentials <- force(credentials) credentials <- function() static_credentials } @@ -96,6 +97,18 @@ chat_azure <- function(endpoint = azure_endpoint(), Chat$new(provider = provider, turns = turns, echo = echo) } +chat_azure_test <- function(system_prompt = NULL, ...) { + api_key <- key_get("AZURE_OPENAI_API_KEY") + + chat_azure( + ..., + system_prompt = system_prompt, + api_key = api_key, + endpoint = "https://ai-hwickhamai260967855527.openai.azure.com", + deployment_id = "gpt-4o-mini" + ) +} + ProviderAzure <- new_class( "ProviderAzure", parent = ProviderOpenAI, diff --git a/tests/testthat/_snaps/provider-azure.md b/tests/testthat/_snaps/provider-azure.md index 7bf340c..437ad2f 100644 --- a/tests/testthat/_snaps/provider-azure.md +++ b/tests/testthat/_snaps/provider-azure.md @@ -1,3 +1,19 @@ +# defaults are reported + + Code + . <- chat_azure_test() + Message + Using api_version = "2024-10-21". + +# all tool variations work + + Code + chat$chat("Great. Do it again.") + Condition + Error in `FUN()`: + ! Can't use async tools with `$chat()` or `$stream()`. + i Async tools are supported, but you must use `$chat_async()` or `$stream_async()`. + # Azure request headers are generated correctly Code diff --git a/tests/testthat/test-provider-azure.R b/tests/testthat/test-provider-azure.R index 7df3c84..5ac694b 100644 --- a/tests/testthat/test-provider-azure.R +++ b/tests/testthat/test-provider-azure.R @@ -1,18 +1,56 @@ +# Getting started -------------------------------------------------------- + test_that("can make simple request", { - chat <- chat_azure( - system_prompt = "Be as terse as possible; no punctuation", - endpoint = "https://ai-hwickhamai260967855527.openai.azure.com", - deployment_id = "gpt-4o-mini" - ) - resp <- chat$chat("What is 1 + 1?") + chat <- chat_azure_test("Be as terse as possible; no punctuation") + resp <- chat$chat("What is 1 + 1?", echo = FALSE) expect_match(resp, "2") - expect_equal(chat$last_turn()@tokens, c(27, 1)) + expect_true(all(chat$last_turn()@tokens >= 1)) +}) - resp <- sync(chat$chat_async("What is 1 + 1?")) - expect_match(resp, "2") - expect_equal(chat$last_turn()@tokens, c(44, 1)) +test_that("can make simple streaming request", { + chat <- chat_azure_test("Be as terse as possible; no punctuation") + resp <- coro::collect(chat$stream("What is 1 + 1?")) + expect_match(paste0(unlist(resp), collapse = ""), "2") +}) + +# Common provider interface ----------------------------------------------- + +test_that("defaults are reported", { + expect_snapshot(. <- chat_azure_test()) +}) + +test_that("respects turns interface", { + chat_fun <- chat_azure_test + + test_turns_system(chat_fun) + test_turns_existing(chat_fun) }) +test_that("all tool variations work", { + chat_fun <- chat_azure_test + + test_tools_simple(chat_fun) + test_tools_async(chat_fun) + test_tools_parallel(chat_fun) + test_tools_sequential(chat_fun, total_calls = 6) +}) + +test_that("can extract data", { + chat_fun <- chat_azure_test + + test_data_extraction(chat_fun) +}) + +test_that("can use images", { + skip("Run manually; 24 hour rate limit") + chat_fun <- chat_azure_test + + httr2::with_verbosity(test_images_inline(chat_fun), 2) + test_images_remote(chat_fun) +}) + +# Authentication -------------------------------------------------------------- + test_that("Azure request headers are generated correctly", { turn <- Turn( role = "user",