From 1427611a20875efe19438dae2f9780c0bad5b1bd Mon Sep 17 00:00:00 2001 From: James McDermott Date: Wed, 23 May 2018 16:29:18 +0100 Subject: [PATCH] Check CUDA is available if specified, and fix the --gpu_mode flag to allow user to turn it off. --- main.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 43f3971..fb5ba5a 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,8 @@ from EBGAN import EBGAN from BEGAN import BEGAN +import torch + """parsing and configuration""" def parse_args(): desc = "Pytorch implementation of GAN collections" @@ -32,7 +34,7 @@ def parse_args(): parser.add_argument('--lrD', type=float, default=0.0002) parser.add_argument('--beta1', type=float, default=0.5) parser.add_argument('--beta2', type=float, default=0.999) - parser.add_argument('--gpu_mode', type=bool, default=True) + parser.add_argument('--gpu_mode', type=str2bool, default=True) return check_args(parser.parse_args()) @@ -62,8 +64,25 @@ def check_args(args): except: print('batch size must be larger than or equal to one') + # --gpu_mode + if args.gpu_mode: + try: + assert torch.cuda.is_available() + except: + print('CUDA is not available. Use --gpu_mode False') + raise + return args +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + """main""" def main(): # parse arguments