1
+ #!/usr/bin/env python
2
+ import os
3
+
4
+ from datasets import load_dataset
5
+ import torch
6
+ from PIL import Image
7
+ from torch .utils .data import DataLoader
8
+ from tqdm import tqdm
9
+
10
+ import numpy as np
11
+ import os
12
+ import argparse
13
+
14
+ import json
15
+ from shutil import copyfile
16
+
17
+ copyfile (src = os .path .join ("/kaggle/input/vizwiz-dataset" , 'vqa.py' ), dst = os .path .join ("../working" , 'vqa.py' ))
18
+ copyfile (src = os .path .join ("/kaggle/input/vizwiz-dataset" , 'prepare_data.py' ), dst = os .path .join ("../working" , 'prepare_data.py' ))
19
+
20
+ from vqa import *
21
+ from prepare_data import *
22
+
23
+ BASE_MODEL = "Salesforce/blip2-opt-2.7b"
24
+ import random
25
+
26
+ def split_dataset (dataset , train_ratio = 0.7 , valid_ratio = 0.1 , test_ratio = 0.2 ):
27
+ # Shuffle the dataset
28
+ random .shuffle (dataset )
29
+
30
+ # Calculate split indices
31
+ total_size = len (dataset )
32
+ train_size = int (total_size * train_ratio )
33
+ valid_size = int (total_size * valid_ratio )
34
+ test_size = total_size - train_size - valid_size
35
+
36
+ # Split the dataset
37
+ train_set = dataset [:train_size ]
38
+ valid_set = dataset [train_size :train_size + valid_size ]
39
+ test_set = dataset [train_size + valid_size :]
40
+
41
+ return train_set , valid_set , test_set
42
+
43
+ def load_dataset_vizwiz (data_path = "/kaggle/input/vizwiz" ):
44
+ INPUT_PATH = data_path
45
+ IMG_PATH = INPUT_PATH
46
+ ANNOTATIONS = INPUT_PATH + '/Annotations/Annotations'
47
+ TRAIN_PATH = INPUT_PATH + '/train/train'
48
+ VALIDATION_PATH = INPUT_PATH + '/val/val'
49
+ TEST_PATH = INPUT_PATH + '/test/test'
50
+ ANNOTATIONS_TRAIN_PATH = ANNOTATIONS + '/train.json'
51
+ ANNOTATIONS_VAL_PATH = ANNOTATIONS + '/val.json'
52
+ ANNOTATIONS_TEST_PATH = ANNOTATIONS + '/test.json'
53
+
54
+ annFile = ANNOTATIONS_TRAIN_PATH
55
+ imgDir = TRAIN_PATH
56
+
57
+ # initialize VQA api for QA annotations
58
+ data_VQA = {d_type :None for d_type in ['train' ,'valid' ,'test' ]}
59
+ for d_type , a_path , d_path in zip (['train' ,'valid' ,'test' ],
60
+ [TRAIN_PATH ,VALIDATION_PATH ,TEST_PATH ],
61
+ [ANNOTATIONS_TRAIN_PATH ,ANNOTATIONS_VAL_PATH ,ANNOTATIONS_TEST_PATH ]):
62
+ annFile = d_path
63
+ imgDir = a_path
64
+
65
+ # initialize VQA api for QA annotations
66
+ vqa = VQA (annFile )
67
+
68
+ # load and display QA annotations for given answer types
69
+ """
70
+ ansTypes can be one of the following
71
+ yes/no
72
+ number
73
+ other
74
+ unanswerable
75
+ """
76
+ anns = vqa .getAnns (ansTypes = ['other' ,'yes/no' ,'number' ]);
77
+ anns = vqa .getBestAnns (ansTypes = ['other' ,'yes/no' ,'number' ]);
78
+
79
+ data_VQA [d_type ] = anns
80
+
81
+ train_n , valid_n = len (data_VQA ['train' ]), len (data_VQA ['valid' ])
82
+
83
+ data_VQA ['train' ] = data_VQA ['train' ][:10000 ]
84
+ data_VQA ['valid' ] = data_VQA ['valid' ][:1000 ]
85
+ print ("Training sets: {}->{} - Validating set: {}->{}" .format (train_n , len (data_VQA ['train' ]), valid_n , len (data_VQA ['valid' ])))
86
+
87
+ return data_VQA , TRAIN_PATH , VALIDATION_PATH
88
+ # train_dataset = VQADataset(dataset=data_VQA['train'],
89
+ # processor=processor,
90
+ # img_path=TRAIN_PATH)
91
+ # valid_dataset = VQADataset(dataset=data_VQA['valid'],
92
+ # processor=processor,
93
+ # img_path=VALIDATION_PATH)
94
+
95
+ # batch_size = 1
96
+ # train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
97
+ # valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
98
+
99
+ # return train_dataloader, valid_dataloader
100
+
101
+ def load_dataset_kvqa (data_path :str = "/kaggle/input/vqa-blind-ko" ):
102
+ INPUT_PATH = data_path
103
+ TRAIN_PATH = INPUT_PATH + '/VQA_train/images'
104
+ ANNOTATIONS = INPUT_PATH
105
+ TEST_PATH = INPUT_PATH + '/VQA_test/task07_images'
106
+ ANNOTATIONS_TRAIN_PATH = ANNOTATIONS + '/train_en.json'
107
+ ANNOTATIONS_TEST_PATH = TEST_PATH + '/test.json'
108
+
109
+ annFile = ANNOTATIONS_TRAIN_PATH
110
+
111
+ # initialize VQA api for QA annotations
112
+ vqa = VQA (annFile )
113
+
114
+ # load and display QA annotations for given answer types
115
+ """
116
+ ansTypes can be one of the following
117
+ yes/no
118
+ number
119
+ other
120
+ unanswerable
121
+ """
122
+ anns = vqa .getAnns ();
123
+ anns = vqa .getBestAnns ();
124
+
125
+ # Split the dataset into train, validation, and test sets
126
+ train_set , valid_set , test_set = split_dataset (anns )
127
+ train_n , valid_n = len (train_set ), len (valid_set )
128
+ train_set = train_set [:20000 ]
129
+ valid_set = valid_set [:2000 ]
130
+ data_VQA = {
131
+ 'train' : train_set ,
132
+ 'valid' : valid_set ,
133
+ 'test' : test_set
134
+ }
135
+ print ("Training sets: {}->{} - Validating set: {}->{}" .format (train_n , len (data_VQA ['train' ]), valid_n , len (data_VQA ['valid' ])))
136
+
137
+ return data_VQA , TRAIN_PATH , TRAIN_PATH
138
+
139
+ # if __name__ == "__main__":
140
+ # parser = argparse.ArgumentParser()
141
+ # parser.add_argument("--kaggle", type=bool, required=True, default=True)
142
+ # parser.add_argument("--vizwiz_path", type=str, required=True, default="/kaggle/input/vizwiz")
143
+ # parser.add_argument("--kvqa_path", type=str, required=True, default="/kaggle/input/vqa-blind-ko")
144
+ # parser.add_argument("--lib_path", type=str, required=True, default="/kaggle/input/vizwiz-dataset")
145
+ # args = parser.parse_args()
146
+
147
+ # # load_dataset()
148
+ # if args.kaggle:
149
+ # from shutil import copyfile
150
+ # copyfile(src = os.path.join(args.lib_path, 'vqa.py'), dst = os.path.join("../working", 'vqa.py'))
151
+
152
+ # from vqa import *
153
+ # load_dataset()
0 commit comments