Skip to content

Saving models with a BasicDecoder Layer? #2432

Open
@klaimans

Description

@klaimans

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04
  • TensorFlow version and how it was installed (source or binary): 2.4.1
  • TensorFlow-Addons version and how it was installed (source or binary): 0.11.2/0.12.1
  • Python version: 3.6
  • Is GPU used? (yes/no): yes

Describe the bug

when trying to save a subclassed model with a BasicDecoder layer, one cannot save the model using the save method or tf.saved_model.save

Code to reproduce the issue

Using the colab tutorial

https://colab.research.google.com/github/tensorflow/addons/blob/master/docs/tutorials/networks_seq2seq_nmt.ipynb

if one tries to save the decoder it doesn't work.

Our subclassed model is more involved but we expect that if we cannot save the decoder already in this example it cannot work for us either.

Provide a reproducible test case that is the bare minimum necessary to generate the problem.

see colab above and add after the training cell the following line:

decoder.save("./test")


TypeError Traceback (most recent call last)
in ()
----> 1 decoder.save("./test")

25 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
2000 # pylint: enable=line-too-long
2001 save.save_model(self, filepath, overwrite, include_optimizer, save_format,
-> 2002 signatures, options, save_traces)
2003
2004 def save_weights(self,

/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
155 else:
156 saved_model_save.save(model, filepath, overwrite, include_optimizer,
--> 157 signatures, options, save_traces)
158
159

/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/saving/saved_model/save.py in save(model, filepath, overwrite, include_optimizer, signatures, options, save_traces)
87 with distribution_strategy_context._get_default_replica_context(): # pylint: disable=protected-access
88 with utils.keras_option_scope(save_traces):
---> 89 save_lib.save(model, filepath, signatures, options)
90
91 if not include_optimizer:

/usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/save.py in save(obj, export_dir, signatures, options)
1031
1032 _, exported_graph, object_saver, asset_info = _build_meta_graph(
-> 1033 obj, signatures, options, meta_graph_def)
1034 saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION
1035

/usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, signatures, options, meta_graph_def)
1196
1197 with save_context.save_context(options):
-> 1198 return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)

/usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/save.py in _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
1131 if signatures is None:
1132 signatures = signature_serialization.find_function_to_export(
-> 1133 checkpoint_graph_view)
1134
1135 signatures, wrapped_functions = (

/usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/signature_serialization.py in find_function_to_export(saveable_view)
73 # If the user did not specify signatures, check the root object for a function
74 # that can be made into a signature.
---> 75 functions = saveable_view.list_functions(saveable_view.root)
76 signature = functions.get(DEFAULT_SIGNATURE_ATTR, None)
77 if signature is not None:

/usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/save.py in list_functions(self, obj, extra_functions)
149 if obj_functions is None:
150 obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access
--> 151 self._serialization_cache)
152 self._functions[obj] = obj_functions
153 if extra_functions:

/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py in _list_functions_for_serialization(self, serialization_cache)
2611 self.predict_function = None
2612 functions = super(
-> 2613 Model, self)._list_functions_for_serialization(serialization_cache)
2614 self.train_function = train_function
2615 self.test_function = test_function

/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/base_layer.py in _list_functions_for_serialization(self, serialization_cache)
3085 def _list_functions_for_serialization(self, serialization_cache):
3086 return (self._trackable_saved_model_saver
-> 3087 .list_functions_for_serialization(serialization_cache))
3088
3089 def getstate(self):

/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py in list_functions_for_serialization(self, serialization_cache)
92 return {}
93
---> 94 fns = self.functions_to_serialize(serialization_cache)
95
96 # The parent AutoTrackable class saves all user-defined tf.functions, and

/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in functions_to_serialize(self, serialization_cache)
77 def functions_to_serialize(self, serialization_cache):
78 return (self._get_serialized_attributes(
---> 79 serialization_cache).functions_to_serialize)
80
81 def _get_serialized_attributes(self, serialization_cache):

/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes(self, serialization_cache)
93
94 object_dict, function_dict = self._get_serialized_attributes_internal(
---> 95 serialization_cache)
96
97 serialized_attr.set_and_validate_objects(object_dict)

/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
49 # cache (i.e. this is the root level object).
50 if len(serialization_cache[constants.KERAS_CACHE_KEY]) == 1:
---> 51 default_signature = save_impl.default_save_signature(self.obj)
52
53 # Other than the default signature function, all other attributes match with

/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in default_save_signature(layer)
203 original_losses = _reset_layer_losses(layer)
204 fn = saving_utils.trace_model_call(layer)
--> 205 fn.get_concrete_function()
206 _restore_layer_losses(original_losses)
207 return fn

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
1297 ValueError: if this object has not yet been called on concrete values.
1298 """
-> 1299 concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
1300 concrete._garbage_collector.release() # pylint: disable=protected-access
1301 return concrete

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
1203 if self._stateful_fn is None:
1204 initializers = []
-> 1205 self._initialize(args, kwargs, add_initializers_to=initializers)
1206 self._initialize_uninitialized_variables(initializers)
1207

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
724 self._concrete_stateful_fn = (
725 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
--> 726 *args, **kwds))
727
728 def invalid_creator_scope(*unused_args, **unused_kwds):

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
2967 args, kwargs = None, None
2968 with self._lock:
-> 2969 graph_function, _ = self._maybe_define_function(args, kwargs)
2970 return graph_function
2971

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3359
3360 self._function_cache.missed.add(call_context_key)
-> 3361 graph_function = self._create_graph_function(args, kwargs)
3362 self._function_cache.primary[cache_key] = graph_function
3363

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3204 arg_names=arg_names,
3205 override_flat_arg_shapes=override_flat_arg_shapes,
-> 3206 capture_by_value=self._capture_by_value),
3207 self._function_attributes,
3208 function_spec=self.function_spec,

/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
988 _, original_func = tf_decorator.unwrap(python_func)
989
--> 990 func_outputs = python_func(*func_args, **func_kwargs)
991
992 # invariant: func_outputs contains only Tensors, CompositeTensors,

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
632 xla_context.Exit()
633 else:
--> 634 out = weak_wrapped_fn().wrapped(*args, **kwds)
635 return out
636

/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/saving/saving_utils.py in _wrapped_model(*args)
133 with base_layer_utils.call_context().enter(
134 model, inputs=inputs, build_graph=False, training=False, saving=True):
--> 135 outputs = model(inputs, training=False)
136
137 # Outputs always has to be a flat dict.

/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/base_layer.py in call(self, *args, **kwargs)
1010 with autocast_variable.enable_auto_cast_variables(
1011 self._compute_dtype_object):
-> 1012 outputs = call_fn(inputs, *args, **kwargs)
1013
1014 if self._activity_regularizer:

/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
618 def wrapper(*args, **kwargs):
619 with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
--> 620 return func(*args, **kwargs)
621
622 if inspect.isfunction(func) or inspect.ismethod(func):

TypeError: call() missing 1 required positional argument: 'initial_state'

Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions