diff --git a/tf2onnx/graph.py b/tf2onnx/graph.py index c1a07958b..a9d6c2ee6 100644 --- a/tf2onnx/graph.py +++ b/tf2onnx/graph.py @@ -486,6 +486,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No self.contained_graphs = {} # {node_name: {node_attribute_name: Graph}} ops = [Node(node, self) for node in nodes] + if input_names is not None: input_names_set = set(input_names) for n in ops: @@ -740,7 +741,6 @@ def reset_nodes(self, ops): if op.name in self.contained_graphs: remained_sub_graphs[op.name] = self.contained_graphs[op.name] - self._nodes = ops self.contained_graphs = remained_sub_graphs self._nodes_by_name = {op.name: op for op in ops} @@ -758,7 +758,7 @@ def reset_nodes(self, ops): raise ValueError("graph input '" + n.name + "' not exist") for o in self.outputs: if o not in self._output_to_node_name: - raise ValueError("graph output '" + o.name + "' not exist") + raise ValueError("graph output '" + str(o) + "' not exist") self._dtypes = remained_dtypes self._output_shapes = remained_shapes diff --git a/tf2onnx/optimizer/__init__.py b/tf2onnx/optimizer/__init__.py index f874b1275..e45166654 100644 --- a/tf2onnx/optimizer/__init__.py +++ b/tf2onnx/optimizer/__init__.py @@ -21,7 +21,8 @@ # optimizer sequence need to be considered carefully _optimizers = OrderedDict([ - ("optimize_transpose", TransposeOptimizer), + ("remove_identity", IdentityOptimizer), + ("optimize_transpose", TransposeOptimizer), # transpose to reshape or add reshape ("remove_redundant_upsample", UpsampleOptimizer), ("fold_constants", ConstFoldOptimizer), ("const_dequantize_optimizer", ConstDequantizeOptimizer), @@ -32,7 +33,6 @@ ("reshape_optimizer", ReshapeOptimizer), ("global_pool_optimizer", GlobalPoolOptimizer), ("q_dq_optimizer", QDQOptimizer), - ("remove_identity", IdentityOptimizer), ("remove_back_to_back", BackToBackOptimizer), ("einsum_optimizer", EinsumOptimizer), ]) diff --git a/tf2onnx/tfonnx.py b/tf2onnx/tfonnx.py index f149ae88d..b95b06e08 100644 --- a/tf2onnx/tfonnx.py +++ b/tf2onnx/tfonnx.py @@ -246,7 +246,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F exceptions = [] if initialized_tables is None: initialized_tables = {} - + ops = list(g.get_nodes()) for node in ops: logger.debug("Process node: %s\n%s", node.name, node.summary) @@ -263,7 +263,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F logger.error("Tensorflow op [%s: %s] is not supported", node.name, op) continue mapped_op[op] += 1 - + func, kwargs = map_info if kwargs: # if there is a tf_op/onnx_op key we'll map the old type to a new type @@ -273,6 +273,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F kwargs["tfl_op" if is_tflite else "tf_op"] = op node.type = converted_op body_graphs = node.get_body_graphs() + if body_graphs: for attr, b_g in body_graphs.items(): logger.debug("start handling subgraph of %s's attribute %s", node.name, attr) @@ -287,7 +288,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F b_g.topological_sort(b_g.get_nodes()) exceptions.extend(body_exceptions) logger.debug("finish handling subgraph of %s's attribute %s", node.name, attr) - + try: func(g, node, **kwargs, initialized_tables=initialized_tables, dequantize=dequantize) if not is_tflite: @@ -302,7 +303,6 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F logger.error("Failed to convert node %r (fct=%r)\n%r", node.name, func, summary, exc_info=1) exceptions.append(ex) - return mapped_op, unmapped_op, exceptions @@ -332,26 +332,96 @@ def transpose_inputs(ctx, inputs_as_nchw): def transpose_outputs(ctx, outputs_as_nchw): """Insert a transpose from NHWC to NCHW on model output on users request.""" ops = [] + + # First pass: Find and handle edge cases in original nodes + edge_case_handled = set() + for node in ctx.get_nodes(): for output_name in node.output: - if output_name in outputs_as_nchw: + # Check if this output is used to create a model output + consumers = ctx.find_output_consumers(output_name) + + # Look for edge case: output consumed by both model output node and other nodes + model_output_consumers = [] + other_consumers = [] + + for consumer in consumers: + if consumer.output and any(out in outputs_as_nchw for out in consumer.output): + model_output_consumers.append(consumer) + else: + other_consumers.append(consumer) + + # Edge case: original node output goes to both model output and other layers + if model_output_consumers and other_consumers: + # Get shape for validation + shape = ctx.get_shape(output_name) + if len(shape) != len(constants.NHWC_TO_NCHW): + continue + + # Handle edge case: Use insert_node_on_output for proper structure + # Step 1: Create Identity node and insert it on the original output + identity_name = utils.make_name(node.name + "_identity") + identity = ctx.make_node("Identity", [output_name], + outputs=[identity_name + ":0"], name=identity_name) + + # Copy shape information + ctx.copy_shape(output_name, identity.output[0]) + ctx.set_shape(identity.output[0], shape) + + # Insert the identity on the original output - this will redirect ALL consumers + ctx.insert_node_on_output(identity, output_name) + + # Step 2: Create Transpose node and connect it to Identity + transpose_name = utils.make_name(identity.name + "_transpose") + transpose = ctx.make_node("Transpose", [identity.output[0]], + outputs=[transpose_name + ":0"], name=transpose_name) + transpose.set_attr("perm", constants.NHWC_TO_NCHW) + ctx.copy_shape(identity.output[0], transpose.output[0]) + ctx.set_shape(transpose.output[0], np.array(shape)[constants.NHWC_TO_NCHW]) + + # Step 3: Manually redirect ONLY the model output consumers to use transpose + for consumer in model_output_consumers: + ctx.replace_all_inputs(identity.output[0], transpose.output[0], ops=[consumer]) + + # Mark this output as handled + edge_case_handled.add(output_name) + + ops.append(node) + ops.append(identity) + ops.append(transpose) + break # Only handle one edge case per node + + # If no edge case was handled for this node, add it normally + if not any(out in edge_case_handled for out in node.output): + ops.append(node) + + # Second pass: Handle normal cases (nodes that directly output to model outputs) + final_ops = [] + for node in ops: + handled = False + for output_name in node.output: + if output_name in outputs_as_nchw and output_name not in edge_case_handled: + # Get shape for validation shape = ctx.get_shape(output_name) if len(shape) != len(constants.NHWC_TO_NCHW): logger.warning("transpose_output for %s: shape must be rank 4, ignored" % output_name) - ops.append(node) continue + # insert transpose op_name = utils.make_name(node.name) - transpose = ctx.insert_new_node_on_output("Transpose", node.input[0], name=op_name) + transpose = ctx.insert_new_node_on_output("Transpose", output_name, name=op_name) transpose.set_attr("perm", constants.NHWC_TO_NCHW) - ctx.copy_shape(node.output[0], transpose.output[0]) - ctx.set_shape(transpose.output[0], np.array(shape)[constants.NHWC_TO_NCHW]) + ctx.copy_shape(output_name, transpose.output[0]) ctx.set_shape(output_name, np.array(shape)[constants.NHWC_TO_NCHW]) - ops.append(transpose) - ops.append(node) - continue - ops.append(node) - ctx.reset_nodes(ops) + final_ops.append(transpose) + final_ops.append(node) + handled = True + break + + if not handled: + final_ops.append(node) + + ctx.reset_nodes(final_ops) def topological_sort(g, continue_on_error): ops = g.get_nodes() @@ -522,7 +592,7 @@ def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, outputs_as_nchw, initialized_tables, is_tflite=False, dequantize=False): op_cnt, attr_cnt = g.dump_node_statistics(include_attrs=True, include_subgraphs=False) - + if is_tflite: tfl_rewriters = [] if dequantize: @@ -531,13 +601,16 @@ def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, outputs_as_nchw, tfl_rewriters.append(rewrite_tfl_select_zero) tfl_rewriters.append(rewrite_tfl_rfft) run_rewriters(g, tfl_rewriters, continue_on_error) + tfl_ops_mapping = handler.tfl_op.create_tfl_to_tf_mapping() _, _, exceptions = tensorflow_onnx_mapping(g, tfl_ops_mapping, is_tflite=True, dequantize=False) + if exceptions and not continue_on_error: raise exceptions[0] # create ops mapping for the desired opsets ops_mapping = handler.tf_op.create_mapping(g.opset, g.extra_opset) + # apply custom ops on top of the assembled opset. We can either complement the opset # or override existing ops with a custom op. diff --git a/tf2onnx/version.py b/tf2onnx/version.py index 3727236be..3b62043af 100644 --- a/tf2onnx/version.py +++ b/tf2onnx/version.py @@ -2,4 +2,4 @@ version = '1.16.1' -git_version = '13bab8a91e17ccd87541b2f361ab60e8e38359d3' +git_version = 'None'