From f23819a5d02be0754941aac7e7ee776a60f5bf31 Mon Sep 17 00:00:00 2001 From: detlef Date: Sat, 7 Sep 2019 16:29:22 +0200 Subject: [PATCH] - some early stopping and change of last pooling layer --- few_shot_tests.py | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/few_shot_tests.py b/few_shot_tests.py index 3fb579e..4c415e5 100755 --- a/few_shot_tests.py +++ b/few_shot_tests.py @@ -11,7 +11,7 @@ from mini_imagenet_dataloader import MiniImageNetDataLoader from tensorflow.keras.models import Model from tensorflow.keras.layers import Activation, Dense, Input, Flatten, Conv2D, Lambda, TimeDistributed, MaxPooling2D, Layer -from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard +from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, Callback import tensorflow.keras.backend as K import tensorflow.keras from tensorflow.keras import optimizers as op @@ -19,12 +19,12 @@ import argparse import numpy as np import random, imageio - +import signal # To support editing of command line parameters use the fork https://github.com/dsmic/Gooey # uncomment the following 2 lines for standard command line handling without gui from gooey import Gooey -@Gooey(load_cmd_args=True) +@Gooey(load_cmd_args=True, ignore_command='--no-gui', force_command='--gui') def parser(): global args @@ -54,10 +54,20 @@ def parser(): parser.add_argument('--cathegories', dest='cathegories', type=int, default=5) parser.add_argument('--only_one_samplefolder', dest='only_one_samplefolder', action='store_true') parser.add_argument('--load_subnet', dest='load_subnet', action='store_true') + parser.add_argument('--EarlyStop', dest='EarlyStop', type=str, default='EarlyStop') args = parser.parse_args() parser() +flag=0 +def CTRL_C(sig, frame): + global flag + print(flag) + #import code; code.interact() + import pdb; pdb.set_trace() + +signal.signal(signal.SIGINT, CTRL_C) + # uncomment the following to disable CuDNN support #import os #os.environ["CUDA_VISIBLE_DEVICES"] = "-1" @@ -358,7 +368,7 @@ def get_config(self): conv4 = TimeDistributed(Conv2D(args.hidden_size, 3, padding='same', activation = 'relu', name = 'conv_4'))(pool3) pool4 = TimeDistributed(MaxPooling2D(pool_size = 2, name = 'pool_4'))(conv4) conv5 = TimeDistributed(Conv2D(args.hidden_size, 3, padding='same', activation = 'relu', name = 'conv_5'))(pool4) -pool5 = TimeDistributed(MaxPooling2D(pool_size = 2, name = 'pool_5'))(conv5) +pool5 = TimeDistributed(MaxPooling2D(pool_size = 5, name = 'pool_5'))(conv5) flat = TimeDistributed(Flatten())(pool5) #x = TimeDistributed(Dense(100, activation = 'relu'))(flat) @@ -483,10 +493,18 @@ def all_layers(model): #print('vor fitting', lambda_model_layers[17].get_weights()[0]) +import os +class TerminateKey(Callback): + def on_batch_end(self, batch, logs=None): + if os.path.exists(args.EarlyStop): + self.model.stop_training = True + +terminate_on_key = TerminateKey() + checkpointer = ModelCheckpoint(filepath='checkpoints/model-{epoch:02d}.hdf5', verbose=1) tensorboard = TensorBoard(log_dir = args.tensorboard_logdir) lambda_model.fit_generator(keras_gen_train.generate_add_samples(), train_epoch_size, args.epochs, - validation_data=keras_gen_train.generate_add_samples('test'), validation_steps=test_epoch_size, callbacks = [tensorboard], workers = 0) + validation_data=keras_gen_train.generate_add_samples('test'), validation_steps=test_epoch_size, callbacks = [tensorboard, terminate_on_key], workers = 0) #workers = 0 is a work around to correct the number of calls to the validation_data generator #test_lambda = lambda_model([K.expand_dims(K.variable(base_train_img[0:0+1]),axis=0),K.expand_dims(K.variable(base_train_img), axis=0), K.expand_dims(K.variable(base_train_label), axis=0)]) @@ -565,6 +583,11 @@ def print_FindModels(model): lambda_model.save_weights(args.final_name+'-weights.hdf5') lambda_model.layers[2].save_weights(args.final_name + '-weights.hdf5' + '_subnet.hdf5') + +if os.path.exists(args.EarlyStop) and os.path.getsize(args.EarlyStop)==0: + os.remove(args.EarlyStop) + print('removed',args.EarlyStop) + # tools for debugging def get_weight_grad(model, inputs, outputs): """ Gets gradient of model for given inputs and outputs for all weights"""