From e860c5d9d95d1c435fd493d290dac4fb1bfff2ed Mon Sep 17 00:00:00 2001 From: detlef Date: Sun, 25 Aug 2019 15:22:30 +0200 Subject: [PATCH] - cleaned up to keep mini_imagenet_dataloader.py originally, therefore two functions have to be overwritten :( --- few_shot_tests.py | 62 +++++++++++++++++++++++++++++++++++++ mini_imagenet_dataloader.py | 11 +++---- 2 files changed, 67 insertions(+), 6 deletions(-) diff --git a/few_shot_tests.py b/few_shot_tests.py index bdd0b07..38597e4 100755 --- a/few_shot_tests.py +++ b/few_shot_tests.py @@ -18,6 +18,7 @@ import tensorflow as tf import argparse import numpy as np +import random, imageio parser = argparse.ArgumentParser(description='train recurrent net.') parser.add_argument('--pretrained_name', dest='pretrained_name', type=str, default=None) @@ -72,6 +73,67 @@ def idx_to_big(self, phase, idx): one_episode_sample_num = self.num_samples_per_class*self.way_num return ((idx+1)*one_episode_sample_num >= len(all_filenames)) + def process_batch(self, input_filename_list, input_label_list, batch_sample_num, reshape_with_one=True, dont_shuffle_batch = False): + new_path_list = [] + new_label_list = [] + for k in range(batch_sample_num): + class_idxs = list(range(0, self.way_num)) + if not dont_shuffle_batch: + random.shuffle(class_idxs) + for class_idx in class_idxs: + true_idx = class_idx*batch_sample_num + k + new_path_list.append(input_filename_list[true_idx]) + new_label_list.append(input_label_list[true_idx]) + + img_list = [] + for filepath in new_path_list: + this_img = imageio.imread(filepath) + this_img = this_img / 255.0 + img_list.append(this_img) + + if reshape_with_one: + img_array = np.array(img_list) + label_array = self.one_hot(np.array(new_label_list)).reshape([1, self.way_num*batch_sample_num, -1]) + else: + img_array = np.array(img_list) + label_array = self.one_hot(np.array(new_label_list)).reshape([self.way_num*batch_sample_num, -1]) + return img_array, label_array + + def get_batch(self, phase='train', idx=0, dont_shuffle_batch = False): + if phase=='train': + all_filenames = self.train_filenames + labels = self.train_labels + elif phase=='val': + all_filenames = self.val_filenames + labels = self.val_labels + elif phase=='test': + all_filenames = self.test_filenames + labels = self.test_labels + else: + print('Please select vaild phase') + + one_episode_sample_num = self.num_samples_per_class*self.way_num + this_task_filenames = all_filenames[idx*one_episode_sample_num:(idx+1)*one_episode_sample_num] + epitr_sample_num = self.shot_num + epite_sample_num = self.episode_test_sample_num + + this_task_tr_filenames = [] + this_task_tr_labels = [] + this_task_te_filenames = [] + this_task_te_labels = [] + for class_k in range(self.way_num): + this_class_filenames = this_task_filenames[class_k*self.num_samples_per_class:(class_k+1)*self.num_samples_per_class] + this_class_label = labels[class_k*self.num_samples_per_class:(class_k+1)*self.num_samples_per_class] + this_task_tr_filenames += this_class_filenames[0:epitr_sample_num] + this_task_tr_labels += this_class_label[0:epitr_sample_num] + this_task_te_filenames += this_class_filenames[epitr_sample_num:] + this_task_te_labels += this_class_label[epitr_sample_num:] + + this_inputa, this_labela = self.process_batch(this_task_tr_filenames, this_task_tr_labels, epitr_sample_num, reshape_with_one=False, dont_shuffle_batch = dont_shuffle_batch) + this_inputb, this_labelb = self.process_batch(this_task_te_filenames, this_task_te_labels, epite_sample_num, reshape_with_one=False, dont_shuffle_batch = dont_shuffle_batch) + + return this_inputa, this_labela, this_inputb, this_labelb + cathegories = 5 shots = args.shots diff --git a/mini_imagenet_dataloader.py b/mini_imagenet_dataloader.py index 59c2372..f86b400 100644 --- a/mini_imagenet_dataloader.py +++ b/mini_imagenet_dataloader.py @@ -141,13 +141,12 @@ def load_list(self, phase='train'): else: print('Please select vaild phase') - def process_batch(self, input_filename_list, input_label_list, batch_sample_num, reshape_with_one=True, dont_shuffle_batch = False): + def process_batch(self, input_filename_list, input_label_list, batch_sample_num, reshape_with_one=True): new_path_list = [] new_label_list = [] for k in range(batch_sample_num): class_idxs = list(range(0, self.way_num)) - if not dont_shuffle_batch: - random.shuffle(class_idxs) + random.shuffle(class_idxs) for class_idx in class_idxs: true_idx = class_idx*batch_sample_num + k new_path_list.append(input_filename_list[true_idx]) @@ -175,7 +174,7 @@ def one_hot(self, inp): out[idx, inp[idx]] = 1 return out - def get_batch(self, phase='train', idx=0, dont_shuffle_batch = False): + def get_batch(self, phase='train', idx=0): if phase=='train': all_filenames = self.train_filenames labels = self.train_labels @@ -205,7 +204,7 @@ def get_batch(self, phase='train', idx=0, dont_shuffle_batch = False): this_task_te_filenames += this_class_filenames[epitr_sample_num:] this_task_te_labels += this_class_label[epitr_sample_num:] - this_inputa, this_labela = self.process_batch(this_task_tr_filenames, this_task_tr_labels, epitr_sample_num, reshape_with_one=False, dont_shuffle_batch = dont_shuffle_batch) - this_inputb, this_labelb = self.process_batch(this_task_te_filenames, this_task_te_labels, epite_sample_num, reshape_with_one=False, dont_shuffle_batch = dont_shuffle_batch) + this_inputa, this_labela = self.process_batch(this_task_tr_filenames, this_task_tr_labels, epitr_sample_num, reshape_with_one=False) + this_inputb, this_labelb = self.process_batch(this_task_te_filenames, this_task_te_labels, epite_sample_num, reshape_with_one=False) return this_inputa, this_labela, this_inputb, this_labelb