diff --git a/few_shot_tests.py b/few_shot_tests.py index 0899c8a..3fb579e 100755 --- a/few_shot_tests.py +++ b/few_shot_tests.py @@ -20,33 +20,43 @@ import numpy as np import random, imageio -parser = argparse.ArgumentParser(description='train recurrent net.') -parser.add_argument('--pretrained_name', dest='pretrained_name', type=str, default=None) -parser.add_argument('--dataset', dest='dataset', type=str, default='train') -parser.add_argument('--lr', dest='lr', type=float, default=1e-3) -parser.add_argument('--epochs', dest='epochs', type=int, default=10) -parser.add_argument('--hidden_size', dest='hidden_size', type=int, default=64) -parser.add_argument('--final_name', dest='final_name', type=str, default='final_model') -parser.add_argument('--shuffle_images', dest='shuffle_images', action='store_true') -parser.add_argument('--enable_idx_increase', dest='enable_idx_increase', action='store_true') -parser.add_argument('--use_independent_base', dest='use_independent_base', action='store_true') -parser.add_argument('--train_indep_and_dependent', dest='train_indep_and_dependent', action='store_true') -parser.add_argument('--tensorboard_logdir', dest='tensorboard_logdir', type=str, default='./logs') -parser.add_argument('--enable_only_layers_of_list', dest='enable_only_layers_of_list', type=str, default=None) -parser.add_argument('--episode_test_sample_num', dest='episode_test_sample_num', type=int, default=15) -parser.add_argument('--biaslayer1', dest='biaslayer1', action='store_true') -parser.add_argument('--biaslayer2', dest='biaslayer2', action='store_true') -parser.add_argument('--shots', dest='shots', type=int, default=5) -parser.add_argument('--debug', dest='debug', action='store_true') -parser.add_argument('--set_model_img_to_weights', dest='set_model_img_to_weights', action='store_true') -parser.add_argument('--load_weights_name', dest='load_weights_name', type=str, default=None) -parser.add_argument('--scale_gradient_layer', dest='scale_gradient_layer', type=float, default=1.0) -parser.add_argument('--increase_idx_every', dest='increase_idx_every', type=int, default=1) -parser.add_argument('--dont_shuffle_batch', dest='dont_shuffle_batch', action='store_true') -parser.add_argument('--cathegories', dest='cathegories', type=int, default=5) -parser.add_argument('--only_one_samplefolder', dest='only_one_samplefolder', action='store_true') -parser.add_argument('--load_subnet', dest='load_subnet', action='store_true') -args = parser.parse_args() + +# To support editing of command line parameters use the fork https://github.com/dsmic/Gooey +# uncomment the following 2 lines for standard command line handling without gui +from gooey import Gooey +@Gooey(load_cmd_args=True) + +def parser(): + global args + parser = argparse.ArgumentParser(description='train recurrent net.') + parser.add_argument('--pretrained_name', dest='pretrained_name', type=str, default=None) + parser.add_argument('--dataset', dest='dataset', type=str, default='train') + parser.add_argument('--lr', dest='lr', type=float, default=1e-3) + parser.add_argument('--epochs', dest='epochs', type=int, default=10) + parser.add_argument('--hidden_size', dest='hidden_size', type=int, default=64) + parser.add_argument('--final_name', dest='final_name', type=str, default='final_model') + parser.add_argument('--shuffle_images', dest='shuffle_images', action='store_true') + parser.add_argument('--enable_idx_increase', dest='enable_idx_increase', action='store_true') + parser.add_argument('--use_independent_base', dest='use_independent_base', action='store_true') + parser.add_argument('--train_indep_and_dependent', dest='train_indep_and_dependent', action='store_true') + parser.add_argument('--tensorboard_logdir', dest='tensorboard_logdir', type=str, default='./logs') + parser.add_argument('--enable_only_layers_of_list', dest='enable_only_layers_of_list', type=str, default=None) + parser.add_argument('--episode_test_sample_num', dest='episode_test_sample_num', type=int, default=15) + parser.add_argument('--biaslayer1', dest='biaslayer1', action='store_true') + parser.add_argument('--biaslayer2', dest='biaslayer2', action='store_true') + parser.add_argument('--shots', dest='shots', type=int, default=5) + parser.add_argument('--debug', dest='debug', action='store_true') + parser.add_argument('--set_model_img_to_weights', dest='set_model_img_to_weights', action='store_true') + parser.add_argument('--load_weights_name', dest='load_weights_name', type=str, default=None) + parser.add_argument('--scale_gradient_layer', dest='scale_gradient_layer', type=float, default=1.0) + parser.add_argument('--increase_idx_every', dest='increase_idx_every', type=int, default=1) + parser.add_argument('--dont_shuffle_batch', dest='dont_shuffle_batch', action='store_true') + parser.add_argument('--cathegories', dest='cathegories', type=int, default=5) + parser.add_argument('--only_one_samplefolder', dest='only_one_samplefolder', action='store_true') + parser.add_argument('--load_subnet', dest='load_subnet', action='store_true') + args = parser.parse_args() + +parser() # uncomment the following to disable CuDNN support #import os