From 30518d91b7ea4ad3376c5f1f4f1c4f9fa6782672 Mon Sep 17 00:00:00 2001 From: ravargas Date: Thu, 1 May 2025 19:04:10 +0200 Subject: [PATCH] feat: add Cloudflare provider --- .../pydantic_ai/providers/cloudflare.py | 64 +++++++++++++++++++ pydantic_ai_slim/pyproject.toml | 1 + tests/providers/test_cloudflare.py | 41 ++++++++++++ uv.lock | 23 ++++++- 4 files changed, 128 insertions(+), 1 deletion(-) create mode 100644 pydantic_ai_slim/pydantic_ai/providers/cloudflare.py create mode 100644 tests/providers/test_cloudflare.py diff --git a/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py b/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py new file mode 100644 index 000000000..bd523dba0 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py @@ -0,0 +1,64 @@ +from __future__ import annotations as _annotations + +import os + +from httpx import AsyncClient as AsyncHTTPClient + +from pydantic_ai.exceptions import UserError +from pydantic_ai.models import cached_async_http_client +from pydantic_ai.providers import Provider + +try: + from cloudflare import AsyncCloudflare +except ImportError as _import_error: + raise ImportError( + 'Please install the `cloudflare` package to use the Cloudflare provider, ' + 'you can use the `cloudflare` optional group — `pip install "pydantic-ai-slim[cloudflare]"`' + ) from _import_error + + +class CloudflareProvider(Provider[AsyncCloudflare]): + """Provider for Cloudflare Workers AI.""" + + @property + def name(self) -> str: + return 'cloudflare' + + @property + def base_url(self) -> str: + return str(self._client.base_url) + + @property + def client(self) -> AsyncCloudflare: + return self._client + + def __init__( + self, + *, + api_key: str | None = None, + cloudflare_client: AsyncCloudflare | None = None, + http_client: AsyncHTTPClient | None = None, + ) -> None: + """Create a new Cloudflare provider. + + Args: + api_key: The API key to use for authentication, if not provided, the `CLOUDFLARE_API_KEY` env var is used. + account_id: Cloudflare account ID, or set via `CLOUDFLARE_ACCOUNT_ID` env var. + cloudflare_client: Pre-existing `AsyncCloudflare` client instance. + http_client: Optional custom `httpx.AsyncClient`. + """ + if cloudflare_client is not None: + assert api_key is None, 'Cannot provide both `cloudflare_client` and `api_key`' + assert http_client is None, 'Cannot provide both `cloudflare_client` and `http_client`' + self._client = cloudflare_client + self.account_id = '' # replace with real extraction if possible + else: + api_key = api_key or os.getenv('CLOUDFLARE_API_KEY') + + if not api_key: + raise UserError( + 'Set the `CLOUDFLARE_API_KEY` environment variable or pass it via `CloudflareProvider(api_key=...)`' + ) + + http_client = http_client or cached_async_http_client(provider='cloudflare') + self._client = AsyncCloudflare(api_key=api_key, http_client=http_client) diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index e1db8a398..0eb2fa444 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -63,6 +63,7 @@ anthropic = ["anthropic>=0.49.0"] groq = ["groq>=0.15.0"] mistral = ["mistralai>=1.2.5"] bedrock = ["boto3>=1.35.74"] +cloudflare = ["cloudflare>=4.1.0"] # Tools duckduckgo = ["duckduckgo-search>=7.0.0"] tavily = ["tavily-python>=0.5.0"] diff --git a/tests/providers/test_cloudflare.py b/tests/providers/test_cloudflare.py new file mode 100644 index 000000000..b4d6fc056 --- /dev/null +++ b/tests/providers/test_cloudflare.py @@ -0,0 +1,41 @@ +from __future__ import annotations as _annotations + +import httpx +import pytest + +from pydantic_ai.exceptions import UserError + +from ..conftest import TestEnv, try_import + +with try_import() as imports_successful: + from cloudflare import AsyncCloudflare + + from pydantic_ai.providers.cloudflare import CloudflareProvider + + +pytestmark = pytest.mark.skipif(not imports_successful(), reason='cloudflare not installed') + + +def test_cloudflare_provider() -> None: + provider = CloudflareProvider(api_key='api-key') + assert provider.name == 'cloudflare' + assert isinstance(provider.client, AsyncCloudflare) + assert provider.base_url.startswith('https://api.cloudflare.com') + + +def test_cloudflare_provider_need_api_key(env: TestEnv) -> None: + env.remove('CLOUDFLARE_API_KEY') + with pytest.raises(UserError, match='CLOUDFLARE_API_KEY'): + CloudflareProvider() + + +def test_cloudflare_provider_pass_http_client() -> None: + http_client = httpx.AsyncClient() + provider = CloudflareProvider(api_key='api-key', http_client=http_client) + assert isinstance(provider.client, AsyncCloudflare) + + +def test_cloudflare_provider_pass_client() -> None: + cloudflare_client = AsyncCloudflare(api_key='test-api-key') + provider = CloudflareProvider(cloudflare_client=cloudflare_client) + assert provider.client == cloudflare_client diff --git a/uv.lock b/uv.lock index 227c424b5..a0cf8bad9 100644 --- a/uv.lock +++ b/uv.lock @@ -686,6 +686,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188 }, ] +[[package]] +name = "cloudflare" +version = "4.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/14/3c/52a4246ec925bf0133d89a151a0258b40a407a576538b648c2afa210831e/cloudflare-4.1.0.tar.gz", hash = "sha256:6b9fbe994856fe942adc6a4881b7beded208034dc5c53dd683751f632e47a1ac", size = 1826873 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d4/a4/a8087940e6c4b8c518fde198e5646ba86591c0811b4e7a71750568a0cf08/cloudflare-4.1.0-py3-none-any.whl", hash = "sha256:ec6db9c04a8835440dcb475d9f0ac48771b4db724cc14bb1a5d744b73993f262", size = 4089940 }, +] + [[package]] name = "cohere" version = "5.13.12" @@ -2917,6 +2934,9 @@ cli = [ { name = "prompt-toolkit" }, { name = "rich" }, ] +cloudflare = [ + { name = "cloudflare" }, +] cohere = [ { name = "cohere", marker = "sys_platform != 'emscripten'" }, ] @@ -2971,6 +2991,7 @@ requires-dist = [ { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" }, { name = "argcomplete", marker = "extra == 'cli'", specifier = ">=3.5.0" }, { name = "boto3", marker = "extra == 'bedrock'", specifier = ">=1.35.74" }, + { name = "cloudflare", marker = "extra == 'cloudflare'", specifier = ">=4.1.0" }, { name = "cohere", marker = "sys_platform != 'emscripten' and extra == 'cohere'", specifier = ">=5.13.11" }, { name = "duckduckgo-search", marker = "extra == 'duckduckgo'", specifier = ">=7.0.0" }, { name = "eval-type-backport", specifier = ">=0.2.0" }, @@ -2993,7 +3014,7 @@ requires-dist = [ { name = "tavily-python", marker = "extra == 'tavily'", specifier = ">=0.5.0" }, { name = "typing-inspection", specifier = ">=0.4.0" }, ] -provides-extras = ["anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "groq", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] +provides-extras = ["anthropic", "bedrock", "cli", "cloudflare", "cohere", "duckduckgo", "evals", "groq", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] [package.metadata.requires-dev] dev = [