Skip to content

Commit

Permalink
- enable command line gui to check all parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmic committed Sep 7, 2019
1 parent 986e70d commit 01885db
Showing 1 changed file with 37 additions and 27 deletions.
64 changes: 37 additions & 27 deletions few_shot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,43 @@
import numpy as np
import random, imageio

parser = argparse.ArgumentParser(description='train recurrent net.')
parser.add_argument('--pretrained_name', dest='pretrained_name', type=str, default=None)
parser.add_argument('--dataset', dest='dataset', type=str, default='train')
parser.add_argument('--lr', dest='lr', type=float, default=1e-3)
parser.add_argument('--epochs', dest='epochs', type=int, default=10)
parser.add_argument('--hidden_size', dest='hidden_size', type=int, default=64)
parser.add_argument('--final_name', dest='final_name', type=str, default='final_model')
parser.add_argument('--shuffle_images', dest='shuffle_images', action='store_true')
parser.add_argument('--enable_idx_increase', dest='enable_idx_increase', action='store_true')
parser.add_argument('--use_independent_base', dest='use_independent_base', action='store_true')
parser.add_argument('--train_indep_and_dependent', dest='train_indep_and_dependent', action='store_true')
parser.add_argument('--tensorboard_logdir', dest='tensorboard_logdir', type=str, default='./logs')
parser.add_argument('--enable_only_layers_of_list', dest='enable_only_layers_of_list', type=str, default=None)
parser.add_argument('--episode_test_sample_num', dest='episode_test_sample_num', type=int, default=15)
parser.add_argument('--biaslayer1', dest='biaslayer1', action='store_true')
parser.add_argument('--biaslayer2', dest='biaslayer2', action='store_true')
parser.add_argument('--shots', dest='shots', type=int, default=5)
parser.add_argument('--debug', dest='debug', action='store_true')
parser.add_argument('--set_model_img_to_weights', dest='set_model_img_to_weights', action='store_true')
parser.add_argument('--load_weights_name', dest='load_weights_name', type=str, default=None)
parser.add_argument('--scale_gradient_layer', dest='scale_gradient_layer', type=float, default=1.0)
parser.add_argument('--increase_idx_every', dest='increase_idx_every', type=int, default=1)
parser.add_argument('--dont_shuffle_batch', dest='dont_shuffle_batch', action='store_true')
parser.add_argument('--cathegories', dest='cathegories', type=int, default=5)
parser.add_argument('--only_one_samplefolder', dest='only_one_samplefolder', action='store_true')
parser.add_argument('--load_subnet', dest='load_subnet', action='store_true')
args = parser.parse_args()

# To support editing of command line parameters use the fork https://github.com/dsmic/Gooey
# uncomment the following 2 lines for standard command line handling without gui
from gooey import Gooey
@Gooey(load_cmd_args=True)

def parser():
global args
parser = argparse.ArgumentParser(description='train recurrent net.')
parser.add_argument('--pretrained_name', dest='pretrained_name', type=str, default=None)
parser.add_argument('--dataset', dest='dataset', type=str, default='train')
parser.add_argument('--lr', dest='lr', type=float, default=1e-3)
parser.add_argument('--epochs', dest='epochs', type=int, default=10)
parser.add_argument('--hidden_size', dest='hidden_size', type=int, default=64)
parser.add_argument('--final_name', dest='final_name', type=str, default='final_model')
parser.add_argument('--shuffle_images', dest='shuffle_images', action='store_true')
parser.add_argument('--enable_idx_increase', dest='enable_idx_increase', action='store_true')
parser.add_argument('--use_independent_base', dest='use_independent_base', action='store_true')
parser.add_argument('--train_indep_and_dependent', dest='train_indep_and_dependent', action='store_true')
parser.add_argument('--tensorboard_logdir', dest='tensorboard_logdir', type=str, default='./logs')
parser.add_argument('--enable_only_layers_of_list', dest='enable_only_layers_of_list', type=str, default=None)
parser.add_argument('--episode_test_sample_num', dest='episode_test_sample_num', type=int, default=15)
parser.add_argument('--biaslayer1', dest='biaslayer1', action='store_true')
parser.add_argument('--biaslayer2', dest='biaslayer2', action='store_true')
parser.add_argument('--shots', dest='shots', type=int, default=5)
parser.add_argument('--debug', dest='debug', action='store_true')
parser.add_argument('--set_model_img_to_weights', dest='set_model_img_to_weights', action='store_true')
parser.add_argument('--load_weights_name', dest='load_weights_name', type=str, default=None)
parser.add_argument('--scale_gradient_layer', dest='scale_gradient_layer', type=float, default=1.0)
parser.add_argument('--increase_idx_every', dest='increase_idx_every', type=int, default=1)
parser.add_argument('--dont_shuffle_batch', dest='dont_shuffle_batch', action='store_true')
parser.add_argument('--cathegories', dest='cathegories', type=int, default=5)
parser.add_argument('--only_one_samplefolder', dest='only_one_samplefolder', action='store_true')
parser.add_argument('--load_subnet', dest='load_subnet', action='store_true')
args = parser.parse_args()

parser()

# uncomment the following to disable CuDNN support
#import os
Expand Down

0 comments on commit 01885db

Please sign in to comment.