From 2a07a1e86481c3dbd6bf41f5834701dc9d6370d4 Mon Sep 17 00:00:00 2001 From: JohnWL <34081873+John-WL@users.noreply.github.com> Date: Tue, 9 Jan 2024 15:59:08 -0500 Subject: [PATCH] attempt at making the patch more stable --- lib_comfyui/custom_extension_injector.py | 56 ++++++++++++------------ 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/lib_comfyui/custom_extension_injector.py b/lib_comfyui/custom_extension_injector.py index d00b225..92761d5 100644 --- a/lib_comfyui/custom_extension_injector.py +++ b/lib_comfyui/custom_extension_injector.py @@ -1,3 +1,4 @@ +import functools import os import inspect import ast @@ -30,8 +31,8 @@ def register_custom_scripts(custom_scripts_path_list): parsed_module = ast.parse(inspect.getsource(server.PromptServer)) parsed_class = parsed_module.body[0] patch_prompt_server_init(parsed_class, custom_scripts_path_list) - patch_prompt_server_add_routes(parsed_class, custom_scripts_path_list) exec(compile(parsed_module, '', 'exec'), server.__dict__) + add_server__init__patch(functools.partial(patch_prompt_server_add_routes, custom_scripts_path_list=custom_scripts_path_list)) # patch for https://github.com/comfyanonymous/ComfyUI/blob/490771b7f495c95fb52875cf234fffc367162c7e/server.py#L123 @@ -65,34 +66,23 @@ def generate_prompt_server_init_code_patch(custom_scripts_path): """) -# patch for https://github.com/comfyanonymous/ComfyUI/blob/490771b7f495c95fb52875cf234fffc367162c7e/server.py#L487 -def patch_prompt_server_add_routes(parsed_class: ast.ClassDef, custom_scripts_path_list): - """ - ComfyUI/serever.py - - ... - def add_routes(self): - self.user_manager.add_routes(self.routes) - self.app.add_routes(self.routes) +# patch for https://github.com/comfyanonymous/ComfyUI/blob/6a7bc35db845179a26e62534f3d4b789151e52fe/server.py#L536 +def patch_prompt_server_add_routes(self, *_, custom_scripts_path_list, **__): + from aiohttp import web - [...] + def add_routes_patch(*args, original_function, **kwargs): + new_routes = [ + web.static( + f"/webui_scripts/{os.path.basename(os.path.dirname(custom_scripts_path))}", + fr"{custom_scripts_path}", + follow_symlinks=True + ) + for custom_scripts_path in custom_scripts_path_list + ] + self.app.add_routes(new_routes) + original_function(*args, **kwargs) - self.app.add_routes([ - <- add code right there in the list - web.static('/', self.web_root, follow_symlinks=True), - ]) - ... - """ - add_routes_ast_function = get_ast_function(parsed_class, 'add_routes') - for custom_scripts_path in custom_scripts_path_list: - code_patch = generate_prompt_server_add_routes_code_patch(custom_scripts_path) - extra_line_of_code = ast.parse(code_patch) - try: - add_routes_ast_function.body[3].value.args[0].elts[0:0] = [extra_line_of_code.body[0].value] - except: - raise RuntimeError("ComfyUI was probably updated with breaking changes. " - "If A1111, ComfyUI and sd-webui-comfyui are up to date, " - "please notify the authors of sd-webui-comfyui.") + self.add_routes = functools.partial(add_routes_patch, original_function=self.add_routes) def generate_prompt_server_add_routes_code_patch(custom_scripts_path): @@ -105,3 +95,15 @@ def get_ast_function(parsed_object, function_name): raise RuntimeError(f'Cannot find function {function_name} in parsed ast') return res[0] + + +@ipc.restrict_to_process('comfyui') +def add_server__init__patch(callback): + import server + original_init = server.PromptServer.__init__ + + def patched_PromptQueue__init__(*args, **kwargs): + callback(*args, **kwargs) + original_init(*args, **kwargs) + + server.PromptServer.__init__ = patched_PromptQueue__init__