Skip to content

Commit 47e0896

Browse files
committed
preparing code
1 parent 2ffb62c commit 47e0896

File tree

3 files changed

+313
-0
lines changed

3 files changed

+313
-0
lines changed

src/prepare_data.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import torch
2+
import skimage.io as io
3+
from PIL import Image
4+
import os
5+
6+
# modify https://github.com/dino-chiio/blip-vqa-finetune/blob/main/finetuning.py
7+
class VQADataset(torch.utils.data.Dataset):
8+
"""VQA (v2) dataset."""
9+
10+
def __init__(self, dataset, processor, img_path=""):
11+
self.dataset = dataset
12+
self.processor = processor
13+
self.img_path = img_path
14+
15+
def __len__(self):
16+
return len(self.dataset)
17+
18+
def __getitem__(self, idx):
19+
# get image + text
20+
question = self.dataset[idx]['question']
21+
answer = self.dataset[idx]['answer']
22+
image_file = self.dataset[idx]['image']
23+
image_path = os.path.join(self.img_path, image_file)
24+
image = Image.open(image_path).convert("RGB")
25+
text = question
26+
27+
encoding = self.processor(image, text,
28+
max_length= 512, pad_to_max_length=True,
29+
# padding="max_length", truncation=True,
30+
return_tensors="pt")
31+
labels = self.processor.tokenizer.encode(
32+
answer, max_length= 8, pad_to_max_length=True, return_tensors='pt'
33+
)
34+
encoding["labels"] = labels
35+
36+
# remove batch dimension
37+
for k,v in encoding.items():
38+
encoding[k] = v.squeeze()
39+
40+
return encoding

src/preprocessing.py

+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
#!/usr/bin/env python
2+
import os
3+
4+
from datasets import load_dataset
5+
import torch
6+
from PIL import Image
7+
from torch.utils.data import DataLoader
8+
from tqdm import tqdm
9+
10+
import numpy as np
11+
import os
12+
import argparse
13+
14+
import json
15+
from shutil import copyfile
16+
17+
copyfile(src = os.path.join("/kaggle/input/vizwiz-dataset", 'vqa.py'), dst = os.path.join("../working", 'vqa.py'))
18+
copyfile(src = os.path.join("/kaggle/input/vizwiz-dataset", 'prepare_data.py'), dst = os.path.join("../working", 'prepare_data.py'))
19+
20+
from vqa import *
21+
from prepare_data import *
22+
23+
BASE_MODEL = "Salesforce/blip2-opt-2.7b"
24+
import random
25+
26+
def split_dataset(dataset, train_ratio=0.7, valid_ratio=0.1, test_ratio=0.2):
27+
# Shuffle the dataset
28+
random.shuffle(dataset)
29+
30+
# Calculate split indices
31+
total_size = len(dataset)
32+
train_size = int(total_size * train_ratio)
33+
valid_size = int(total_size * valid_ratio)
34+
test_size = total_size - train_size - valid_size
35+
36+
# Split the dataset
37+
train_set = dataset[:train_size]
38+
valid_set = dataset[train_size:train_size + valid_size]
39+
test_set = dataset[train_size + valid_size:]
40+
41+
return train_set, valid_set, test_set
42+
43+
def load_dataset_vizwiz(data_path="/kaggle/input/vizwiz"):
44+
INPUT_PATH = data_path
45+
IMG_PATH = INPUT_PATH
46+
ANNOTATIONS = INPUT_PATH + '/Annotations/Annotations'
47+
TRAIN_PATH = INPUT_PATH + '/train/train'
48+
VALIDATION_PATH = INPUT_PATH + '/val/val'
49+
TEST_PATH = INPUT_PATH + '/test/test'
50+
ANNOTATIONS_TRAIN_PATH = ANNOTATIONS + '/train.json'
51+
ANNOTATIONS_VAL_PATH = ANNOTATIONS + '/val.json'
52+
ANNOTATIONS_TEST_PATH = ANNOTATIONS + '/test.json'
53+
54+
annFile = ANNOTATIONS_TRAIN_PATH
55+
imgDir = TRAIN_PATH
56+
57+
# initialize VQA api for QA annotations
58+
data_VQA = {d_type:None for d_type in ['train','valid','test']}
59+
for d_type, a_path, d_path in zip(['train','valid','test'],
60+
[TRAIN_PATH,VALIDATION_PATH,TEST_PATH],
61+
[ANNOTATIONS_TRAIN_PATH,ANNOTATIONS_VAL_PATH,ANNOTATIONS_TEST_PATH]):
62+
annFile = d_path
63+
imgDir = a_path
64+
65+
# initialize VQA api for QA annotations
66+
vqa=VQA(annFile)
67+
68+
# load and display QA annotations for given answer types
69+
"""
70+
ansTypes can be one of the following
71+
yes/no
72+
number
73+
other
74+
unanswerable
75+
"""
76+
anns = vqa.getAnns(ansTypes=['other','yes/no','number']);
77+
anns = vqa.getBestAnns(ansTypes=['other','yes/no','number']);
78+
79+
data_VQA[d_type] = anns
80+
81+
train_n, valid_n = len(data_VQA['train']), len(data_VQA['valid'])
82+
83+
data_VQA['train'] = data_VQA['train'][:10000]
84+
data_VQA['valid'] = data_VQA['valid'][:1000]
85+
print("Training sets: {}->{} - Validating set: {}->{}".format(train_n, len(data_VQA['train']), valid_n, len(data_VQA['valid'])))
86+
87+
return data_VQA, TRAIN_PATH, VALIDATION_PATH
88+
# train_dataset = VQADataset(dataset=data_VQA['train'],
89+
# processor=processor,
90+
# img_path=TRAIN_PATH)
91+
# valid_dataset = VQADataset(dataset=data_VQA['valid'],
92+
# processor=processor,
93+
# img_path=VALIDATION_PATH)
94+
95+
# batch_size = 1
96+
# train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
97+
# valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
98+
99+
# return train_dataloader, valid_dataloader
100+
101+
def load_dataset_kvqa(data_path:str="/kaggle/input/vqa-blind-ko"):
102+
INPUT_PATH = data_path
103+
TRAIN_PATH = INPUT_PATH + '/VQA_train/images'
104+
ANNOTATIONS = INPUT_PATH
105+
TEST_PATH = INPUT_PATH + '/VQA_test/task07_images'
106+
ANNOTATIONS_TRAIN_PATH = ANNOTATIONS + '/train_en.json'
107+
ANNOTATIONS_TEST_PATH = TEST_PATH + '/test.json'
108+
109+
annFile = ANNOTATIONS_TRAIN_PATH
110+
111+
# initialize VQA api for QA annotations
112+
vqa=VQA(annFile)
113+
114+
# load and display QA annotations for given answer types
115+
"""
116+
ansTypes can be one of the following
117+
yes/no
118+
number
119+
other
120+
unanswerable
121+
"""
122+
anns = vqa.getAnns();
123+
anns = vqa.getBestAnns();
124+
125+
# Split the dataset into train, validation, and test sets
126+
train_set, valid_set, test_set = split_dataset(anns)
127+
train_n, valid_n = len(train_set), len(valid_set)
128+
train_set = train_set[:20000]
129+
valid_set = valid_set[:2000]
130+
data_VQA = {
131+
'train': train_set,
132+
'valid': valid_set,
133+
'test': test_set
134+
}
135+
print("Training sets: {}->{} - Validating set: {}->{}".format(train_n, len(data_VQA['train']), valid_n, len(data_VQA['valid'])))
136+
137+
return data_VQA, TRAIN_PATH, TRAIN_PATH
138+
139+
# if __name__ == "__main__":
140+
# parser = argparse.ArgumentParser()
141+
# parser.add_argument("--kaggle", type=bool, required=True, default=True)
142+
# parser.add_argument("--vizwiz_path", type=str, required=True, default="/kaggle/input/vizwiz")
143+
# parser.add_argument("--kvqa_path", type=str, required=True, default="/kaggle/input/vqa-blind-ko")
144+
# parser.add_argument("--lib_path", type=str, required=True, default="/kaggle/input/vizwiz-dataset")
145+
# args = parser.parse_args()
146+
147+
# # load_dataset()
148+
# if args.kaggle:
149+
# from shutil import copyfile
150+
# copyfile(src = os.path.join(args.lib_path, 'vqa.py'), dst = os.path.join("../working", 'vqa.py'))
151+
152+
# from vqa import *
153+
# load_dataset()

src/vqa.py

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
__author__ = 'QingLi'
2+
__version__ = '1.0'
3+
## modified by sooh-J
4+
# Interface for accessing the VQA dataset.
5+
6+
# This code is based on the code written by Qing Li for VizWiz Python API available at the following link:
7+
# (https://github.com/xxx)
8+
9+
# The following functions are defined:
10+
# VQA - VQA class that loads VQA annotation file and prepares data structures.
11+
# getQuesIds - Get question ids that satisfy given filter conditions.
12+
# getImgIds - Get image ids that satisfy given filter conditions.
13+
# loadQA - Load questions and answers with the specified question ids.
14+
# showQA - Display the specified questions and answers.
15+
# loadRes - Load result file and create result object.
16+
17+
# Help on each function can be accessed by: "help(COCO.function)"
18+
19+
import json
20+
import datetime
21+
import copy
22+
import random
23+
import skimage.io as io
24+
import matplotlib.pyplot as plt
25+
import os
26+
from collections import Counter
27+
28+
def export_max_value(ans):
29+
ans_cnt = Counter(ans)
30+
31+
max_cnt = max(ans_cnt.values())
32+
max_ans = max([k for k,v in ans_cnt.items() if v == max_cnt], key=len)
33+
34+
return max_ans
35+
36+
class VQA:
37+
def __init__(self, annotation_file=None):
38+
"""
39+
Constructor of VQA helper class for reading and visualizing questions and answers.
40+
:param annotation_file (str): location of VQA annotation file
41+
:return:
42+
"""
43+
# load dataset
44+
self.dataset = {}
45+
self.imgToQA = {}
46+
if annotation_file != None:
47+
print('loading dataset into memory...')
48+
time_t = datetime.datetime.utcnow()
49+
dataset = json.load(open(annotation_file, 'r'))
50+
print(datetime.datetime.utcnow() - time_t)
51+
self.dataset = dataset
52+
self.imgToQA = {x['image']:x for x in dataset}
53+
self.anns = {}
54+
55+
def getImgs(self):
56+
return list(self.imgToQA.keys())
57+
58+
def getAnns(self, imgs=[], ansTypes=[]):
59+
"""
60+
Get annotations that satisfy given filter conditions. default skips that filter
61+
:param imgs (str array): get annotations for given image names
62+
ansTypes (str array) : get annotations for given answer types
63+
:return: annotations (dict array) : dict array of annotations
64+
"""
65+
anns = self.dataset
66+
67+
imgs = imgs if type(imgs) == list else [imgs]
68+
if len(imgs) != 0:
69+
anns = [self.imgToQA[img] for img in imgs]
70+
71+
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
72+
if len(ansTypes) != 0:
73+
anns = [ann for ann in anns if ann.get('answer_type',"None") in ansTypes]
74+
75+
self.anns = anns
76+
return anns
77+
78+
def getBestAnns(self, imgs=[], ansTypes=[]):
79+
"""
80+
code by SOOH-J
81+
Filter the best answer(only one answer with confidence)
82+
:param imgs (str array): get annotations for given image names
83+
ansTypes (str array) : get annotations for given answer types
84+
:return: annotations (dict array) : dict array of annotations
85+
"""
86+
try:
87+
anns = self.anns
88+
except:
89+
anns = self.getAnns(imgs, ansTypes)
90+
91+
# include only answers with confidence
92+
confidence_anns = []
93+
for ann in anns:
94+
confidence_ann = [an for an in ann['answers'] if an.get('answer_confidence') == 'yes']
95+
if confidence_ann:
96+
ann_copy = ann.copy()
97+
ann_copy['answers'] = confidence_ann
98+
try:
99+
confidence_answers = [con_ans['answer'] for con_ans in confidence_ann if con_ans['answer'] not in ['unanswerable','unsuitable image','unsuitable']]
100+
ann_copy['answer'] = export_max_value(confidence_answers)
101+
except:
102+
continue
103+
confidence_anns.append(ann_copy)
104+
continue
105+
return confidence_anns
106+
107+
def showQA(self, anns):
108+
"""
109+
Display the specified annotations.
110+
:param anns (array of object): annotations to display
111+
:return: None
112+
"""
113+
if len(anns) == 0:
114+
return 0
115+
for ann in anns:
116+
print("Question: %s"%ann['question'])
117+
print("Answer: ")
118+
print('\n'.join([f"\tanswer:{x['answer']}, confidence:{x['answer_confidence']}" for x in ann['answers']]))
119+
print("SOOOO the best answer:")
120+
print(f"\t{ann['answer']}")

0 commit comments

Comments
 (0)