diff --git a/modguard/init.py b/modguard/init.py index b236ada8..e9782d19 100644 --- a/modguard/init.py +++ b/modguard/init.py @@ -1,3 +1,5 @@ +from dataclasses import dataclass +from enum import Enum import os from typing import Optional @@ -5,11 +7,23 @@ from .check import check_import from .core import PublicMember from .parsing import utils -from .parsing.boundary import ensure_boundary, build_boundary_trie +from .parsing.boundary import add_boundary, has_boundary, build_boundary_trie from .parsing.imports import get_imports from .parsing.public import mark_as_public +class WriteOperation(Enum): + BOUNDARY = "boundary" + PUBLIC = "public" + + +@dataclass +class FileWriteInformation: + location: str + operation: WriteOperation + member_name: str = "" + + def init_project(root: str, exclude_paths: Optional[list[str]] = None): # Core functionality: # * do nothing in any package already having a Boundary @@ -22,16 +36,23 @@ def init_project(root: str, exclude_paths: Optional[list[str]] = None): root = utils.canonical(root) exclude_paths = list(map(utils.canonical, exclude_paths)) if exclude_paths else None + write_operations: list[FileWriteInformation] = [] + boundary_trie = build_boundary_trie(root, exclude_paths=exclude_paths) initial_boundary_paths = [ boundary.full_path for boundary in boundary_trie if boundary.full_path ] for dirpath in utils.walk_pypackages(root, exclude_paths=exclude_paths): - added_boundary = ensure_boundary(dirpath + "/__init__.py") - if added_boundary: + filepath = dirpath + "/__init__.py" + if not has_boundary(filepath): dir_mod_path = utils.file_to_module_path(dirpath) boundary_trie.insert(dir_mod_path) + write_operations.append( + FileWriteInformation( + location=filepath, operation=WriteOperation.BOUNDARY + ) + ) for file_path in utils.walk_pyfiles(root, exclude_paths=exclude_paths): mod_path = utils.file_to_module_path(file_path) @@ -63,9 +84,24 @@ def init_project(root: str, exclude_paths: Optional[list[str]] = None): file_path, member_name = utils.module_to_file_path(import_mod_path) try: - mark_as_public(file_path, member_name) + write_operations.append( + FileWriteInformation( + location=file_path, + operation=WriteOperation.PUBLIC, + member_name=member_name, + ) + ) violated_boundary.add_public_member(PublicMember(name=import_mod_path)) except errors.ModguardError: print( f"Skipping member {member_name} in {file_path}; could not mark as public" ) + # After we've completed our pass on inserting boundaries and public members, write to files + for write_op in write_operations: + try: + if write_op.operation == WriteOperation.BOUNDARY: + add_boundary(write_op.location) + if write_op.operation == WriteOperation.PUBLIC: + mark_as_public(write_op.location, write_op.member_name) + except errors.ModguardError: + print(f"Error marking {write_op.operation} in {write_op.operation}") diff --git a/modguard/parsing/boundary.py b/modguard/parsing/boundary.py index c6bf1b50..e47b7832 100644 --- a/modguard/parsing/boundary.py +++ b/modguard/parsing/boundary.py @@ -56,7 +56,7 @@ def _has_boundary(file_path: str, file_content: str) -> bool: boundary_finder.visit(parsed_ast) return boundary_finder.found_boundary - +@public def has_boundary(file_path: str) -> bool: with open(file_path, "r") as file: file_content = file.read() @@ -66,24 +66,12 @@ def has_boundary(file_path: str) -> bool: BOUNDARY_PRELUDE = "import modguard\nmodguard.Boundary()\n" - -def _add_boundary(file_path: str, file_content: str): - with open(file_path, "w") as file: - file.write(BOUNDARY_PRELUDE + file_content) - - @public -def ensure_boundary(file_path: str) -> bool: - with open(file_path, "r") as file: +def add_boundary(file_path: str) -> None: + with open(file_path, "r+") as file: file_content = file.read() - - if _has_boundary(file_path, file_content): - # Boundary already exists, don't need to create one - return False - - # Boundary doesn't exist, create one - _add_boundary(file_path, file_content) - return True + file.seek(0) + file.write(BOUNDARY_PRELUDE + file_content) @public diff --git a/modguard/parsing/public.py b/modguard/parsing/public.py index 2815da9d..53ab20b0 100644 --- a/modguard/parsing/public.py +++ b/modguard/parsing/public.py @@ -252,49 +252,45 @@ def _public_module_prelude(should_import: bool = True) -> str: @public def mark_as_public(file_path: str, member_name: str = ""): - with open(file_path, "r") as file: + with open(file_path, "r+") as file: file_content = file.read() - - try: - parsed_ast = ast.parse(file_content) - except SyntaxError as e: - raise ModguardParseError(f"Syntax error in {file_path}: {e}") - - modguard_public_is_imported = is_modguard_imported(parsed_ast, "public") - - if not member_name: - with open(file_path, "w") as file: + file.seek(0) + try: + parsed_ast = ast.parse(file_content) + except SyntaxError as e: + raise ModguardParseError(f"Syntax error in {file_path}: {e}") + modguard_public_is_imported = is_modguard_imported(parsed_ast, "public") + if not member_name: file.write( _public_module_prelude(should_import=not modguard_public_is_imported) + file_content ) - return + return + + member_finder = MemberFinder(member_name) + member_finder.visit(parsed_ast) + if member_finder.matched_lineno is None: + raise ModguardParseError( + f"Failed to find member {member_name} in file {file_path}" + ) - member_finder = MemberFinder(member_name) - member_finder.visit(parsed_ast) - if member_finder.matched_lineno is None: - raise ModguardParseError( - f"Failed to find member {member_name} in file {file_path}" - ) + normal_lineno = member_finder.matched_lineno - 1 + file_lines = file_content.splitlines(keepends=True) + if member_finder.matched_assignment: + # Insert a call to public for the member after the assignment + lines_to_write = [ + *file_lines[: normal_lineno + 1], + f"{PUBLIC_CALL}({member_name})\n", + *file_lines[normal_lineno + 1 :], + ] + else: + # Insert a decorator before the function or class definition + lines_to_write = [ + *file_lines[:normal_lineno], + PUBLIC_DECORATOR + "\n", + *file_lines[normal_lineno:], + ] + if not modguard_public_is_imported: + lines_to_write = [IMPORT_MODGUARD + "\n", *lines_to_write] - normal_lineno = member_finder.matched_lineno - 1 - file_lines = file_content.splitlines(keepends=True) - if member_finder.matched_assignment: - # Insert a call to public for the member after the assignment - lines_to_write = [ - *file_lines[: normal_lineno + 1], - f"{PUBLIC_CALL}({member_name})\n", - *file_lines[normal_lineno + 1 :], - ] - else: - # Insert a decorator before the function or class definition - lines_to_write = [ - *file_lines[:normal_lineno], - PUBLIC_DECORATOR + "\n", - *file_lines[normal_lineno:], - ] - if not modguard_public_is_imported: - lines_to_write = [IMPORT_MODGUARD + "\n", *lines_to_write] - - with open(file_path, "w") as file: file.write("".join(lines_to_write)) diff --git a/tests/__init__.py b/tests/__init__.py index e960b710..5c072682 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,2 +1,3 @@ import modguard + modguard.Boundary()