-
Notifications
You must be signed in to change notification settings - Fork 0
/
find_images.py
executable file
·104 lines (77 loc) · 2.91 KB
/
find_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python3
"""Find images that match a (hard coded) text string and report their filenames.
This is intended as a way to check that everything necessary to run the app is
actually working
"""
import logging
import os
from pathlib import Path
import clip
import psycopg
import torch
from dotenv import load_dotenv
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(name)s %(levelname)s: %(message)s',
)
logger = logging.getLogger(__name__)
load_dotenv()
SERVICE_URI = os.getenv("PG_SERVICE_URI")
# Load the open CLIP model
# If we download it remotely, it will default to being cached in ~/.cache/clip
LOCAL_MODEL = Path('./models/ViT-B-32.pt').absolute()
MODEL_NAME = 'ViT-B/32'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if LOCAL_MODEL.exists():
logger.info(f'Importing CLIP model from {LOCAL_MODEL}')
logger.info(f'Using {DEVICE}')
model, preprocess = clip.load(MODEL_NAME, device=DEVICE, download_root=LOCAL_MODEL.parent)
else:
logger.info('Importing CLIP model')
logger.info(f'Using {DEVICE}')
model, preprocess = clip.load(MODEL_NAME, device=DEVICE)
INDEX_NAME = "photos" # Update with your index name
def get_single_embedding(text):
with torch.no_grad():
# Encode the text to compute the feature vector and normalize it
text_input = clip.tokenize([text]).to(DEVICE)
text_features = model.encode_text(text_input)
text_features /= text_features.norm(dim=-1, keepdim=True)
# Return the feature vector
return text_features.cpu().numpy()[0]
def vector_to_string(embedding):
"""Convert our (ndarry) embedding vector into a string that SQL can use.
"""
vector_str = ", ".join(str(x) for x in embedding.tolist())
vector_str = f'[{vector_str}]'
return vector_str
def search_for_matches(text):
"""Search for the "nearest" four images
See [Querying](https://github.com/pgvector/pgvector?tab=readme-ov-file#querying)
in the pgvector documentation.
pgvector distance functions (see are:
* <-> - L2 distance
* <#> - (negative) inner product
* <=> - cosine distance
* <+> - L1 distance (added in 0.7.0)
"""
vector = get_single_embedding(text)
embedding_string = vector_to_string(vector)
# Perform search
try:
with psycopg.connect(SERVICE_URI) as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT * FROM pictures ORDER BY embedding <-> %s LIMIT 4;",
(embedding_string,),
)
rows = cur.fetchall()
return [row[0] for row in rows]
except Exception as exc:
print(f'{exc.__class__.__name__}: {exc}')
return []
text_input = "man jumping" # Provide your text input here
logger.info(f'Searching for {text_input!r}')
matches = search_for_matches(text_input)
for index, filename in enumerate(matches):
print(f'{index+1}: {filename}')