Skip to content

Commit

Permalink
- cleaned up to keep mini_imagenet_dataloader.py originally, therefor…
Browse files Browse the repository at this point in the history
…e two functions have to be overwritten :(
  • Loading branch information
dsmic committed Aug 25, 2019
1 parent 86a6347 commit e860c5d
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 6 deletions.
62 changes: 62 additions & 0 deletions few_shot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import tensorflow as tf
import argparse
import numpy as np
import random, imageio

parser = argparse.ArgumentParser(description='train recurrent net.')
parser.add_argument('--pretrained_name', dest='pretrained_name', type=str, default=None)
Expand Down Expand Up @@ -72,6 +73,67 @@ def idx_to_big(self, phase, idx):
one_episode_sample_num = self.num_samples_per_class*self.way_num
return ((idx+1)*one_episode_sample_num >= len(all_filenames))

def process_batch(self, input_filename_list, input_label_list, batch_sample_num, reshape_with_one=True, dont_shuffle_batch = False):
new_path_list = []
new_label_list = []
for k in range(batch_sample_num):
class_idxs = list(range(0, self.way_num))
if not dont_shuffle_batch:
random.shuffle(class_idxs)
for class_idx in class_idxs:
true_idx = class_idx*batch_sample_num + k
new_path_list.append(input_filename_list[true_idx])
new_label_list.append(input_label_list[true_idx])

img_list = []
for filepath in new_path_list:
this_img = imageio.imread(filepath)
this_img = this_img / 255.0
img_list.append(this_img)

if reshape_with_one:
img_array = np.array(img_list)
label_array = self.one_hot(np.array(new_label_list)).reshape([1, self.way_num*batch_sample_num, -1])
else:
img_array = np.array(img_list)
label_array = self.one_hot(np.array(new_label_list)).reshape([self.way_num*batch_sample_num, -1])
return img_array, label_array

def get_batch(self, phase='train', idx=0, dont_shuffle_batch = False):
if phase=='train':
all_filenames = self.train_filenames
labels = self.train_labels
elif phase=='val':
all_filenames = self.val_filenames
labels = self.val_labels
elif phase=='test':
all_filenames = self.test_filenames
labels = self.test_labels
else:
print('Please select vaild phase')

one_episode_sample_num = self.num_samples_per_class*self.way_num
this_task_filenames = all_filenames[idx*one_episode_sample_num:(idx+1)*one_episode_sample_num]
epitr_sample_num = self.shot_num
epite_sample_num = self.episode_test_sample_num

this_task_tr_filenames = []
this_task_tr_labels = []
this_task_te_filenames = []
this_task_te_labels = []
for class_k in range(self.way_num):
this_class_filenames = this_task_filenames[class_k*self.num_samples_per_class:(class_k+1)*self.num_samples_per_class]
this_class_label = labels[class_k*self.num_samples_per_class:(class_k+1)*self.num_samples_per_class]
this_task_tr_filenames += this_class_filenames[0:epitr_sample_num]
this_task_tr_labels += this_class_label[0:epitr_sample_num]
this_task_te_filenames += this_class_filenames[epitr_sample_num:]
this_task_te_labels += this_class_label[epitr_sample_num:]

this_inputa, this_labela = self.process_batch(this_task_tr_filenames, this_task_tr_labels, epitr_sample_num, reshape_with_one=False, dont_shuffle_batch = dont_shuffle_batch)
this_inputb, this_labelb = self.process_batch(this_task_te_filenames, this_task_te_labels, epite_sample_num, reshape_with_one=False, dont_shuffle_batch = dont_shuffle_batch)

return this_inputa, this_labela, this_inputb, this_labelb


cathegories = 5
shots = args.shots
Expand Down
11 changes: 5 additions & 6 deletions mini_imagenet_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,12 @@ def load_list(self, phase='train'):
else:
print('Please select vaild phase')

def process_batch(self, input_filename_list, input_label_list, batch_sample_num, reshape_with_one=True, dont_shuffle_batch = False):
def process_batch(self, input_filename_list, input_label_list, batch_sample_num, reshape_with_one=True):
new_path_list = []
new_label_list = []
for k in range(batch_sample_num):
class_idxs = list(range(0, self.way_num))
if not dont_shuffle_batch:
random.shuffle(class_idxs)
random.shuffle(class_idxs)
for class_idx in class_idxs:
true_idx = class_idx*batch_sample_num + k
new_path_list.append(input_filename_list[true_idx])
Expand Down Expand Up @@ -175,7 +174,7 @@ def one_hot(self, inp):
out[idx, inp[idx]] = 1
return out

def get_batch(self, phase='train', idx=0, dont_shuffle_batch = False):
def get_batch(self, phase='train', idx=0):
if phase=='train':
all_filenames = self.train_filenames
labels = self.train_labels
Expand Down Expand Up @@ -205,7 +204,7 @@ def get_batch(self, phase='train', idx=0, dont_shuffle_batch = False):
this_task_te_filenames += this_class_filenames[epitr_sample_num:]
this_task_te_labels += this_class_label[epitr_sample_num:]

this_inputa, this_labela = self.process_batch(this_task_tr_filenames, this_task_tr_labels, epitr_sample_num, reshape_with_one=False, dont_shuffle_batch = dont_shuffle_batch)
this_inputb, this_labelb = self.process_batch(this_task_te_filenames, this_task_te_labels, epite_sample_num, reshape_with_one=False, dont_shuffle_batch = dont_shuffle_batch)
this_inputa, this_labela = self.process_batch(this_task_tr_filenames, this_task_tr_labels, epitr_sample_num, reshape_with_one=False)
this_inputb, this_labelb = self.process_batch(this_task_te_filenames, this_task_te_labels, epite_sample_num, reshape_with_one=False)

return this_inputa, this_labela, this_inputb, this_labelb

0 comments on commit e860c5d

Please sign in to comment.