Skip to content

Commit

Permalink
- some early stopping and change of last pooling layer
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmic committed Sep 7, 2019
1 parent 01885db commit f23819a
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions few_shot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,20 @@
from mini_imagenet_dataloader import MiniImageNetDataLoader
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Activation, Dense, Input, Flatten, Conv2D, Lambda, TimeDistributed, MaxPooling2D, Layer
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, Callback
import tensorflow.keras.backend as K
import tensorflow.keras
from tensorflow.keras import optimizers as op
import tensorflow as tf
import argparse
import numpy as np
import random, imageio

import signal

# To support editing of command line parameters use the fork https://github.com/dsmic/Gooey
# uncomment the following 2 lines for standard command line handling without gui
from gooey import Gooey
@Gooey(load_cmd_args=True)
@Gooey(load_cmd_args=True, ignore_command='--no-gui', force_command='--gui')

def parser():
global args
Expand Down Expand Up @@ -54,10 +54,20 @@ def parser():
parser.add_argument('--cathegories', dest='cathegories', type=int, default=5)
parser.add_argument('--only_one_samplefolder', dest='only_one_samplefolder', action='store_true')
parser.add_argument('--load_subnet', dest='load_subnet', action='store_true')
parser.add_argument('--EarlyStop', dest='EarlyStop', type=str, default='EarlyStop')
args = parser.parse_args()

parser()

flag=0
def CTRL_C(sig, frame):
global flag
print(flag)
#import code; code.interact()
import pdb; pdb.set_trace()

signal.signal(signal.SIGINT, CTRL_C)

# uncomment the following to disable CuDNN support
#import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
Expand Down Expand Up @@ -358,7 +368,7 @@ def get_config(self):
conv4 = TimeDistributed(Conv2D(args.hidden_size, 3, padding='same', activation = 'relu', name = 'conv_4'))(pool3)
pool4 = TimeDistributed(MaxPooling2D(pool_size = 2, name = 'pool_4'))(conv4)
conv5 = TimeDistributed(Conv2D(args.hidden_size, 3, padding='same', activation = 'relu', name = 'conv_5'))(pool4)
pool5 = TimeDistributed(MaxPooling2D(pool_size = 2, name = 'pool_5'))(conv5)
pool5 = TimeDistributed(MaxPooling2D(pool_size = 5, name = 'pool_5'))(conv5)

flat = TimeDistributed(Flatten())(pool5)
#x = TimeDistributed(Dense(100, activation = 'relu'))(flat)
Expand Down Expand Up @@ -483,10 +493,18 @@ def all_layers(model):

#print('vor fitting', lambda_model_layers[17].get_weights()[0])

import os
class TerminateKey(Callback):
def on_batch_end(self, batch, logs=None):
if os.path.exists(args.EarlyStop):
self.model.stop_training = True

terminate_on_key = TerminateKey()

checkpointer = ModelCheckpoint(filepath='checkpoints/model-{epoch:02d}.hdf5', verbose=1)
tensorboard = TensorBoard(log_dir = args.tensorboard_logdir)
lambda_model.fit_generator(keras_gen_train.generate_add_samples(), train_epoch_size, args.epochs,
validation_data=keras_gen_train.generate_add_samples('test'), validation_steps=test_epoch_size, callbacks = [tensorboard], workers = 0)
validation_data=keras_gen_train.generate_add_samples('test'), validation_steps=test_epoch_size, callbacks = [tensorboard, terminate_on_key], workers = 0)
#workers = 0 is a work around to correct the number of calls to the validation_data generator

#test_lambda = lambda_model([K.expand_dims(K.variable(base_train_img[0:0+1]),axis=0),K.expand_dims(K.variable(base_train_img), axis=0), K.expand_dims(K.variable(base_train_label), axis=0)])
Expand Down Expand Up @@ -565,6 +583,11 @@ def print_FindModels(model):
lambda_model.save_weights(args.final_name+'-weights.hdf5')
lambda_model.layers[2].save_weights(args.final_name + '-weights.hdf5' + '_subnet.hdf5')


if os.path.exists(args.EarlyStop) and os.path.getsize(args.EarlyStop)==0:
os.remove(args.EarlyStop)
print('removed',args.EarlyStop)

# tools for debugging
def get_weight_grad(model, inputs, outputs):
""" Gets gradient of model for given inputs and outputs for all weights"""
Expand Down

0 comments on commit f23819a

Please sign in to comment.