Skip to content

Commit

Permalink
- small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmic committed Sep 15, 2019
1 parent f23819a commit 1570470
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions few_shot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def parser():
parser.add_argument('--only_one_samplefolder', dest='only_one_samplefolder', action='store_true')
parser.add_argument('--load_subnet', dest='load_subnet', action='store_true')
parser.add_argument('--EarlyStop', dest='EarlyStop', type=str, default='EarlyStop')
parser.add_argument('--max_idx', dest='max_idx', type=int, default=-1)
parser.add_argument('--dense_img_num', dest='dense_img_num', type=int, default=-1)

args = parser.parse_args()

parser()
Expand Down Expand Up @@ -85,6 +88,8 @@ def printdeb(*what):
class OurMiniImageNetDataLoader(MiniImageNetDataLoader):
# adding functions we need
def idx_to_big(self, phase, idx):
if args.max_idx>0 and idx >= args.max_idx: #to limit the allowed different batches
return True
if phase=='train':
all_filenames = self.train_filenames
elif phase=='val':
Expand Down Expand Up @@ -368,13 +373,16 @@ def get_config(self):
conv4 = TimeDistributed(Conv2D(args.hidden_size, 3, padding='same', activation = 'relu', name = 'conv_4'))(pool3)
pool4 = TimeDistributed(MaxPooling2D(pool_size = 2, name = 'pool_4'))(conv4)
conv5 = TimeDistributed(Conv2D(args.hidden_size, 3, padding='same', activation = 'relu', name = 'conv_5'))(pool4)
pool5 = TimeDistributed(MaxPooling2D(pool_size = 5, name = 'pool_5'))(conv5)
pool5 = TimeDistributed(MaxPooling2D(pool_size = 2, name = 'pool_5'))(conv5)

flat = TimeDistributed(Flatten())(pool5)
#x = TimeDistributed(Dense(100, activation = 'relu'))(flat)
if args.dense_img_num > 0:
x = TimeDistributed(Dense(args.dense_img_num, activation = 'tanh'))(flat)
else:
x = flat
#predictions = Activation('softmax')(x)

model_img = FindModel(inputs=inputs, outputs=flat)
model_img = FindModel(inputs=inputs, outputs=x)

#model_img.compile(loss='categorical_crossentropy', optimizer='Adam', metrics=['categorical_accuracy'])

Expand All @@ -383,7 +391,7 @@ def get_config(self):
input1 = Input(shape=(None,84,84,3))
input2 = Input(shape=(None,84,84,3)) #, tensor = K.variable(episode_train_img[0:0]))

input2b = BiasLayer(shots * cathegories, args.biaslayer1, bias_num = 1, name = 'bias1'+str(cathegories)+'_'+str(args.shots)+'t')(input2)
input2b = BiasLayer(shots * cathegories, args.biaslayer1, bias_num = 1, name = 'bias1_'+str(cathegories)+'_'+str(args.shots)+'t')(input2)
encoded_l = model_img(input1)
encoded_r = model_img(input2b)

Expand Down

0 comments on commit 1570470

Please sign in to comment.