diff --git a/modguard/cli.py b/modguard/cli.py index 08dbf4e3..63b93bb4 100644 --- a/modguard/cli.py +++ b/modguard/cli.py @@ -1,8 +1,8 @@ import argparse import os import sys -from modguard.check import check, ErrorInfo -from modguard.init import init_project +from .check import check, ErrorInfo +from .init import init_project class BCOLORS: diff --git a/modguard/parsing/public.py b/modguard/parsing/public.py index 31892fa9..a70d545d 100644 --- a/modguard/parsing/public.py +++ b/modguard/parsing/public.py @@ -14,14 +14,15 @@ def __init__(self, module_name: str): self.import_found = False def visit_ImportFrom(self, node): - is_modguard_module_import = node.module is not None and ( - node.module == "modguard" or node.module.startswith("modguard.") - ) - if is_modguard_module_import and any( - alias.name == self.module_name for alias in node.names - ): - self.import_found = True - return + if self.module_name: + is_modguard_module_import = node.module is not None and ( + node.module == "modguard" or node.module.startswith("modguard.") + ) + if is_modguard_module_import and any( + alias.name == self.module_name for alias in node.names + ): + self.import_found = True + return self.generic_visit(node) def visit_Import(self, node): @@ -32,6 +33,12 @@ def visit_Import(self, node): self.generic_visit(node) +def is_modguard_imported(parsed_ast: ast.AST, module_name: str = "") -> bool: + modguard_import_visitor = ModguardImportVisitor(module_name) + modguard_import_visitor.visit(parsed_ast) + return modguard_import_visitor.import_found + + class PublicMemberVisitor(ast.NodeVisitor): def __init__(self, current_mod_path: str, is_package: bool = False): self.is_modguard_public_imported = False @@ -104,20 +111,41 @@ def visit_Call(self, node): top_level_expr = isinstance(parent_node, ast.Expr) and isinstance( parent_node.parent, ast.Module ) + is_raw_public_call = ( + isinstance(node.func, ast.Name) and node.func.id == "public" + ) + is_modguard_public_call = ( + isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "modguard" + and node.func.attr == "public" + ) if ( self.is_modguard_public_imported and (top_level or top_level_expr) - and isinstance(node.func, ast.Name) - and node.func.id == "public" + and (is_raw_public_call or is_modguard_public_call) ): # public() has been called at the top-level, - # so we add it as the sole PublicMember and return - self.public_members = [ - PublicMember( - name="", - allowlist=self._extract_allowlist(node), + if node.args: + # if public is given positional arguments, add each as a public member + self.public_members.extend( + ( + PublicMember( + name=arg.id, + allowlist=self._extract_allowlist(node), + ) + for arg in node.args + if isinstance(arg, ast.Name) + ) ) - ] + else: + # if no positional arguments, we add a PublicMember for the whole module and return + self.public_members = [ + PublicMember( + name="", + allowlist=self._extract_allowlist(node), + ) + ] return self.generic_visit(node) @@ -226,9 +254,7 @@ def mark_as_public(file_path: str, member_name: str = ""): except SyntaxError as e: raise ModguardParseError(f"Syntax error in {file_path}: {e}") - modguard_import_visitor = ModguardImportVisitor("public") - modguard_import_visitor.visit(parsed_ast) - modguard_public_is_imported = modguard_import_visitor.import_found + modguard_public_is_imported = is_modguard_imported(parsed_ast, "public") if not member_name: with open(file_path, "w") as file: