Skip to content

Commit

Permalink
Merge pull request #6 from Never-Over/handle-missing-dynamic-imports
Browse files Browse the repository at this point in the history
Handle missing dynamic imports
  • Loading branch information
emdoyle authored Feb 13, 2024
2 parents c17c602 + f6e0f22 commit 5cd9d15
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 63 deletions.
13 changes: 10 additions & 3 deletions modguard/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,11 @@ def init_project(root: str, exclude_paths: Optional[list[str]] = None):
# This import is fine, no need to mark anything as public
continue

file_path, member_name = fs.module_to_file_path(import_mod_path)
member_name = ""
try:
file_path, member_name = fs.module_to_file_path(
import_mod_path, find_package_init=True
)
write_operations.append(
FileWriteInformation(
location=file_path,
Expand All @@ -93,7 +96,7 @@ def init_project(root: str, exclude_paths: Optional[list[str]] = None):
violated_boundary.add_public_member(PublicMember(name=import_mod_path))
except errors.ModguardError:
print(
f"Skipping member {member_name} in {file_path}; could not mark as public"
f"Skipping member {member_name or import_mod_path} in {file_path}; could not mark as public"
)
# After we've completed our pass on inserting boundaries and public members, write to files
for write_op in write_operations:
Expand All @@ -103,4 +106,8 @@ def init_project(root: str, exclude_paths: Optional[list[str]] = None):
elif write_op.operation == WriteOperation.PUBLIC:
mark_as_public(write_op.location, write_op.member_name)
except errors.ModguardError:
print(f"Error marking {write_op.operation} in {write_op.location}")
print(
f"Error marking {write_op.operation.value}"
f"{'({member})'.format(member=write_op.member_name) if write_op.member_name else ''}"
f" in {write_op.location}"
)
121 changes: 61 additions & 60 deletions modguard/parsing/public.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import ast
import re
from typing import Optional, Union

from modguard import public, filesystem as fs
from modguard.core.public import PublicMember
from modguard.errors import ModguardParseError


class ModguardImportVisitor(ast.NodeVisitor):
Expand Down Expand Up @@ -99,12 +99,6 @@ def visit_ClassDef(self, node: ast.ClassDef):
self._add_public_member_from_decorator(node=node, decorator=decorator)

def visit_Call(self, node: ast.Call):
parent_node = getattr(node, "parent")
grandparent_node = getattr(parent_node, "parent")
top_level = isinstance(parent_node, ast.Module)
top_level_expr = isinstance(parent_node, ast.Expr) and isinstance(
grandparent_node, ast.Module
)
is_raw_public_call = (
isinstance(node.func, ast.Name) and node.func.id == "public"
)
Expand All @@ -114,22 +108,20 @@ def visit_Call(self, node: ast.Call):
and node.func.value.id == "modguard"
and node.func.attr == "public"
)
if (
self.is_modguard_public_imported
and (top_level or top_level_expr)
and (is_raw_public_call or is_modguard_public_call)
if self.is_modguard_public_imported and (
is_raw_public_call or is_modguard_public_call
):
# public() has been called at the top-level,
if node.args:
# if public is given positional arguments, add each as a public member
self.public_members.extend(
(
PublicMember(
name=arg.id,
name=arg.id if isinstance(arg, ast.Name) else arg.value,
allowlist=self._extract_allowlist(node),
)
for arg in node.args
if isinstance(arg, ast.Name)
if isinstance(arg, ast.Name) or isinstance(arg, ast.Constant)
)
)
else:
Expand All @@ -142,12 +134,6 @@ def visit_Call(self, node: ast.Call):
]
return

def visit(self, node: ast.AST):
# Inject a 'parent' attribute to each node for easier parent tracking
for child in ast.iter_child_nodes(node):
setattr(child, "parent", node)
super().visit(node)


def get_public_members(file_path: str) -> list[PublicMember]:
parsed_ast = fs.parse_ast(file_path)
Expand All @@ -159,102 +145,117 @@ def get_public_members(file_path: str) -> list[PublicMember]:
return public_member_visitor.public_members


# TODO: handle re-exported members (follow imports?)
class MemberFinder(ast.NodeVisitor):
def __init__(self, member_name: str):
self.member_name = member_name
# For functions and classes, matched_lineno is the start of the definition
# because a decorator can be inserted directly before the definition
# For assignments, matched_lineno is the end of the assignment
# because a public(...) call can be inserted directly after the assignment
self.matched_lineno: Optional[int] = None
self.start_lineno: Optional[int] = None
# For assignments, end_lineno is the end of the assignment value expression
# because a public(...) call should be inserted directly after the assignment
self.end_lineno: Optional[int] = None
self.matched_assignment = False

def _check_assignment_target(
self, target: Union[ast.expr, ast.Name, ast.Attribute, ast.Subscript]
def _check_assignment(
self,
target: Union[ast.expr, ast.Name, ast.Attribute, ast.Subscript],
value: ast.expr,
):
if isinstance(target, ast.Name) and target.id == self.member_name:
self.matched_lineno = target.end_lineno
self.start_lineno = target.lineno
self.end_lineno = value.end_lineno
self.matched_assignment = True
return
elif isinstance(target, ast.List) or isinstance(target, ast.Tuple):
for elt in target.elts:
if isinstance(elt, ast.Name) and elt.id == self.member_name:
self.matched_lineno = target.end_lineno
self.start_lineno = target.lineno
self.matched_lineno = value.end_lineno
self.matched_assignment = True
return

def visit_Assign(self, node: ast.Assign):
for target in node.targets:
self._check_assignment_target(target)
self._check_assignment(target, node.value)

def visit_AnnAssign(self, node: ast.AnnAssign):
self._check_assignment_target(node.target)
# If node.value is none, can use target itself for end_lineno
self._check_assignment(node.target, node.value or node.target)

def visit_Global(self, node: ast.Global):
if self.member_name in node.names:
self.matched_lineno = node.end_lineno
self.start_lineno = node.lineno
self.matched_assignment = True
return

def visit_FunctionDef(self, node: ast.FunctionDef):
if node.name == self.member_name:
self.matched_lineno = node.lineno
self.start_lineno = node.lineno
return

def visit_ClassDef(self, node: ast.ClassDef):
if node.name == self.member_name:
self.matched_lineno = node.lineno
self.start_lineno = node.lineno
return


def _public_module_end(should_import: bool = True) -> str:
if should_import:
return "import modguard\nmodguard.public()\n"
return "modguard.public()\n"


IMPORT_REGEX = re.compile(r"^(from |import )")
WHITESPACE_REGEX = re.compile(r"^((\s)*)")
IMPORT_MODGUARD = "import modguard"
PUBLIC_DECORATOR = "@modguard.public"
PUBLIC_CALL = "modguard.public"
MODGUARD_PUBLIC = "modguard.public"


@public
def mark_as_public(file_path: str, member_name: str = ""):
file_content = fs.read_file(file_path)
parsed_ast = fs.parse_ast(file_path)
modguard_public_is_imported = is_modguard_imported(parsed_ast, "public")
if not member_name:
if not member_name or member_name == "*":
fs.write_file(
file_path,
file_content
+ _public_module_end(should_import=not modguard_public_is_imported),
(f"{IMPORT_MODGUARD}\n" if not modguard_public_is_imported else "")
+ file_content
+ f"{MODGUARD_PUBLIC}()\n",
)
return

member_finder = MemberFinder(member_name)
member_finder.visit(parsed_ast)
if member_finder.matched_lineno is None:
raise ModguardParseError(
f"Failed to find member {member_name} in file {file_path}"
)

normal_lineno = member_finder.matched_lineno - 1
file_lines = file_content.splitlines(keepends=True)
if member_finder.matched_assignment:
# Insert a call to public for the member after the assignment
lines_to_write: list[str]
if member_finder.start_lineno is None:
# The member name was not found, which probably means it is dynamic
# Add a public call with the member name as a string
lines_to_write = [
*file_lines[: normal_lineno + 1],
f"{PUBLIC_CALL}({member_name})\n",
*file_lines[normal_lineno + 1 :],
*file_lines,
f'{MODGUARD_PUBLIC}("{member_name}")\n',
]
else:
# Insert a decorator before the function or class definition
lines_to_write = [
*file_lines[:normal_lineno],
PUBLIC_DECORATOR + "\n",
*file_lines[normal_lineno:],
]
starting_line = file_lines[member_finder.start_lineno - 1]
starting_whitespace_match = WHITESPACE_REGEX.match(starting_line)
assert (
starting_whitespace_match
), f"Whitespace regex should always match.\n{starting_line}"

# The member name was found
if member_finder.matched_assignment:
assert (
member_finder.end_lineno is not None
), f"Expected to find end_lineno on matched assignment. [{file_path}, {member_name}]"

# Insert a call to public for the member after the assignment
lines_to_write = [
*file_lines[: member_finder.end_lineno],
f"{starting_whitespace_match.group(1) or ''}{MODGUARD_PUBLIC}({member_name})\n",
*file_lines[member_finder.end_lineno :],
]
else:
# Insert a decorator before the function or class definition
lines_to_write = [
*file_lines[: member_finder.start_lineno - 1],
f"{starting_whitespace_match.group(1) or ''}{PUBLIC_DECORATOR}\n",
*file_lines[member_finder.start_lineno - 1 :],
]
if not modguard_public_is_imported:
lines_to_write = [IMPORT_MODGUARD + "\n", *lines_to_write]

Expand Down

0 comments on commit 5cd9d15

Please sign in to comment.