Skip to content

Commit 00db0d1

Browse files
committed
added ocp demo and wget
1 parent 3961348 commit 00db0d1

File tree

7 files changed

+35
-9
lines changed

7 files changed

+35
-9
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ instance/
7878
.scrapy
7979

8080
# Sphinx documentation
81-
docs/_build/
81+
assets/_build/
8282

8383
# PyBuilder
8484
target/

README.md

+12-8
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Our results show that GCL achieves a **94.5%** increase in NDCG@10 for in-domain
2929
## 2. Dataset and Benchmarks
3030

3131
### Dataset Structure
32-
<img src="docs/ms1.png" alt="multi split visual" width="500"/>
32+
<img src="assets/ms1.png" alt="multi split visual" width="500"/>
3333

3434
Illustration of multi-dimensional split along both query and document dimensions resulting in 4 splits:
3535
training split with 80\% of queries and 50\% of documents, novel query splitwith the other 20\% of queries and the same documents as the training split,
@@ -81,7 +81,7 @@ The Marqo-GS-10M dataset is available for direct download. This dataset is pivot
8181
### Dataset Visualization
8282
Visualization of the collected triplet dataset containing search queries (top row),
8383
documents and scores, showcasing thumbnails of returned products with scores that decrease linearly according to their positions.
84-
![Dataset Qualitative](docs/visual_dataset_4.png)
84+
![Dataset Qualitative](assets/visual_dataset_4.png)
8585

8686

8787
## 3. Instructions to use the GCL Benchmarks
@@ -112,7 +112,7 @@ bash ./scripts/eval-vitb32-ckpt.sh
112112

113113

114114
## 4. GCL Training Framework and Models
115-
![Main Figure](docs/main_figure1.png)
115+
![Main Figure](assets/main_figure1.png)
116116
Overview of our Generalized Contrastive Learning (GCL) approach.
117117
GCL integrates ranking information alongside multiple input fields for each sample (e.g., title and image)
118118
across both left-hand-side (LHS) and right-hand-side (RHS).
@@ -144,18 +144,22 @@ Retrieval and ranking performance comparison of GCL versus publicly available co
144144

145145
## 5. Example Usage of Models
146146
### Quick Demo with OpenCLIP
147-
Here is a quick example to use our model if you have installed open_clip_torch.
148-
147+
Here is a quick example to use our model if you have installed open_clip_torch.
148+
```bash
149+
python demos/openclip_demo.py
150+
```
149151
```python
150152
import torch
151153
from PIL import Image
152154
import open_clip
155+
import wget
153156

154-
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
155-
# model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='/path/to/downloaded/VITB32.pt')
157+
model_url = "https://marqo-gcl-public.s3.us-west-2.amazonaws.com/v1/gcl-vitb32-117-gs-full-states.pt"
158+
wget.download(model_url, "gcl-vitb32-117-gs-full-states.pt")
159+
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='gcl-vitb32-117-gs-full-states.pt')
156160
tokenizer = open_clip.get_tokenizer('ViT-B-32')
157161

158-
image = preprocess(Image.open('docs/oxford_shoe.png')).unsqueeze(0)
162+
image = preprocess(Image.open('assets/oxford_shoe.png')).unsqueeze(0)
159163
text = tokenizer(["a dog", "Vintage Style Women's Oxfords", "a cat"])
160164
logit_scale = 10
161165
with torch.no_grad(), torch.cuda.amp.autocast():
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

demos/openclip_demo.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
from PIL import Image
3+
import open_clip
4+
import wget
5+
6+
model_url = "https://marqo-gcl-public.s3.us-west-2.amazonaws.com/v1/gcl-vitb32-117-gs-full-states.pt"
7+
wget.download(model_url, "gcl-vitb32-117-gs-full-states.pt")
8+
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='gcl-vitb32-117-gs-full-states.pt')
9+
tokenizer = open_clip.get_tokenizer('ViT-B-32')
10+
11+
image = preprocess(Image.open('assets/oxford_shoe.png')).unsqueeze(0)
12+
text = tokenizer(["a dog", "Vintage Style Women's Oxfords", "a cat"])
13+
logit_scale = 10
14+
with torch.no_grad(), torch.cuda.amp.autocast():
15+
image_features = model.encode_image(image)
16+
text_features = model.encode_text(text)
17+
image_features /= image_features.norm(dim=-1, keepdim=True)
18+
text_features /= text_features.norm(dim=-1, keepdim=True)
19+
20+
text_probs = (logit_scale * image_features @ text_features.T).softmax(dim=-1)
21+
22+
print("Label probs:", text_probs)

0 commit comments

Comments
 (0)