diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index eb82de2bc..9f52feca5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -6,12 +6,15 @@ from __future__ import annotations as _annotations +import asyncio +import weakref from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Iterator +from collections.abc import AsyncIterator, Iterator, MutableMapping from contextlib import asynccontextmanager, contextmanager from dataclasses import dataclass, field from datetime import datetime from functools import cache +from types import TracebackType import httpx from typing_extensions import Literal, TypeAliasType @@ -506,15 +509,48 @@ def cached_async_http_client(*, provider: str | None = None, timeout: int = 600, @cache def _cached_async_http_client(provider: str | None, timeout: int = 600, connect: int = 5) -> httpx.AsyncClient: return httpx.AsyncClient( - transport=_cached_async_http_transport(), + transport=_get_transport_for_loop(), timeout=httpx.Timeout(timeout=timeout, connect=connect), headers={'User-Agent': get_user_agent()}, ) @cache -def _cached_async_http_transport() -> httpx.AsyncHTTPTransport: - return httpx.AsyncHTTPTransport() +def _get_transport_for_loop() -> _PerLoopTransport: + return _PerLoopTransport() + + +class _PerLoopTransport(httpx.AsyncBaseTransport): + def __init__(self): + self.transports: MutableMapping[asyncio.AbstractEventLoop, httpx.AsyncHTTPTransport] = ( + weakref.WeakKeyDictionary() + ) + + def get_transport(self) -> httpx.AsyncHTTPTransport: + # Clean the dictionary of closed loops + for loop in list(self.transports.keys()): + if loop.is_closed(): + del self.transports[loop] + + return self.transports.setdefault(asyncio.get_running_loop(), httpx.AsyncHTTPTransport()) + + async def __aenter__(self): + await self.get_transport().__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + await self.get_transport().__aexit__(exc_type, exc_value, traceback) + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + return await self.get_transport().handle_async_request(request) + + async def aclose(self) -> None: + await self.get_transport().aclose() @cache