diff --git a/tools/convert_to_bizyair.py b/tools/convert_to_bizyair.py index e77f05f5..22d7e2b7 100644 --- a/tools/convert_to_bizyair.py +++ b/tools/convert_to_bizyair.py @@ -11,6 +11,7 @@ from pathlib import Path from typing import List +import yaml from loguru import logger @@ -98,12 +99,13 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("-i", "--input", type=str, required=True) parser.add_argument("-o", "--output", type=str, required=False, default=None) + parser.add_argument("-p", "--patch", type=str, required=False, default=None) return parser.parse_args() def main(): args = get_args() - out = convert_to_bizyair(load_input_file(args.input)) + out = convert_to_bizyair(load_input_file(args.input), args.patch) if args.output is None: args.output = args.input.replace(".json", ".bizyair.json") with open(args.output, "w") as f: @@ -117,13 +119,55 @@ def get_bizyair_display_name(class_type: str) -> str: return f"{bizyair_logo}{bizyair_cls_prefix} {bizyair.NODE_DISPLAY_NAME_MAPPINGS.get(class_type, class_type)}" -def convert_to_bizyair(inputs: dict): - bizyair.NODE_CLASS_MAPPINGS +def get_trans_format(inputs: dict): + if "nodes" in inputs: + return "workflow" + return "workflow_api" + + +def workflow_convert(inputs: dict): + nodes = inputs["nodes"] + for node in nodes: + class_type = node["type"] + node_inputs = node.get("inputs") + node_outputs = node.get("outputs") + + bizyair_cls_type = f"{bizyair.nodes_base.PREFIX}_{class_type}" + is_converted = False + + if bizyair_cls_type in bizyair.NODE_CLASS_MAPPINGS: + node["type"] = bizyair_cls_type + + display_name = get_bizyair_display_name(class_type) + node["properties"]["Node name for S&R"] = display_name + + if node_inputs: + for input_node in node_inputs: + input_type = input_node["type"] + input_node["type"] = f"{bizyair.nodes_base.PREFIX}_{input_type}" + + if node_outputs: + for output_node in node_outputs: + output_type = output_node["type"] + output_node["type"] = f"{bizyair.nodes_base.PREFIX}_{output_type}" + is_converted = True + pprint.pprint( + { + "original_class_type": class_type, + "bizyair_cls_type": bizyair_cls_type, + "is_converted": is_converted, + } + ) + + return inputs + +def workflow_api_convert(inputs: dict): for x in inputs.copy(): class_type = inputs[x]["class_type"] bizyair_cls_type = f"{bizyair.nodes_base.PREFIX}_{class_type}" is_converted = False + if bizyair_cls_type in bizyair.NODE_CLASS_MAPPINGS: inputs[x]["class_type"] = bizyair_cls_type display_name = get_bizyair_display_name(class_type) @@ -137,8 +181,59 @@ def convert_to_bizyair(inputs: dict): "is_converted": is_converted, } ) + return inputs +def patch_apply(inputs: dict, yaml_file): + replacements = load_yaml_replacements(yaml_file) + for replacement in replacements["node_replacements"]: + + original_type = replacement["original_type"] + replace_type = replacement["replace_type"] + + for node in inputs["nodes"]: + if "type" in node and node["type"] == original_type: + node["type"] = replace_type + + display_name = get_bizyair_display_name(replace_type) + node["properties"]["Node name for S&R"] = display_name + + node_inputs = node.get("inputs") + if node_inputs: + for input_node in node_inputs: + input_type = input_node["type"] + input_node["type"] = f"{bizyair.nodes_base.PREFIX}_{input_type}" + + node_outputs = node.get("outputs") + if node_outputs: + for output_node in node_outputs: + output_type = output_node["type"] + output_node["type"] = f"{bizyair.nodes_base.PREFIX}_{output_type}" + + return inputs + + +def convert_to_bizyair(inputs: dict, yaml_file): + bizyair.NODE_CLASS_MAPPINGS + + input_format = get_trans_format(inputs) + if input_format == "workflow_api": + inputs = workflow_api_convert(inputs) + elif input_format == "workflow": + inputs = workflow_convert(inputs) + + if yaml_file: + inputs = patch_apply(inputs, yaml_file) + + return inputs + + +def load_yaml_replacements(yaml_file): + with open(yaml_file, "r") as file: + replacements = yaml.safe_load(file) + return replacements + + if __name__ == "__main__": main() diff --git a/tools/test.yaml b/tools/test.yaml new file mode 100644 index 00000000..e40189e0 --- /dev/null +++ b/tools/test.yaml @@ -0,0 +1,6 @@ +node_replacements: + - original_type: "DualCLIPLoaderGGUF" + replace_type: "BizyAir_DualCLIPLoader" + + - original_type: "UnetLoaderGGUF" + replace_type: "BizyAir_UNETLoader"