diff --git a/src/guidellm/__init__.py b/src/guidellm/__init__.py index e5620188..b10b4455 100644 --- a/src/guidellm/__init__.py +++ b/src/guidellm/__init__.py @@ -6,6 +6,7 @@ # flake8: noqa import os + import transformers # type: ignore os.environ["TOKENIZERS_PARALLELISM"] = "false" # Silence warnings for tokenizers diff --git a/src/guidellm/backend/__init__.py b/src/guidellm/backend/__init__.py index 875e319e..13910180 100644 --- a/src/guidellm/backend/__init__.py +++ b/src/guidellm/backend/__init__.py @@ -1,5 +1,6 @@ from .base import Backend, BackendEngine, BackendEnginePublic, GenerativeResponse from .openai import OpenAIBackend +from .aiohttp import AiohttpBackend __all__ = [ "Backend", @@ -7,4 +8,5 @@ "BackendEnginePublic", "GenerativeResponse", "OpenAIBackend", + "AiohttpBackend" ] diff --git a/src/guidellm/backend/aiohttp.py b/src/guidellm/backend/aiohttp.py new file mode 100644 index 00000000..fbbd9715 --- /dev/null +++ b/src/guidellm/backend/aiohttp.py @@ -0,0 +1,180 @@ +import base64 +import io +from typing import AsyncGenerator, Dict, List, Optional +from loguru import logger + +import aiohttp +import json + +from guidellm.backend.base import Backend, GenerativeResponse +from guidellm.config import settings +from guidellm.core import TextGenerationRequest + +__all__ = ["AiohttpBackend"] + +@Backend.register("aiohttp_server") +class AiohttpBackend(Backend): + """ + An aiohttp-based backend implementation for LLM requests. + + This class provides an interface to communicate with a server hosting + an LLM API using aiohttp for asynchronous requests. + """ + + def __init__( + self, + openai_api_key: Optional[str] = None, + target: Optional[str] = None, + model: Optional[str] = None, + timeout: Optional[float] = None, + **request_args, + ): + self._request_args: Dict = request_args + self._api_key: str = openai_api_key or settings.aiohttp.api_key + + if not self._api_key: + err = ValueError( + "`GUIDELLM__AIOHTTP__API_KEY` environment variable or " + "--openai-api-key CLI parameter must be specified for the " + "aiohttp backend." + ) + logger.error("{}", err) + raise err + + base_url = target or settings.aiohttp.base_url + self._api_url = f"{base_url}/chat/completions" + + if not base_url: + err = ValueError( + "`GUIDELLM__AIOHTTP__BASE_URL` environment variable or " + "target parameter must be specified for the OpenAI backend." + ) + logger.error("{}", err) + raise err + + self._timeout = aiohttp.ClientTimeout(total=timeout or settings.request_timeout) + self._model = model + + super().__init__(type_="aiohttp_backend", target=base_url, model=self._model) + logger.info("aiohttp {} Backend listening on {}", self._model, base_url) + + async def make_request( + self, + request: TextGenerationRequest, + ) -> AsyncGenerator[GenerativeResponse, None]: + """ + Make a request to the aiohttp backend. + + Sends a prompt to the LLM server and streams the response tokens. + + :param request: The text generation request to submit. + :type request: TextGenerationRequest + :yield: A stream of GenerativeResponse objects. + :rtype: AsyncGenerator[GenerativeResponse, None] + """ + + async with aiohttp.ClientSession(timeout=self._timeout) as session: + logger.debug("Making request to aiohttp backend with prompt: {}", request.prompt) + + request_args = {} + if request.output_token_count is not None: + request_args.update( + { + "max_completion_tokens": request.output_token_count, + "stop": None, + "ignore_eos": True, + } + ) + elif settings.aiohttp.max_gen_tokens and settings.aiohttp.max_gen_tokens > 0: + request_args.update( + { + "max_tokens": settings.aiohttp.max_gen_tokens, + } + ) + + request_args.update(self._request_args) + + messages = self._build_messages(request) + + payload = { + "model": self._model, + "messages": messages, + "stream": True, + **request_args, + } + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self._api_key}", + } + + try: + async with session.post(url=self._api_url, json=payload, headers=headers) as response: + if response.status != 200: + error_message = await response.text() + logger.error("Request failed: {} - {}", response.status, error_message) + raise Exception(f"Failed to generate response: {error_message}") + + token_count = 0 + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") + if chunk == "[DONE]": + # Final response + yield GenerativeResponse( + type_="final", + prompt=request.prompt, + output_token_count=token_count, + prompt_token_count=request.prompt_token_count, + ) + else: + # Intermediate token response + token_count += 1 + data = json.loads(chunk) + delta = data["choices"][0]["delta"] + token = delta["content"] + yield GenerativeResponse( + type_="token_iter", + add_token=token, + prompt=request.prompt, + output_token_count=token_count, + prompt_token_count=request.prompt_token_count, + ) + except Exception as e: + logger.error("Error while making request: {}", e) + raise + + def available_models(self) -> List[str]: + """ + Retrieve a list of available models from the server. + """ + # This could include an API call to `self._api_url/models` if the server supports it. + logger.warning("Fetching available models is not implemented for aiohttp backend.") + return [] + + def validate_connection(self): + """ + Validate the connection to the backend server. + """ + logger.info("Connection validation is not explicitly implemented for aiohttp backend.") + + def _build_messages(self, request: TextGenerationRequest) -> Dict: + if request.number_images == 0: + messages = [{"role": "user", "content": request.prompt}] + else: + content = [] + for image in request.images: + stream = io.BytesIO() + im_format = image.image.format or "PNG" + image.image.save(stream, format=im_format) + im_b64 = base64.b64encode(stream.getvalue()).decode("utf-8") + image_url = {"url": f"data:image/{im_format.lower()};base64,{im_b64}"} + content.append({"type": "image_url", "image_url": image_url}) + + content.append({"type": "text", "text": request.prompt}) + messages = [{"role": "user", "content": content}] + + return messages diff --git a/src/guidellm/backend/base.py b/src/guidellm/backend/base.py index d71c5f66..a1658594 100644 --- a/src/guidellm/backend/base.py +++ b/src/guidellm/backend/base.py @@ -15,7 +15,7 @@ __all__ = ["Backend", "BackendEngine", "BackendEnginePublic", "GenerativeResponse"] -BackendEnginePublic = Literal["openai_server"] +BackendEnginePublic = Literal["openai_server", "aiohttp_server"] BackendEngine = Union[BackendEnginePublic, Literal["test"]] diff --git a/src/guidellm/backend/openai.py b/src/guidellm/backend/openai.py index 8c83f914..9843fc1a 100644 --- a/src/guidellm/backend/openai.py +++ b/src/guidellm/backend/openai.py @@ -1,3 +1,5 @@ +import base64 +import io from typing import AsyncGenerator, Dict, List, Optional from loguru import logger @@ -92,6 +94,9 @@ async def make_request( { "max_tokens": request.output_token_count, "stop": None, + "extra_body": { + "ignore_eos": True, + } } ) elif settings.openai.max_gen_tokens and settings.openai.max_gen_tokens > 0: @@ -103,11 +108,11 @@ async def make_request( request_args.update(self._request_args) + messages = self._build_messages(request) + stream = await self._async_client.chat.completions.create( model=self.model, - messages=[ - {"role": "user", "content": request.prompt}, - ], + messages=messages, stream=True, **request_args, ) @@ -167,3 +172,21 @@ def validate_connection(self): except Exception as error: logger.error("Failed to validate OpenAI connection: {}", error) raise error + + def _build_messages(self, request: TextGenerationRequest) -> Dict: + if request.number_images == 0: + messages = [{"role": "user", "content": request.prompt}] + else: + content = [] + for image in request.images: + stream = io.BytesIO() + im_format = image.image.format or "PNG" + image.image.save(stream, format=im_format) + im_b64 = base64.b64encode(stream.getvalue()).decode("utf-8") + image_url = {"url": f"data:image/{im_format.lower()};base64,{im_b64}"} + content.append({"type": "image_url", "image_url": image_url}) + + content.append({"type": "text", "text": request.prompt}) + messages = [{"role": "user", "content": content}] + + return messages diff --git a/src/guidellm/config.py b/src/guidellm/config.py index c3d950ec..a19a6244 100644 --- a/src/guidellm/config.py +++ b/src/guidellm/config.py @@ -90,6 +90,7 @@ class EmulatedDataSettings(BaseModel): "force_new_line_punctuation": True, } ) + image_source: List[str] = "https://www.gutenberg.org/cache/epub/1342/pg1342-images.html" class OpenAISettings(BaseModel): @@ -108,6 +109,9 @@ class OpenAISettings(BaseModel): max_gen_tokens: int = 4096 +class AiohttpSettings(OpenAISettings): + pass + class ReportGenerationSettings(BaseModel): """ Report generation settings for the application @@ -152,6 +156,7 @@ class Settings(BaseSettings): # Request settings openai: OpenAISettings = OpenAISettings() + aiohttp: AiohttpSettings = AiohttpSettings() # Report settings report_generation: ReportGenerationSettings = ReportGenerationSettings() diff --git a/src/guidellm/core/report.py b/src/guidellm/core/report.py index b6791e45..c48eed56 100644 --- a/src/guidellm/core/report.py +++ b/src/guidellm/core/report.py @@ -147,19 +147,15 @@ def _create_benchmark_report_data_tokens_summary( for benchmark in report.benchmarks_sorted: table.add_row( _benchmark_rate_id(benchmark), - f"{benchmark.prompt_token_distribution.mean:.2f}", + f"{benchmark.prompt_token:.2f}", ", ".join( f"{percentile:.1f}" - for percentile in benchmark.prompt_token_distribution.percentiles( - [1, 5, 50, 95, 99] - ) + for percentile in benchmark.prompt_token_percentiles ), - f"{benchmark.output_token_distribution.mean:.2f}", + f"{benchmark.output_token:.2f}", ", ".join( f"{percentile:.1f}" - for percentile in benchmark.output_token_distribution.percentiles( - [1, 5, 50, 95, 99] - ) + for percentile in benchmark.output_token_percentiles ), ) logger.debug("Created data tokens summary table for the report.") @@ -181,7 +177,7 @@ def _create_benchmark_report_dist_perf_summary( "Benchmark", "Request Latency [1%, 5%, 10%, 50%, 90%, 95%, 99%] (sec)", "Time to First Token [1%, 5%, 10%, 50%, 90%, 95%, 99%] (ms)", - "Inter Token Latency [1%, 5%, 10%, 50%, 90% 95%, 99%] (ms)", + "Inter Token Latency [1%, 5%, 10%, 50%, 90%, 95%, 99%] (ms)", title="[magenta]Performance Stats by Benchmark[/magenta]", title_style="bold", title_justify="left", @@ -193,21 +189,15 @@ def _create_benchmark_report_dist_perf_summary( _benchmark_rate_id(benchmark), ", ".join( f"{percentile:.2f}" - for percentile in benchmark.request_latency_distribution.percentiles( - [1, 5, 10, 50, 90, 95, 99] - ) + for percentile in benchmark.request_latency_percentiles ), ", ".join( f"{percentile * 1000:.1f}" - for percentile in benchmark.ttft_distribution.percentiles( - [1, 5, 10, 50, 90, 95, 99] - ) + for percentile in benchmark.time_to_first_token_percentiles ), ", ".join( f"{percentile * 1000:.1f}" - for percentile in benchmark.itl_distribution.percentiles( - [1, 5, 10, 50, 90, 95, 99] - ) + for percentile in benchmark.inter_token_latency_percentiles ), ) logger.debug("Created distribution performance summary table for the report.") diff --git a/src/guidellm/core/request.py b/src/guidellm/core/request.py index 4f7315c5..06d0f37c 100644 --- a/src/guidellm/core/request.py +++ b/src/guidellm/core/request.py @@ -1,9 +1,10 @@ import uuid -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Tuple from pydantic import Field from guidellm.core.serializable import Serializable +from guidellm.utils import ImageDescriptor class TextGenerationRequest(Serializable): @@ -16,6 +17,10 @@ class TextGenerationRequest(Serializable): description="The unique identifier for the request.", ) prompt: str = Field(description="The input prompt for the text generation.") + images: Optional[List[ImageDescriptor]] = Field( + default=None, + description="Input images.", + ) prompt_token_count: Optional[int] = Field( default=None, description="The number of tokens in the input prompt.", @@ -29,6 +34,21 @@ class TextGenerationRequest(Serializable): description="The parameters for the text generation request.", ) + @property + def number_images(self) -> int: + if self.images is None: + return 0 + else: + return len(self.images) + + @property + def image_resolution(self) -> List[Tuple[int, int]]: + if self.images is None: + return None + else: + return [im.size for im in self.images] + + def __str__(self) -> str: prompt_short = ( self.prompt[:32] + "..." @@ -41,4 +61,5 @@ def __str__(self) -> str: f"prompt={prompt_short}, prompt_token_count={self.prompt_token_count}, " f"output_token_count={self.output_token_count}, " f"params={self.params})" + f"image_resolution={self.image_resolution}" ) diff --git a/src/guidellm/core/result.py b/src/guidellm/core/result.py index f218784c..aebd1763 100644 --- a/src/guidellm/core/result.py +++ b/src/guidellm/core/result.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Literal, Optional, Union from loguru import logger -from pydantic import Field +from pydantic import Field, computed_field from guidellm.core.distribution import Distribution from guidellm.core.request import TextGenerationRequest @@ -221,6 +221,7 @@ def __iter__(self): """ return iter(self.results) + @computed_field # type: ignore[misc] @property def request_count(self) -> int: """ @@ -231,6 +232,7 @@ def request_count(self) -> int: """ return len(self.results) + @computed_field # type: ignore[misc] @property def error_count(self) -> int: """ @@ -241,6 +243,7 @@ def error_count(self) -> int: """ return len(self.errors) + @computed_field # type: ignore[misc] @property def total_count(self) -> int: """ @@ -251,6 +254,7 @@ def total_count(self) -> int: """ return self.request_count + self.error_count + @computed_field # type: ignore[misc] @property def start_time(self) -> Optional[float]: """ @@ -264,6 +268,7 @@ def start_time(self) -> Optional[float]: return self.results[0].start_time + @computed_field # type: ignore[misc] @property def end_time(self) -> Optional[float]: """ @@ -277,6 +282,7 @@ def end_time(self) -> Optional[float]: return self.results[-1].end_time + @computed_field # type: ignore[misc] @property def duration(self) -> float: """ @@ -290,6 +296,7 @@ def duration(self) -> float: return self.end_time - self.start_time + @computed_field # type: ignore[misc] @property def completed_request_rate(self) -> float: """ @@ -303,6 +310,7 @@ def completed_request_rate(self) -> float: return len(self.results) / self.duration + @computed_field # type: ignore[misc] @property def request_latency(self) -> float: """ @@ -332,6 +340,19 @@ def request_latency_distribution(self) -> Distribution: ] ) + @computed_field # type: ignore[misc] + @property + def request_latency_percentiles(self) -> List[float]: + """ + Get standard percentiles of request latency in seconds. + + :return: List of percentile request latency in seconds + :rtype: List[float] + """ + return self.request_latency_distribution.percentiles([1, 5, 10, 50, 90, 95, 99]) + + + @computed_field # type: ignore[misc] @property def time_to_first_token(self) -> float: """ @@ -361,6 +382,20 @@ def ttft_distribution(self) -> Distribution: ] ) + @computed_field # type: ignore[misc] + @property + def time_to_first_token_percentiles(self) -> List[float]: + """ + Get standard percentiles for time taken to decode the first token + in milliseconds. + + :return: List of percentile time taken to decode the first token + in milliseconds. + :rtype: List[float] + """ + return self.ttft_distribution.percentiles([1, 5, 10, 50, 90, 95, 99]) + + @computed_field # type: ignore[misc] @property def inter_token_latency(self) -> float: """ @@ -388,6 +423,18 @@ def itl_distribution(self) -> Distribution: ] ) + @computed_field # type: ignore[misc] + @property + def inter_token_latency_percentiles(self) -> List[float]: + """ + Get standard percentiles for the time between tokens in milliseconds. + + :return: List of percentiles for the average time between tokens. + :rtype: List[float] + """ + return self.itl_distribution.percentiles([1, 5, 10, 50, 90, 95, 99]) + + @computed_field # type: ignore[misc] @property def output_token_throughput(self) -> float: """ @@ -403,6 +450,17 @@ def output_token_throughput(self) -> float: return total_tokens / self.duration + @computed_field # type: ignore[misc] + @property + def prompt_token(self) -> float: + """ + Get the average number of prompt tokens. + + :return: The average number of prompt tokens. + :rtype: float + """ + return self.prompt_token_distribution.mean + @property def prompt_token_distribution(self) -> Distribution: """ @@ -413,6 +471,28 @@ def prompt_token_distribution(self) -> Distribution: """ return Distribution(data=[result.prompt_token_count for result in self.results]) + @computed_field # type: ignore[misc] + @property + def prompt_token_percentiles(self) -> List[float]: + """ + Get standard percentiles for number of prompt tokens. + + :return: List of percentiles of number of prompt tokens. + :rtype: List[float] + """ + return self.prompt_token_distribution.percentiles([1, 5, 50, 95, 99]) + + @computed_field # type: ignore[misc] + @property + def output_token(self) -> float: + """ + Get the average number of output tokens. + + :return: The average number of output tokens. + :rtype: float + """ + return self.output_token_distribution.mean + @property def output_token_distribution(self) -> Distribution: """ @@ -423,6 +503,18 @@ def output_token_distribution(self) -> Distribution: """ return Distribution(data=[result.output_token_count for result in self.results]) + @computed_field # type: ignore[misc] + @property + def output_token_percentiles(self) -> List[float]: + """ + Get standard percentiles for number of output tokens. + + :return: List of percentiles of number of output tokens. + :rtype: List[float] + """ + return self.output_token_distribution.percentiles([1, 5, 50, 95, 99]) + + @computed_field # type: ignore[misc] @property def overloaded(self) -> bool: if ( diff --git a/src/guidellm/main.py b/src/guidellm/main.py index 4016ecec..4748b12d 100644 --- a/src/guidellm/main.py +++ b/src/guidellm/main.py @@ -186,17 +186,17 @@ def generate_benchmark_report_cli( def generate_benchmark_report( target: str, - backend: BackendEnginePublic, - model: Optional[str], data: Optional[str], data_type: Literal["emulated", "file", "transformers"], - tokenizer: Optional[str], - rate_type: ProfileGenerationMode, - rate: Optional[float], - max_seconds: Optional[int], - max_requests: Union[Literal["dataset"], int, None], - output_path: str, - cont_refresh_table: bool, + backend: BackendEnginePublic="openai_server", + model: Optional[str]=None, + tokenizer: Optional[str]=None, + rate_type: ProfileGenerationMode="sweep", + rate: Optional[float]=None, + max_seconds: Optional[int]=120, + max_requests: Union[Literal["dataset"], int, None]=None, + output_path: str=None, + cont_refresh_table: bool=False, ) -> GuidanceReport: """ Generate a benchmark report for a specified backend and dataset. diff --git a/src/guidellm/request/emulated.py b/src/guidellm/request/emulated.py index 7d481cb7..02f564a1 100644 --- a/src/guidellm/request/emulated.py +++ b/src/guidellm/request/emulated.py @@ -11,7 +11,7 @@ from guidellm.config import settings from guidellm.core.request import TextGenerationRequest from guidellm.request.base import GenerationMode, RequestGenerator -from guidellm.utils import clean_text, filter_text, load_text, split_text +from guidellm.utils import clean_text, filter_text, load_images, load_text, split_text __all__ = ["EmulatedConfig", "EmulatedRequestGenerator", "EndlessTokens"] @@ -30,6 +30,9 @@ class EmulatedConfig: generated_tokens_variance (Optional[int]): Variance for generated tokens. generated_tokens_min (Optional[int]): Minimum number of generated tokens. generated_tokens_max (Optional[int]): Maximum number of generated tokens. + images (Optional[int]): Number of images. + width (Optional[int]): Width of images. + height (Optional[int]): Height of images. """ @staticmethod @@ -47,7 +50,7 @@ def create_config(config: Optional[Union[str, Path, Dict]]) -> "EmulatedConfig": """ if not config: logger.debug("Creating default configuration") - return EmulatedConfig(prompt_tokens=1024, generated_tokens=256) + return EmulatedConfig(prompt_tokens=1024, generated_tokens=256, images=0) if isinstance(config, dict): logger.debug("Loading configuration from dict: {}", config) @@ -105,6 +108,10 @@ def create_config(config: Optional[Union[str, Path, Dict]]) -> "EmulatedConfig": generated_tokens_min: Optional[int] = None generated_tokens_max: Optional[int] = None + images: int = 0 + width: int = None + height: int = None + @property def prompt_tokens_range(self) -> Tuple[int, int]: """ @@ -327,6 +334,8 @@ def __init__( settings.emulated_data.filter_start, settings.emulated_data.filter_end, ) + if self._config.images > 0: + self._images = load_images(settings.emulated_data.image_source, [self._config.width, self._config.height]) self._rng = np.random.default_rng(random_seed) # NOTE: Must be after all the parameters since the queue population @@ -355,6 +364,7 @@ def create_item(self) -> TextGenerationRequest: logger.debug("Creating new text generation request") target_prompt_token_count = self._config.sample_prompt_tokens(self._rng) prompt = self.sample_prompt(target_prompt_token_count) + images = self.sample_images() prompt_token_count = len(self.tokenizer.tokenize(prompt)) output_token_count = self._config.sample_output_tokens(self._rng) logger.debug("Generated prompt: {}", prompt) @@ -363,6 +373,7 @@ def create_item(self) -> TextGenerationRequest: prompt=prompt, prompt_token_count=prompt_token_count, output_token_count=output_token_count, + images=images, ) def sample_prompt(self, tokens: int) -> str: @@ -395,3 +406,11 @@ def sample_prompt(self, tokens: int) -> str: right = mid return self._tokens.create_text(start_line_index, left) + + + def sample_images(self): + image_indices = self._rng.choice( + len(self._images), size=self._config.images, replace=False, + ) + + return [self._images[i] for i in image_indices] diff --git a/src/guidellm/utils/__init__.py b/src/guidellm/utils/__init__.py index 2fdd8ca8..eb4931bd 100644 --- a/src/guidellm/utils/__init__.py +++ b/src/guidellm/utils/__init__.py @@ -1,3 +1,4 @@ +from .images import ImageDescriptor, load_images from .injector import create_report, inject_data from .progress import BenchmarkReportProgress from .text import ( @@ -37,4 +38,6 @@ "resolve_transformers_dataset_split", "split_lines_by_punctuation", "split_text", + "ImageDescriptor", + "load_images", ] diff --git a/src/guidellm/utils/images.py b/src/guidellm/utils/images.py new file mode 100644 index 00000000..fb66d432 --- /dev/null +++ b/src/guidellm/utils/images.py @@ -0,0 +1,80 @@ +from io import BytesIO +from typing import List, Optional, Tuple +from urllib.parse import urljoin + +import requests +from bs4 import BeautifulSoup +from loguru import logger +from PIL import Image +from pydantic import ConfigDict, Field, computed_field + +from guidellm.config import settings +from guidellm.core.serializable import Serializable + +__all__ = ["load_images", "ImageDescriptor"] + +class ImageDescriptor(Serializable): + """ + A class to represent image data in serializable format. + """ + model_config = ConfigDict(arbitrary_types_allowed=True) + + url: Optional[str] = Field(description="url address for image.") + image: Image.Image = Field(description="PIL image", exclude=True) + filename: Optional[int] = Field( + default=None, + description="Image filename.", + ) + + @computed_field # type: ignore[misc] + @property + def image_resolution(self) -> Tuple[int, int]: + if self.image is None: + return None + else: + return self.image.size + + +def load_images(data: str, image_resolution: Optional[List[int]]) -> List[ImageDescriptor]: + """ + Load an HTML file from a path or URL + + :param data: the path or URL to load the HTML file from + :type data: Union[str, Path] + :return: Descriptor containing image url and the data in PIL.Image.Image format + :rtype: ImageDescriptor + """ + + images = [] + if not data: + return None + if isinstance(data, str) and data.startswith("http"): + response = requests.get(data, timeout=settings.request_timeout) + response.raise_for_status() + + soup = BeautifulSoup(response.text, "html.parser") + for img_tag in soup.find_all("img"): + img_url = img_tag.get("src") + + if img_url: + # Handle relative URLs + img_url = urljoin(data, img_url) + + # Download the image + logger.debug("Loading image: {}", img_url) + img_response = requests.get(img_url) + img_response.raise_for_status() + image = Image.open(BytesIO(img_response.content)) + + if image_resolution is not None: + image = image.resize(image_resolution) + + # Load image into Pillow + images.append( + ImageDescriptor( + url=img_url, + image=image, + ) + ) + + return images