+{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"collapsed":true,"execution":{"iopub.execute_input":"2024-06-25T04:54:37.917452Z","iopub.status.busy":"2024-06-25T04:54:37.916870Z","iopub.status.idle":"2024-06-25T04:56:14.937122Z","shell.execute_reply":"2024-06-25T04:56:14.935948Z","shell.execute_reply.started":"2024-06-25T04:54:37.917425Z"},"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[],"source":["!pip install -q git+https://github.com/huggingface/peft.git transformers bitsandbytes datasets accelerate\n","!pip install -i https://pypi.org/simple/ bitsandbytes\n","!pip install evaluate\n","!pip install nltk\n","!pip install rouge\n","!pip install pycocoevalcap"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-06-25T05:15:25.828317Z","iopub.status.busy":"2024-06-25T05:15:25.827938Z","iopub.status.idle":"2024-06-25T05:18:30.650740Z","shell.execute_reply":"2024-06-25T05:18:30.649597Z","shell.execute_reply.started":"2024-06-25T05:15:25.828287Z"},"trusted":true},"outputs":[],"source":["import torch\n","from PIL import Image\n","from tqdm import tqdm\n","import pickle\n","from transformers import AutoProcessor, Blip2ForConditionalGeneration\n","\n","processor = AutoProcessor.from_pretrained(\"sooh-j/VQA-for-VIP\")\n","model_v = Blip2ForConditionalGeneration.from_pretrained(\"sooh-j/VQA-for-VIP\", \n"," device_map=\"auto\", )"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-06-25T05:00:12.900410Z","iopub.status.busy":"2024-06-25T05:00:12.899772Z","iopub.status.idle":"2024-06-25T05:00:28.393883Z","shell.execute_reply":"2024-06-25T05:00:28.392929Z","shell.execute_reply.started":"2024-06-25T05:00:12.900376Z"},"trusted":true},"outputs":[],"source":["import os\n","from shutil import copyfile\n","\n","# Interface for accessing the VQA dataset.\n","lib_PATH = '/kaggle/input/vizwiz-dataset'\n","\n","from os import listdir\n","from os.path import isfile, join\n","lib_files = [f for f in listdir(lib_PATH) if isfile(join(lib_PATH, f))]\n","# lib_file \n","\n","for lib_f in lib_files:\n"," copyfile(src = os.path.join(lib_PATH, lib_f), \n"," dst = os.path.join(\"../working\", lib_f))\n","\n","# import all our functions\n","from preprocessing import *\n","from prepare_data_eval import *\n","from vqa import *\n","\n","#-------------------------------download VIZWIZ dataset--------------------------#\n","\n","vizwiz_data, VIZWIZ_TRAIN_PATH, VIZWIZ_VALIDATION_PATH = load_dataset_vizwiz(\"/kaggle/input/vizwiz\")\n","\n","# vizwiz_train_dataset = VQADataset(dataset=vizwiz_data['train'],\n","# processor=processor,\n","# img_path=VIZWIZ_TRAIN_PATH)\n","vizwiz_valid_dataset = VQADataset(dataset=vizwiz_data['test'][:100],\n"," processor=processor,\n"," img_path=VIZWIZ_VALIDATION_PATH)\n","\n","#-------------------------------download KVQA dataset--------------------------#\n","\n","kvqa_data, KVQA_TRAIN_PATH, KVQA_VALIDATION_PATH = load_dataset_kvqa(\"/kaggle/input/vqa-blind-ko\")\n","\n","# kvqa_train_dataset = VQADataset(dataset=kvqa_data['train'],\n","# processor=processor,\n","# img_path=KVQA_TRAIN_PATH)\n","kvqa_valid_dataset = VQADataset(dataset=kvqa_data['test'][:100],\n"," processor=processor,\n"," img_path=KVQA_VALIDATION_PATH)"]},{"cell_type":"code","execution_count":null,"metadata":{"collapsed":true,"execution":{"iopub.execute_input":"2024-06-25T05:18:48.199928Z","iopub.status.busy":"2024-06-25T05:18:48.199552Z","iopub.status.idle":"2024-06-25T05:22:32.593238Z","shell.execute_reply":"2024-06-25T05:22:32.592349Z","shell.execute_reply.started":"2024-06-25T05:18:48.199897Z"},"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[],"source":["import torch\n","from torch.utils.data import ConcatDataset\n","from nltk.translate.bleu_score import sentence_bleu\n","import nltk\n","from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score\n","from tqdm import tqdm\n","model = model_v\n","nltk.download('wordnet')\n","\n","# BLEU Score \n","def calculate_bleu(reference, candidate):\n"," reference_tokens = [ref.split() for ref in reference]\n"," candidate_tokens = candidate.split()\n"," score = sentence_bleu(reference_tokens, candidate_tokens)\n"," return score\n","\n","def predict(model, image, question):\n"," device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n"," prompt = f\"Question: {question}, Answer:\"\n"," processed = processor(images=image, text=prompt, return_tensors=\"pt\").to(device)\n"," output = model.generate(**processed, \n"," max_new_tokens=20,\n"," temperature = 0.5,\n"," do_sample=True,\n"," top_k=50,\n"," top_p=0.9,\n"," repetition_penalty=1.2\n"," ).to(device)\n"," predicted_text = processor.decode(output[0], skip_special_tokens=True)\n"," return predicted_text\n","\n","combined_dataset = ConcatDataset([vizwiz_valid_dataset, kvqa_valid_dataset])\n","\n","references = []\n","candidates = []\n","\n","for image, question, answer in tqdm(combined_dataset):\n"," predicted_answer = predict(model, image, question)\n","\n"," references.append(answer)\n"," candidates.append(predicted_answer)\n","\n","# BLEU Score \n","bleu_scores = [calculate_bleu([ref], cand) for ref, cand in zip(references, candidates)]\n","avg_bleu_score = sum(bleu_scores) / len(bleu_scores)"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-06-25T05:22:32.602681Z","iopub.status.busy":"2024-06-25T05:22:32.602289Z","iopub.status.idle":"2024-06-25T05:22:32.803103Z","shell.execute_reply":"2024-06-25T05:22:32.802281Z","shell.execute_reply.started":"2024-06-25T05:22:32.602650Z"},"trusted":true},"outputs":[],"source":["from sklearn.feature_extraction.text import TfidfVectorizer\n","from sklearn.metrics.pairwise import cosine_similarity\n","import numpy as np\n","def calculate_metrics_with_similarity(predicted_list, ground_truth_list, threshold=0.3):\n"," vectorizer = TfidfVectorizer().fit(predicted_list + ground_truth_list)\n"," predicted_vectors = vectorizer.transform(predicted_list)\n"," ground_truth_vectors = vectorizer.transform(ground_truth_list)\n"," \n"," true_positive = 0\n"," \n"," for pred_vec in predicted_vectors:\n"," similarities = cosine_similarity(pred_vec, ground_truth_vectors).flatten()\n"," if np.any(similarities >= threshold):\n"," true_positive += 1\n"," \n"," precision = true_positive / len(predicted_list) if predicted_list else 0\n"," recall = true_positive / len(ground_truth_list) if ground_truth_list else 0\n"," f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) else 0\n"," accuracy = true_positive / len(predicted_list) if predicted_list else 0\n"," \n"," return {\n"," 'precision': precision,\n"," 'recall': recall,\n"," 'f1_score': f1_score,\n"," 'accuracy': accuracy\n"," }\n","\n","metrics = calculate_metrics_with_similarity( candidates, references)\n","print(metrics)\n","print(f\"Average BLEU Score: {avg_bleu_score}\")"]},{"cell_type":"markdown","metadata":{},"source":["----"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-06-25T04:56:14.939904Z","iopub.status.busy":"2024-06-25T04:56:14.939510Z","iopub.status.idle":"2024-06-25T05:00:12.898660Z","shell.execute_reply":"2024-06-25T05:00:12.897855Z","shell.execute_reply.started":"2024-06-25T04:56:14.939867Z"},"trusted":true},"outputs":[],"source":["from datasets import load_dataset\n","import torch\n","from PIL import Image\n","from torch.utils.data import DataLoader\n","from tqdm import tqdm\n","import pickle\n","from transformers import AutoProcessor, Blip2ForConditionalGeneration\n","\n","processor = AutoProcessor.from_pretrained(\"sooh-j/VQA-for-VIP\")\n","model_blip = Blip2ForConditionalGeneration.from_pretrained(\"Salesforce/blip2-opt-2.7b\", \n"," device_map=\"auto\", )\n"]},{"cell_type":"code","execution_count":null,"metadata":{"collapsed":true,"execution":{"iopub.execute_input":"2024-06-25T05:00:30.627929Z","iopub.status.busy":"2024-06-25T05:00:30.627320Z","iopub.status.idle":"2024-06-25T05:02:46.273480Z","shell.execute_reply":"2024-06-25T05:02:46.272511Z","shell.execute_reply.started":"2024-06-25T05:00:30.627893Z"},"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[],"source":["references_blip = []\n","candidates_blip = []\n","\n","for image, question, answer in tqdm(combined_dataset):\n"," predicted_answer_blip = predict(model_blip, image, question)\n","\n"," references_blip.append(answer)\n"," candidates_blip.append(predicted_answer_blip)\n","\n","# BLEU Score \n","bleu_scores_blip = [calculate_bleu([ref], cand) for ref, cand in zip(references_blip, candidates_blip)]\n","avg_bleu_score_blip = sum(bleu_scores_blip) / len(bleu_scores_blip)\n","\n","metrics_blip = calculate_metrics_with_similarity( candidates_blip, references_blip)\n","print(metrics_blip)\n","print(f\"Average BLEU Score: {avg_bleu_score_blip}\")"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.status.busy":"2024-06-25T03:02:54.391025Z","iopub.status.idle":"2024-06-25T03:02:54.391396Z","shell.execute_reply":"2024-06-25T03:02:54.391233Z","shell.execute_reply.started":"2024-06-25T03:02:54.391217Z"},"trusted":true},"outputs":[],"source":["import requests\n","from PIL import Image\n","import os\n","import skimage.io as io\n","import matplotlib.pyplot as plt\n","from io import BytesIO\n","import base64\n","\n","device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n","input_images = [\n"," \"/kaggle/input/dataset-fortest/VizWiz_train_00000000.jpg\",\n"," \"/kaggle/input/dataset-fortest/airpod.jpeg\"\n"," ]\n","input_questions = [\n"," \"What's the name of this product?\",\n"," \"what is this?\"\n"," ]\n","\n","for img, question in zip(input_images, input_questions):\n"," if os.path.isfile(img):\n"," image = Image.open(img).convert('RGB')\n"," I = io.imread(img)\n"," plt.imshow(I)\n"," plt.axis('off')\n"," else:\n"," image = Image.open(requests.get(img, stream=True).raw).convert('RGB')\n"," plt.imshow(image)\n","\n"," prompt = f\"Question: {question}, Answer:\"\n"," processed = processor(images=image, text=prompt, return_tensors=\"pt\").to(device)\n"," out = model.generate(**processed, \n"," max_new_tokens=20,\n"," temperature = 0.5,\n"," do_sample=True,\n"," top_k=50,\n"," top_p=0.9,\n"," repetition_penalty=1.2 \n"," ).to(device)\n","\n"," text_output = processor.decode(out[0], skip_special_tokens=True)\n"," print(f\"Q : {question}, A : {text_output}\")\n"," plt.figtext(1, 0.5, f\"Q : {question}\\nA : {text_output}\", fontsize=14)\n"," plt.show()"]}],"metadata":{"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[{"datasetId":2310141,"sourceId":3887986,"sourceType":"datasetVersion"},{"datasetId":5253046,"sourceId":8755764,"sourceType":"datasetVersion"},{"datasetId":5265861,"sourceId":8764428,"sourceType":"datasetVersion"},{"datasetId":4884402,"sourceId":8765206,"sourceType":"datasetVersion"}],"dockerImageVersionId":30733,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.13"}},"nbformat":4,"nbformat_minor":4}
0 commit comments