diff --git a/seq2seq/training/utils.py b/seq2seq/training/utils.py index 9451d57c..e73ebac0 100644 --- a/seq2seq/training/utils.py +++ b/seq2seq/training/utils.py @@ -115,7 +115,7 @@ def cell_from_spec(cell_classname, cell_params): cell_class = locate(cell_classname) or getattr(rnn_cell, cell_classname) # Make sure additional arguments are valid - cell_args = set(inspect.getargspec(cell_class.__init__).args[1:]) + cell_args = set(inspect.signature(cell_class.__init__).parameters) for key in cell_params.keys(): if key not in cell_args: raise ValueError(