Skip to content

PyTorch code for BagFormer: Better Cross-Modal Retrieval via bag-wise interaction

Notifications You must be signed in to change notification settings

popself/BagFormer

 
 

Repository files navigation

BagFormer: Better Cross-Modal Retrieval via bag-wise interaction

This is the PyTorch code of the BagFormer paper. The code has been tested on Python 3.8 and PyTorch 1.13. To install the dependencies, please create a virtual environment and run

pip install -r requirements.txt

Pre-trained checkpoints:

num of image-text pairs BagFormer
108M Download

Finetuned checkpoints:

Task BagFormer
Image-Text Retrieval (MUGE) Download

Image-Text Retrieval:

  1. Download MUGE Multimodal Retrieval dataset from the original website, and unzip file to data directory, or modify the path in configs/config_muge.yaml.
  2. To evaluate the finetuned BagFormer model on MUGE, run:
python3 train_muge.py \
--checkpoint path-to-finetuned-checkpoint \
--interaction bagwise \
--output_dir path-to-output \
--evaluate
  1. To finetune the pre-trained checkpoint. Then run:
python3 train_muge.py \
--checkpoint path-to-pretrain-checkpoint \
--interaction bagwise \
--output_dir path-to-output 
  1. To compare bagwise interaction with cls_token or tokenwise interaction, run baseline:
 # cls_token baseline, which is the BagFormer w/o late interaction model in the paper
python3 train_muge.py \
--checkpoint path-to-pretrain-checkpoint \
--interaction cls_token \
--output_dir path-to-output

# tokenwise baseline, which is the BagFormer w/o bagging layer model in the paper
python3 train_muge.py \
--checkpoint path-to-pretrain-checkpoint \
--interaction tokenwise \
--output_dir path-to-output

Calculate bag-wise similarity

import torch
import torch.nn.functional as F
from PIL import Image
from ruamel import yaml
from transformers import BertTokenizer

from models.loss import tokenwise_similarity_martix
from models.model_helper import EmbeddingBagHelperAutomaton
from models.model_retrieval_bagwise import BagFormer
from MUGE_helper.dataset import get_test_transform

device = "cuda" if torch.cuda.is_available() else "cpu"
text_encoder = "bert-base-chinese"
max_seq_len = 25
config = yaml.load(open("configs/config_muge.yaml", "r"), Loader=yaml.Loader)
test_transform = get_test_transform(config)

tokenizer = BertTokenizer.from_pretrained(text_encoder)

model = BagFormer(
    config=config, 
    text_encoder=toext_encoder,
    tokenizer=tokenizer
    )

checkpoint = torch.load(
    "path-to-checkpoint", map_location="cpu"
)
model.load_state_dict(checkpoint["model"], strict=False)
model = model.to(device)

embedding_bag_helper = EmbeddingBagHelperAutomaton(
    tokenizer, config["entity_dict_path"], masked_token=["[CLS]", "[PAD]"]
)

product_image = test_transform(Image.open("rumble_roller.jpeg"))
image = product_image.unsqueeze(0).to(device)

product_title = ["rumble roller", "nike zoomx vista"]
text = tokenizer(product_title, padding="max_length", max_length=max_seq_len)

embed_bag_offset, attn_mask = embedding_bag_helper.process(text, return_mask=True)
embed_bag_offset = torch.LongTensor(embed_bag_offset).to(device)
embed_bag_attn_mask = torch.LongTensor(attn_mask).to(device)
text = text.convert_to_tensors("pt").to(device)

with torch.no_grad():
    # encode image and text
    image_features = model.visual_encoder(image)
    text_features = model.text_encoder(
        text.input_ids, attention_mask=text.attention_mask, mode="text"
    ).last_hidden_state
    # get text bag feature
    batch_size, seq_len, text_width = text_features.shape
    embedding_input = torch.arange(batch_size * seq_len, device=device)
    embedbag_feats = F.embedding_bag(
        embedding_input,
        text_features.view(-1, text_width),
        embed_bag_offset,
        mode="sum",
    ).view(batch_size, -1, text_width)
    embedbag_feats = F.normalize(embedbag_feats, dim=-1)
    # pad to same length
    embedbag_seq_len = embedbag_feats.shape[1]
    embedbag_feats = F.pad(
        embedbag_feats,
        pad=(0, 0, 0, max_seq_len - embedbag_seq_len, 0, 0),
        mode="constant",
        value=0,
    )
    # calc bagwise similarity matrix
    sim_i2t, sim_t2i = tokenwise_similarity_martix(embedbag_feats, image_features)

print("image feature shape:", image_features.shape)  
# prints: torch.Size([1, 257, 768])
print("text feature shape:", embedbag_feats.shape)  
# prints: torch.Size([2, 25, 768])
print("img2text sim:", sim_i2t)  # prints: [[132.4761, 50.0424]
print("text2img sim:", sim_t2i)  # prints: [[33.4206], [19.6727]]

Citation

If you find our work useful, please consider citing BagFormer:

@article{hou2022bagformer,
  title={BagFormer: Better Cross-Modal Retrieval via bag-wise interaction},
  author={Hou, Haowen and Yan, Xiaopeng and Zhang, Yigeng and Lian, Fengzong and Kang, Zhanhui},
  journal={arXiv preprint arXiv:2212.14322},
  year={2022}
}

About

PyTorch code for BagFormer: Better Cross-Modal Retrieval via bag-wise interaction

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.8%
  • Shell 0.2%