2
2
import glob
3
3
import inspect
4
4
import json
5
- import logging
6
5
import os
7
6
import random
8
7
import sys
8
+ import re
9
9
from typing import Dict , List , Any , Callable , Tuple
10
10
11
11
import black
12
12
13
- from utils import import_custom_nodes , add_comfyui_directory_to_sys_path , get_value_at_index
14
13
15
- sys . path . append ( '../' )
14
+ from utils import import_custom_nodes , find_path , add_comfyui_directory_to_sys_path , add_extra_model_paths , get_value_at_index
16
15
16
+ sys .path .append ('../' )
17
17
from nodes import NODE_CLASS_MAPPINGS
18
18
19
19
20
- logging .basicConfig (level = logging .INFO )
21
-
22
-
23
20
class FileHandler :
24
21
"""Handles reading and writing files.
25
22
@@ -217,7 +214,7 @@ def generate_workflow(self, load_order: List, filename: str = 'generated_code_wo
217
214
continue
218
215
219
216
class_type , import_statement , class_code = self .get_class_info (class_type )
220
- initialized_objects [class_type ] = class_type . lower (). strip ( )
217
+ initialized_objects [class_type ] = self . clean_variable_name ( class_type )
221
218
if class_type in self .base_node_class_mappings .keys ():
222
219
import_statements .add (import_statement )
223
220
if class_type not in self .base_node_class_mappings .keys ():
@@ -234,9 +231,9 @@ def generate_workflow(self, load_order: List, filename: str = 'generated_code_wo
234
231
inputs ['unique_id' ] = random .randint (1 , 2 ** 64 )
235
232
236
233
# Create executed variable and generate code
237
- executed_variables [idx ] = f'{ class_type . lower (). strip ( )} _{ idx } '
234
+ executed_variables [idx ] = f'{ self . clean_variable_name ( class_type )} _{ idx } '
238
235
inputs = self .update_inputs (inputs , executed_variables )
239
-
236
+
240
237
if is_special_function :
241
238
special_functions_code .append (self .create_function_call_code (initialized_objects [class_type ], class_def .FUNCTION , executed_variables [idx ], is_special_function , ** inputs ))
242
239
else :
@@ -306,11 +303,11 @@ def assemble_python_code(self, import_statements: set, speical_functions_code: L
306
303
"""
307
304
# Get the source code of the utils functions as a string
308
305
func_strings = []
309
- for func in [add_comfyui_directory_to_sys_path , get_value_at_index ]:
306
+ for func in [get_value_at_index , find_path , add_comfyui_directory_to_sys_path , add_extra_model_paths ]:
310
307
func_strings .append (f'\n { inspect .getsource (func )} ' )
311
308
# Define static import statements required for the script
312
309
static_imports = ['import os' , 'import random' , 'import sys' , 'from typing import Sequence, Mapping, Any, Union' ,
313
- 'import torch' ] + func_strings + ['\n \n add_comfyui_directory_to_sys_path()' ]
310
+ 'import torch' ] + func_strings + ['\n \n add_comfyui_directory_to_sys_path()\n add_extra_model_paths() \n ' ]
314
311
# Check if custom nodes should be included
315
312
if custom_nodes :
316
313
static_imports .append (f'\n { inspect .getsource (import_custom_nodes )} \n ' )
@@ -328,7 +325,7 @@ def assemble_python_code(self, import_statements: set, speical_functions_code: L
328
325
final_code = black .format_str (final_code , mode = black .Mode ())
329
326
330
327
return final_code
331
-
328
+
332
329
def get_class_info (self , class_type : str ) -> Tuple [str , str , str ]:
333
330
"""Generates and returns necessary information about class type.
334
331
@@ -339,12 +336,36 @@ def get_class_info(self, class_type: str) -> Tuple[str, str, str]:
339
336
Tuple[str, str, str]: Updated class type, import statement string, class initialization code.
340
337
"""
341
338
import_statement = class_type
339
+ variable_name = self .clean_variable_name (class_type )
342
340
if class_type in self .base_node_class_mappings .keys ():
343
- class_code = f'{ class_type . lower (). strip () } = { class_type .strip ()} ()'
341
+ class_code = f'{ variable_name } = { class_type .strip ()} ()'
344
342
else :
345
- class_code = f'{ class_type . lower (). strip () } = NODE_CLASS_MAPPINGS["{ class_type } "]()'
343
+ class_code = f'{ variable_name } = NODE_CLASS_MAPPINGS["{ class_type } "]()'
346
344
347
345
return class_type , import_statement , class_code
346
+
347
+ @staticmethod
348
+ def clean_variable_name (class_type : str ) -> str :
349
+ """
350
+ Remove any characters from variable name that could cause errors running the Python script.
351
+
352
+ Args:
353
+ class_type (str): Class type.
354
+
355
+ Returns:
356
+ str: Cleaned variable name with no special characters or spaces
357
+ """
358
+ # Convert to lowercase and replace spaces with underscores
359
+ clean_name = class_type .lower ().strip ().replace ("-" , "_" ).replace (" " , "_" )
360
+
361
+ # Remove characters that are not letters, numbers, or underscores
362
+ clean_name = re .sub (r'[^a-z0-9_]' , '' , clean_name )
363
+
364
+ # Ensure that it doesn't start with a number
365
+ if clean_name [0 ].isdigit ():
366
+ clean_name = "_" + clean_name
367
+
368
+ return clean_name
348
369
349
370
def get_function_parameters (self , func : Callable ) -> List :
350
371
"""Get the names of a function's parameters.
0 commit comments