Skip to content

Commit

Permalink
Check CUDA is available if specified, and fix the --gpu_mode flag to …
Browse files Browse the repository at this point in the history
…allow user to turn it off.
  • Loading branch information
jmmcd committed May 23, 2018
1 parent 5207b2f commit 1427611
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from EBGAN import EBGAN
from BEGAN import BEGAN

import torch

"""parsing and configuration"""
def parse_args():
desc = "Pytorch implementation of GAN collections"
Expand All @@ -32,7 +34,7 @@ def parse_args():
parser.add_argument('--lrD', type=float, default=0.0002)
parser.add_argument('--beta1', type=float, default=0.5)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--gpu_mode', type=bool, default=True)
parser.add_argument('--gpu_mode', type=str2bool, default=True)

return check_args(parser.parse_args())

Expand Down Expand Up @@ -62,8 +64,25 @@ def check_args(args):
except:
print('batch size must be larger than or equal to one')

# --gpu_mode
if args.gpu_mode:
try:
assert torch.cuda.is_available()
except:
print('CUDA is not available. Use --gpu_mode False')
raise

return args

def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')


"""main"""
def main():
# parse arguments
Expand Down

0 comments on commit 1427611

Please sign in to comment.