Skip to content

Commit

Permalink
Argument specifications for each CLI are now stored in a static struc…
Browse files Browse the repository at this point in the history
…ture. (#399)

* Argument specifications for each CLI are now stored in a static structure.
  • Loading branch information
fhieber authored May 17, 2018
1 parent de627d6 commit 5a3bf5f
Show file tree
Hide file tree
Showing 8 changed files with 913 additions and 855 deletions.
1,736 changes: 884 additions & 852 deletions sockeye/arguments.py

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion sockeye/average.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,14 @@ def main():
"""
Commandline interface to average parameters.
"""
log_sockeye_version(logger)
params = argparse.ArgumentParser(description="Averages parameters from multiple models.")
arguments.add_average_args(params)
args = params.parse_args()
average_parameters(args)


def average_parameters(args: argparse.Namespace):
log_sockeye_version(logger)

if len(args.inputs) > 1:
avg_params = average(args.inputs)
Expand Down
3 changes: 3 additions & 0 deletions sockeye/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def main():
params.add_argument('-k', type=int, default=5, help='Number of neighbours to print')
params.add_argument('--gamma', '-g', type=float, default=1.0, help='Softmax distribution steepness.')
args = params.parse_args()
embeddings(args)


def embeddings(args: argparse.Namespace):
logger.info("Arguments: %s", args)

config = model.SockeyeModel.load_config(os.path.join(args.model, C.CONFIG_NAME))
Expand Down
6 changes: 5 additions & 1 deletion sockeye/extract_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,14 @@ def main():
"""
Commandline interface to extract parameters.
"""
log_sockeye_version(logger)
params = argparse.ArgumentParser(description="Extract specific parameters.")
arguments.add_extract_args(params)
args = params.parse_args()
extract_parameters(args)


def extract_parameters(args: argparse.Namespace):
log_sockeye_version(logger)

if os.path.isdir(args.input):
param_path = os.path.join(args.input, C.PARAMS_BEST_NAME)
Expand Down
6 changes: 5 additions & 1 deletion sockeye/init_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def main():
"""
Commandline interface to initialize Sockeye embedding weights with pretrained word representations.
"""
log_sockeye_version(logger)
params = argparse.ArgumentParser(description='Quick usage: python3 -m sockeye.init_embedding '
'-w embed-in-src.npy embed-in-tgt.npy '
'-i vocab-in-src.json vocab-in-tgt.json '
Expand All @@ -132,6 +131,11 @@ def main():
'-f params.init')
arguments.add_init_embedding_args(params)
args = params.parse_args()
init_embeddings(args)


def init_embeddings(args: argparse.Namespace):
log_sockeye_version(logger)

if len(args.weight_files) != len(args.vocabularies_in) or \
len(args.weight_files) != len(args.vocabularies_out) or \
Expand Down
4 changes: 4 additions & 0 deletions sockeye/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def main():
params = argparse.ArgumentParser(description='Preprocesses and shards training data.')
arguments.add_prepare_data_cli_args(params)
args = params.parse_args()
prepare_data(args)


def prepare_data(args: argparse.Namespace):

output_folder = os.path.abspath(args.output)
os.makedirs(output_folder, exist_ok=True)
Expand Down
3 changes: 3 additions & 0 deletions sockeye/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,10 @@ def main():
params = argparse.ArgumentParser(description='Train Sockeye sequence-to-sequence models.')
arguments.add_train_cli_args(params)
args = params.parse_args()
train(args)


def train(args: argparse.Namespace):
if args.dry_run:
# Modify arguments so that we write to a temporary directory and
# perform 0 training iterations
Expand Down
4 changes: 4 additions & 0 deletions sockeye/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def main():
params = argparse.ArgumentParser(description='Translate CLI')
arguments.add_translate_cli_args(params)
args = params.parse_args()
run_translate(args)


def run_translate(args: argparse.Namespace):

if args.output is not None:
global logger
Expand Down

0 comments on commit 5a3bf5f

Please sign in to comment.