diff --git a/supporting-blog-content/result-diversification/.python-version b/supporting-blog-content/result-diversification/.python-version new file mode 100644 index 00000000..e4fba218 --- /dev/null +++ b/supporting-blog-content/result-diversification/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/supporting-blog-content/result-diversification/diversification.ipynb b/supporting-blog-content/result-diversification/diversification.ipynb new file mode 100644 index 00000000..06c844c2 --- /dev/null +++ b/supporting-blog-content/result-diversification/diversification.ipynb @@ -0,0 +1,999 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Result diversification with Elasticsearch\n", + "This notebook demonstrates:\n", + "1. Loading fashion dataset\n", + "2. Index in Elasticsearch using image search\n", + "3. Search items with a broad search term\n", + "4. Apply result diversification with the MMR algorithm to the results.\n", + "\n", + "Check out our [blog post](https://www.elastic.co/search-labs/blog/diversify-results-maximum-marginal-relevance) on this topic to learn more about " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Setup and Dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ] + } + ], + "source": [ + "!pip install -r requirements.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "import requests\n", + "import numpy as np\n", + "import kagglehub\n", + "from itertools import repeat\n", + "from concurrent.futures import ThreadPoolExecutor\n", + "from tqdm import tqdm\n", + "import time\n", + "from elasticsearch import Elasticsearch\n", + "from IPython.display import HTML, display\n", + "from typing import List, Dict, Tuple" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Load Configuration\n", + "\n", + "Create a configuration file `elastic_config.env` in this format to authenticate with JINA and the Elastic Cluster. \n", + "```\n", + "ELASTIC_API_KEY=\n", + "ELASTIC_HOST=\n", + "JINA_API_KEY=\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Configuration loaded successfully\n" + ] + } + ], + "source": [ + "def load_config(file_path=\"elastic_config.env\"):\n", + " \"\"\"Load configuration from environment file\"\"\"\n", + " config = {}\n", + " try:\n", + " with open(file_path, \"r\") as file:\n", + " for line in file:\n", + " if \"=\" in line:\n", + " key, value = line.strip().split(\"=\", 1)\n", + " config[key] = value\n", + " except FileNotFoundError:\n", + " print(f\"Configuration file not found: {file_path}\")\n", + " return config\n", + "\n", + "\n", + "config = load_config()\n", + "elastic_host = config.get(\"ELASTIC_HOST\")\n", + "elastic_api_key = config.get(\"ELASTIC_API_KEY\")\n", + "jina_api_key = config.get(\"JINA_API_KEY\")\n", + "\n", + "print(\"Configuration loaded successfully\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Load Dataset and Extract ID & Image URLs" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Path to dataset files: /Users/peter/.cache/kagglehub/datasets/paramaggarwal/fashion-product-images-dataset/versions/1\n", + "Loaded 44446 total products\n", + "\n", + "Filtered to 2694 bottomwear products\n" + ] + } + ], + "source": [ + "dataset_path = kagglehub.dataset_download(\n", + " \"paramaggarwal/fashion-product-images-dataset\"\n", + ")\n", + "print(\"Path to dataset files:\", dataset_path)\n", + "\n", + "styles_folder = os.path.join(dataset_path, \"fashion-dataset/styles\")\n", + "\n", + "\n", + "def load_dataset(folder_path):\n", + " \"\"\"Load all JSON files from the dataset folder\"\"\"\n", + " products = []\n", + "\n", + " for filename in os.listdir(folder_path):\n", + " if filename.endswith(\".json\"):\n", + " file_path = os.path.join(folder_path, filename)\n", + " try:\n", + " with open(file_path, \"r\") as f:\n", + " data = json.load(f)\n", + " if \"data\" in data:\n", + " products.append(data[\"data\"])\n", + " except Exception as e:\n", + " print(f\"Error reading {filename}: {e}\")\n", + "\n", + " return products\n", + "\n", + "\n", + "products = load_dataset(styles_folder)\n", + "print(f\"Loaded {len(products)} total products\")\n", + "\n", + "# Filter for bottomwear only to limit data for this demo\n", + "bottomwear_products = []\n", + "for product in products:\n", + " sub_category = product.get(\"subCategory\", {})\n", + " if sub_category.get(\"typeName\", \"\").lower() == \"bottomwear\":\n", + " bottomwear_products.append(product)\n", + "\n", + "print(f\"\\nFiltered to {len(bottomwear_products)} bottomwear products\")\n", + "\n", + "products = bottomwear_products" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracted 2693 products with valid IDs and image URLs\n", + "\n", + "Limited to 1000 items for demo\n", + "\n", + "Sample items (alphabetically sorted):\n", + " - Femella Women Off White Shorts (Shorts, Off White)\n", + " - Nike Women Strong Poly Black Capri (Capris, Black)\n", + " - Flying Machine Men Blue Jeans (Jeans, Blue)\n", + " - Urban Yoga Men Black Shorts (Shorts, Black)\n", + " - Doodle Girls Lace Bow LT.Pink Leggings (Leggings, Pink)\n" + ] + } + ], + "source": [ + "def extract_id_and_image_url(products):\n", + " \"\"\"Extract ID and image URL from products\"\"\"\n", + " image_data = []\n", + "\n", + " for product in products:\n", + " product_id = product.get(\"id\")\n", + "\n", + " style_images = product.get(\"styleImages\", {})\n", + " default_image = style_images.get(\"default\", {})\n", + "\n", + " image_url = default_image.get(\"resolutions\", {}).get(\"360X480\", \"\")\n", + " if not image_url:\n", + " image_url = default_image.get(\"imageURL\", \"\")\n", + "\n", + " if product_id and image_url:\n", + " image_data.append(\n", + " {\n", + " \"id\": product_id,\n", + " \"image_url\": image_url,\n", + " \"product_name\": product.get(\"productDisplayName\", \"\"),\n", + " \"brand\": product.get(\"brandName\", \"\"),\n", + " \"color\": product.get(\"baseColour\", \"\"),\n", + " \"article_type\": product.get(\"articleType\", {}).get(\"typeName\", \"\"),\n", + " }\n", + " )\n", + "\n", + " return image_data\n", + "\n", + "\n", + "image_data = extract_id_and_image_url(products)\n", + "print(f\"Extracted {len(image_data)} products with valid IDs and image URLs\")\n", + "\n", + "# Only use 1000 products to not make the demo too heavy\n", + "demo_image_data = image_data[:1000]\n", + "print(f\"\\nLimited to {len(demo_image_data)} items for demo\")\n", + "print(f\"\\nSample items (alphabetically sorted):\")\n", + "for i in range(min(5, len(demo_image_data))):\n", + " item = demo_image_data[i]\n", + " print(f\" - {item['product_name']} ({item['article_type']}, {item['color']})\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Create Image Embeddings with JINA API" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Getting embeddings...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Getting embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [05:25<00:00, 3.08it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Retrieved 1000 embeddings\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "def get_single_image_embedding(item, jina_api_key):\n", + " \"\"\"Get embedding for a single image\"\"\"\n", + " url = \"https://api.jina.ai/v1/embeddings\"\n", + " headers = {\n", + " \"Content-Type\": \"application/json\",\n", + " \"Authorization\": f\"Bearer {jina_api_key}\",\n", + " }\n", + "\n", + " product_data = {\n", + " \"product_name\": item[\"product_name\"],\n", + " \"brand\": item[\"brand\"],\n", + " \"color\": item[\"color\"],\n", + " \"article_type\": item[\"article_type\"],\n", + " }\n", + "\n", + " data = {\n", + " \"model\": \"jina-embeddings-v4\",\n", + " \"dimensions\": 1024,\n", + " \"normalized\": True,\n", + " \"task\": \"retrieval.passage\",\n", + " \"embedding_type\": \"float\",\n", + " \"input\": [{\"text\": f\"{product_data}\"}, {\"image\": item[\"image_url\"]}],\n", + " }\n", + "\n", + " try:\n", + " response = requests.post(url, headers=headers, json=data, timeout=200)\n", + " response.raise_for_status()\n", + "\n", + " result = response.json()\n", + " if \"data\" in result and len(result[\"data\"]) > 0:\n", + " return {\n", + " \"id\": item[\"id\"],\n", + " \"image_url\": item[\"image_url\"],\n", + " \"product_name\": item[\"product_name\"],\n", + " \"brand\": item[\"brand\"],\n", + " \"color\": item[\"color\"],\n", + " \"article_type\": item[\"article_type\"],\n", + " \"image_vector\": to_avg_vector(\n", + " [result[\"data\"][0][\"embedding\"], result[\"data\"][1][\"embedding\"]]\n", + " ),\n", + " }\n", + " return None\n", + " except Exception as e:\n", + " print(f\"Error processing {item}: {e}\")\n", + " return None\n", + "\n", + "\n", + "# encode image and product information in one vector\n", + "def to_avg_vector(vectors):\n", + " vectors_array = np.array(vectors)\n", + "\n", + " avg_vector = np.mean(vectors_array, axis=0)\n", + "\n", + " norm = np.linalg.norm(avg_vector)\n", + " if norm > 0:\n", + " normalized_avg_vector = avg_vector / norm\n", + " else:\n", + " normalized_avg_vector = avg_vector\n", + "\n", + " return normalized_avg_vector.tolist()\n", + "\n", + "\n", + "print(\"Getting embeddings...\")\n", + "\n", + "with ThreadPoolExecutor(max_workers=10) as executor:\n", + " products_with_vectors = list(\n", + " tqdm(\n", + " executor.map(\n", + " get_single_image_embedding, demo_image_data, repeat(jina_api_key)\n", + " ),\n", + " total=len(demo_image_data),\n", + " desc=\"Getting embeddings\",\n", + " )\n", + " )\n", + "\n", + "print(f\"Retrieved {len(products_with_vectors)} embeddings\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original products: 1000\n", + "After filtering similar items: 758\n", + "Removed 242 similar items\n" + ] + } + ], + "source": [ + "def _cosine_similarity(X, Y):\n", + " \"\"\"Compute cosine similarity between two sets of vectors.\"\"\"\n", + " X = np.array(X)\n", + " Y = np.array(Y)\n", + "\n", + " if X.ndim == 1:\n", + " X = X.reshape(1, -1)\n", + " if Y.ndim == 1:\n", + " Y = Y.reshape(1, -1)\n", + "\n", + " # Normalize the vectors\n", + " X_norm = X / np.linalg.norm(X, axis=1, keepdims=True)\n", + " Y_norm = Y / np.linalg.norm(Y, axis=1, keepdims=True)\n", + "\n", + " return np.dot(X_norm, Y_norm.T)\n", + "\n", + "\n", + "def filter_out_similar_items(items, threshold=0.98):\n", + " \"\"\"Filter out items that have very high similarity to previously seen items\"\"\"\n", + " filtered_items = []\n", + "\n", + " for i, item1 in enumerate(items):\n", + " is_similar_to_existing = False\n", + "\n", + " for existing_item in filtered_items:\n", + " similarity = _cosine_similarity(\n", + " [item1[\"image_vector\"]], [existing_item[\"image_vector\"]]\n", + " )[0][0]\n", + "\n", + " if similarity >= threshold:\n", + " is_similar_to_existing = True\n", + " break\n", + "\n", + " if not is_similar_to_existing:\n", + " filtered_items.append(item1)\n", + "\n", + " return filtered_items\n", + "\n", + "\n", + "# Filter out items with similarity >= 0.98\n", + "filtered_products = filter_out_similar_items(products_with_vectors, threshold=0.98)\n", + "\n", + "print(f\"Original products: {len(products_with_vectors)}\")\n", + "print(f\"After filtering similar items: {len(filtered_products)}\")\n", + "print(f\"Removed {len(products_with_vectors) - len(filtered_products)} similar items\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Setup Elasticsearch Index" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Deleted existing index 'fashion_images'\n", + "Created index 'fashion_images'\n" + ] + } + ], + "source": [ + "# Initialize Elasticsearch client\n", + "es = Elasticsearch(elastic_host, api_key=elastic_api_key)\n", + "\n", + "# Define index name\n", + "index_name = \"fashion_images\"\n", + "\n", + "# Define index mapping\n", + "mapping = {\n", + " \"mappings\": {\n", + " \"properties\": {\n", + " \"id\": {\"type\": \"keyword\"},\n", + " \"image_url\": {\"type\": \"keyword\"},\n", + " \"product_name\": {\"type\": \"keyword\"},\n", + " \"brand\": {\"type\": \"keyword\"},\n", + " \"color\": {\"type\": \"keyword\"},\n", + " \"article_type\": {\"type\": \"keyword\"},\n", + " \"image_vector\": {\n", + " \"type\": \"dense_vector\",\n", + " \"dims\": 1024,\n", + " \"index\": True,\n", + " \"similarity\": \"cosine\",\n", + " \"index_options\": {\"type\": \"flat\"},\n", + " },\n", + " }\n", + " }\n", + "}\n", + "\n", + "if es.indices.exists(index=index_name):\n", + " es.indices.delete(index=index_name)\n", + " print(f\"Deleted existing index '{index_name}'\")\n", + "\n", + "es.indices.create(index=index_name, body=mapping)\n", + "print(f\"Created index '{index_name}'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Index Documents with Image Vectors" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "start\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Indexing images: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 758/758 [00:26<00:00, 28.72it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Successfully indexed 758 documents\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "def index_single_image(item):\n", + " try:\n", + " es.index(index=index_name, id=item[\"id\"], document=item)\n", + " return 1\n", + " except Exception as e:\n", + " print(f\"Error indexing document {item['id']}: {e}\")\n", + " return 0\n", + "\n", + "\n", + "print(\"start\")\n", + "\n", + "# Index the documents in parallel\n", + "with ThreadPoolExecutor(max_workers=10) as executor:\n", + " results = list(\n", + " tqdm(\n", + " executor.map(index_single_image, filtered_products),\n", + " total=len(filtered_products),\n", + " desc=\"Indexing images\",\n", + " )\n", + " )\n", + "\n", + "indexed_count = sum(results)\n", + "print(f\"Successfully indexed {indexed_count} documents\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Query Images with Text Search" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating text embedding for: 'pants'\n", + "\n", + "Searching for items similar to: 'pants'\n", + "Found 150 similar images\n" + ] + } + ], + "source": [ + "SEARCH_QUERY = \"pants\"\n", + "\n", + "\n", + "def get_text_embedding(text, jina_api_key):\n", + " \"\"\"Get text embedding from JINA API\"\"\"\n", + " url = \"https://api.jina.ai/v1/embeddings\"\n", + " headers = {\n", + " \"Content-Type\": \"application/json\",\n", + " \"Authorization\": f\"Bearer {jina_api_key}\",\n", + " }\n", + "\n", + " data = {\n", + " \"model\": \"jina-embeddings-v4\",\n", + " \"dimensions\": 1024,\n", + " \"normalized\": True,\n", + " \"embedding_type\": \"float\",\n", + " \"task\": \"retrieval.query\",\n", + " \"input\": [{\"text\": text}],\n", + " }\n", + "\n", + " try:\n", + " response = requests.post(url, headers=headers, json=data, timeout=30)\n", + " response.raise_for_status()\n", + " result = response.json()\n", + "\n", + " if \"data\" in result and len(result[\"data\"]) > 0:\n", + " return result[\"data\"][0][\"embedding\"]\n", + " except Exception as e:\n", + " print(f\"Error getting text embedding: {e}\")\n", + "\n", + " return None\n", + "\n", + "\n", + "def search_similar_images(es, index_name, query_vector, k=20):\n", + " \"\"\"Search for similar images using vector similarity\"\"\"\n", + " query = {\n", + " \"knn\": {\n", + " \"field\": \"image_vector\",\n", + " \"query_vector\": query_vector,\n", + " \"k\": k,\n", + " },\n", + " \"size\": k,\n", + " }\n", + "\n", + " response = es.search(index=index_name, body=query)\n", + "\n", + " results = []\n", + " for hit in response[\"hits\"][\"hits\"]:\n", + " # Find the original product data to get additional info\n", + " product_id = hit[\"_source\"][\"id\"]\n", + "\n", + " results.append(\n", + " {\n", + " \"id\": product_id,\n", + " \"image_url\": hit[\"_source\"][\"image_url\"],\n", + " \"image_vector\": hit[\"_source\"][\"image_vector\"],\n", + " \"score\": hit[\"_score\"],\n", + " \"product_name\": hit[\"_source\"][\"product_name\"],\n", + " \"brand\": hit[\"_source\"][\"brand\"],\n", + " \"color\": hit[\"_source\"][\"color\"],\n", + " \"article_type\": hit[\"_source\"][\"article_type\"],\n", + " }\n", + " )\n", + "\n", + " return results\n", + "\n", + "\n", + "print(f\"Creating text embedding for: '{SEARCH_QUERY}'\")\n", + "query_vector = get_text_embedding(SEARCH_QUERY, jina_api_key)\n", + "\n", + "if query_vector:\n", + " print(f\"\\nSearching for items similar to: '{SEARCH_QUERY}'\")\n", + " search_results = search_similar_images(es, index_name, query_vector, k=150)\n", + " print(f\"Found {len(search_results)} similar images\")\n", + "else:\n", + " print(\"Failed to get text embedding\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. Display Search Results\\n\\nShowing results for text search: **\"pants\"**" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

Original Search Results

\n", + "
\n", + " \n", + "

ID: 9785

\n", + "

Urban Yoga Women Summer B...

\n", + "

Track Pants - Navy Blue

\n", + "

Score: 0.862

\n", + "
\n", + " \n", + "
\n", + " \n", + "

ID: 7128

\n", + "

Urban Yoga Women Bottom B...

\n", + "

Track Pants - Black

\n", + "

Score: 0.861

\n", + "
\n", + " \n", + "
\n", + " \n", + "

ID: 19242

\n", + "

Puma Women Grey Capri Pan...

\n", + "

Capris - Grey

\n", + "

Score: 0.858

\n", + "
\n", + " \n", + "
\n", + " \n", + "

ID: 3921

\n", + "

Urban Yoga Men's Bottom B...

\n", + "

Track Pants - Black

\n", + "

Score: 0.857

\n", + "
\n", + " \n", + "
\n", + " \n", + "

ID: 52529

\n", + "

Pepe Jeans Men Grey 3/4 L...

\n", + "

Shorts - Grey

\n", + "

Score: 0.856

\n", + "
\n", + "
\n", + "
\n", + " \n", + "

ID: 4826

\n", + "

ADIDAS Men's Woven Dark N...

\n", + "

Track Pants - Navy Blue

\n", + "

Score: 0.855

\n", + "
\n", + " \n", + "
\n", + " \n", + "

ID: 44664

\n", + "

Wills Lifestyle Women Cha...

\n", + "

Trousers - Charcoal

\n", + "

Score: 0.854

\n", + "
\n", + " \n", + "
\n", + " \n", + "

ID: 7133

\n", + "

Urban Yoga Men Bottom Gre...

\n", + "

Track Pants - Grey

\n", + "

Score: 0.854

\n", + "
\n", + " \n", + "
\n", + " \n", + "

ID: 43522

\n", + "

French Connection Women N...

\n", + "

Trousers - Navy Blue

\n", + "

Score: 0.854

\n", + "
\n", + " \n", + "
\n", + " \n", + "

ID: 18869

\n", + "

Puma Women Black Core Tra...

\n", + "

Track Pants - Black

\n", + "

Score: 0.854

\n", + "
\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def display_images(images, title=\"Images\", max_per_row=5):\n", + " \"\"\"Display images in a grid layout\"\"\"\n", + " html = f\"

{title}

\"\n", + " html += '
'\n", + " images = images[:10]\n", + "\n", + " for i, img in enumerate(images):\n", + " score = img.get(\"score\", \"N/A\")\n", + " if isinstance(score, (int, float)):\n", + " score_str = f\"{score:.3f}\"\n", + " else:\n", + " score_str = \"N/A\"\n", + "\n", + " product_name = img.get(\"product_name\", \"N/A\")\n", + " if product_name != \"N/A\" and len(product_name) > 25:\n", + " product_name = product_name[:25] + \"...\"\n", + "\n", + " html += f\"\"\"\n", + "
\n", + " \n", + "

ID: {img['id']}

\n", + "

{product_name}

\n", + "

{img.get('article_type', '')} - {img.get('color', '')}

\n", + "

Score: {score_str}

\n", + "
\n", + " \"\"\"\n", + "\n", + " if (i + 1) % max_per_row == 0:\n", + " html += '
'\n", + "\n", + " html += \"
\"\n", + " display(HTML(html))\n", + "\n", + "\n", + "display_images(search_results, \"Original Search Results\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 9. Reranking with Maximum Marginal Relevance (MMR)\n", + "MMR is a diversity-promoting algorithm that balances:\n", + "\n", + "**Relevance**: How well items match the query \n", + "**Diversity**: How different items are from each other \n", + "The algorithm iteratively selects items that are relevant to the query but different from already selected items." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

Reranked Results (MMR)

\n", + "
\n", + " \n", + "

ID: 9785

\n", + "

Urban Yoga Women Summer B...

\n", + "

Track Pants - Navy Blue

\n", + "

Score: 0.862

\n", + "
\n", + " \n", + "
\n", + " \n", + "

ID: 41163

\n", + "

Allen Solly Woman Khaki T...

\n", + "

Trousers - Khaki

\n", + "

Score: 0.839

\n", + "
\n", + " \n", + "
\n", + " \n", + "

ID: 13255

\n", + "

Palm Tree Kids Boys Check...

\n", + "

Shorts - White

\n", + "

Score: 0.835

\n", + "
\n", + " \n", + "
\n", + " \n", + "

ID: 4774

\n", + "

ADIDAS Women 3S Pink Trac...

\n", + "

Track Pants - Pink

\n", + "

Score: 0.837

\n", + "
\n", + " \n", + "
\n", + " \n", + "

ID: 52529

\n", + "

Pepe Jeans Men Grey 3/4 L...

\n", + "

Shorts - Grey

\n", + "

Score: 0.856

\n", + "
\n", + "
\n", + "
\n", + " \n", + "

ID: 22466

\n", + "

Myntra Women Cream Patial...

\n", + "

Leggings - Cream

\n", + "

Score: 0.836

\n", + "
\n", + " \n", + "
\n", + " \n", + "

ID: 44906

\n", + "

Puma Men White 3/4 Length...

\n", + "

Shorts - White

\n", + "

Score: 0.853

\n", + "
\n", + " \n", + "
\n", + " \n", + "

ID: 32406

\n", + "

Arrow Woman Black Trouser...

\n", + "

Trousers - Black

\n", + "

Score: 0.853

\n", + "
\n", + " \n", + "
\n", + " \n", + "

ID: 57824

\n", + "

United Colors of Benetton...

\n", + "

Trousers - Green

\n", + "

Score: 0.842

\n", + "
\n", + " \n", + "
\n", + " \n", + "

ID: 30919

\n", + "

Fabindia Women Pink Harem...

\n", + "

Trousers - Pink

\n", + "

Score: 0.840

\n", + "
\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# taken from: https://github.com/elastic/elasticsearch-py/blob/main/elasticsearch/helpers/vectorstore/_utils.py#L39\n", + "def maximal_marginal_relevance(\n", + " query_embedding: List[float],\n", + " embedding_list: List[List[float]],\n", + " lambda_mult: float = 0.5,\n", + " k: int = 4,\n", + ") -> List[int]:\n", + " query_embedding_arr = np.array(query_embedding)\n", + "\n", + " if min(k, len(embedding_list)) <= 0:\n", + " return []\n", + " if query_embedding_arr.ndim == 1:\n", + " query_embedding_arr = np.expand_dims(query_embedding_arr, axis=0)\n", + " similarity_to_query = _cosine_similarity(query_embedding_arr, embedding_list)[0]\n", + " most_similar = int(np.argmax(similarity_to_query))\n", + " idxs = [most_similar]\n", + " selected = np.array([embedding_list[most_similar]])\n", + " while len(idxs) < min(k, len(embedding_list)):\n", + " best_score = -np.inf\n", + " idx_to_add = -1\n", + " similarity_to_selected = _cosine_similarity(embedding_list, selected)\n", + " for i, query_score in enumerate(similarity_to_query):\n", + " if i in idxs:\n", + " continue\n", + " redundant_score = max(similarity_to_selected[i])\n", + " equation_score = (\n", + " lambda_mult * query_score - (1 - lambda_mult) * redundant_score\n", + " )\n", + " if equation_score > best_score:\n", + " best_score = equation_score\n", + " idx_to_add = i\n", + " idxs.append(idx_to_add)\n", + " selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)\n", + " return idxs\n", + "\n", + "\n", + "mmr_indices = maximal_marginal_relevance(\n", + " query_embedding=query_vector,\n", + " embedding_list=[result[\"image_vector\"] for result in search_results],\n", + " lambda_mult=0.5,\n", + " k=100,\n", + ")\n", + "\n", + "reranked_results = [search_results[i] for i in mmr_indices]\n", + "display_images(reranked_results, \"Reranked Results (MMR)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/supporting-blog-content/result-diversification/example.env b/supporting-blog-content/result-diversification/example.env new file mode 100644 index 00000000..22fc00a2 --- /dev/null +++ b/supporting-blog-content/result-diversification/example.env @@ -0,0 +1,3 @@ +ELASTIC_API_KEY= +ELASTIC_HOST= +JINA_API_KEY= \ No newline at end of file diff --git a/supporting-blog-content/result-diversification/requirements.txt b/supporting-blog-content/result-diversification/requirements.txt new file mode 100644 index 00000000..eb34cb8e --- /dev/null +++ b/supporting-blog-content/result-diversification/requirements.txt @@ -0,0 +1,103 @@ +anyio==4.9.0 +appnope==0.1.4 +argon2-cffi==25.1.0 +argon2-cffi-bindings==21.2.0 +arrow==1.3.0 +asttokens==3.0.0 +async-lru==2.0.5 +attrs==25.3.0 +babel==2.17.0 +beautifulsoup4==4.13.4 +bleach==6.2.0 +certifi==2025.6.15 +cffi==1.17.1 +charset-normalizer==3.4.2 +comm==0.2.2 +debugpy==1.8.14 +decorator==5.2.1 +defusedxml==0.7.1 +elastic-transport==8.17.1 +elasticsearch==9.0.2 +executing==2.2.0 +fastjsonschema==2.21.1 +fqdn==1.5.1 +h11==0.16.0 +httpcore==1.0.9 +httpx==0.28.1 +idna==3.10 +ipykernel==6.29.5 +ipython==9.3.0 +ipython_pygments_lexers==1.1.1 +ipywidgets==8.1.7 +isoduration==20.11.0 +jedi==0.19.2 +Jinja2==3.1.6 +json5==0.12.0 +jsonpointer==3.0.0 +jsonschema==4.24.0 +jsonschema-specifications==2025.4.1 +jupyter-events==0.12.0 +jupyter-lsp==2.2.5 +jupyter_client==8.6.3 +jupyter_core==5.8.1 +jupyter_server==2.16.0 +jupyter_server_terminals==0.5.3 +jupyterlab==4.4.3 +jupyterlab_pygments==0.3.0 +jupyterlab_server==2.27.3 +jupyterlab_widgets==3.0.15 +kagglehub==0.3.12 +MarkupSafe==3.0.2 +matplotlib-inline==0.1.7 +mistune==3.1.3 +nbclient==0.10.2 +nbconvert==7.16.6 +nbformat==5.10.4 +nest-asyncio==1.6.0 +notebook_shim==0.2.4 +numpy==2.3.1 +overrides==7.7.0 +packaging==25.0 +pandas==2.3.0 +pandocfilters==1.5.1 +parso==0.8.4 +pexpect==4.9.0 +platformdirs==4.3.8 +prometheus_client==0.22.1 +prompt_toolkit==3.0.51 +psutil==7.0.0 +ptyprocess==0.7.0 +pure_eval==0.2.3 +pycparser==2.22 +Pygments==2.19.2 +python-dateutil==2.9.0.post0 +python-json-logger==3.3.0 +pytz==2025.2 +PyYAML==6.0.2 +pyzmq==27.0.0 +referencing==0.36.2 +requests==2.32.4 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rpds-py==0.25.1 +Send2Trash==1.8.3 +setuptools==80.9.0 +six==1.17.0 +sniffio==1.3.1 +soupsieve==2.7 +stack-data==0.6.3 +terminado==0.18.1 +tinycss2==1.4.0 +tornado==6.5.1 +tqdm==4.67.1 +traitlets==5.14.3 +types-python-dateutil==2.9.0.20250516 +typing_extensions==4.14.0 +tzdata==2025.2 +uri-template==1.3.0 +urllib3==2.5.0 +wcwidth==0.2.13 +webcolors==24.11.1 +webencodings==0.5.1 +websocket-client==1.8.0 +widgetsnbextension==4.0.14