Skip to content

Commit

Permalink
Find non-function, non-class members being imported; TODO public mark…
Browse files Browse the repository at this point in the history
…er logic
  • Loading branch information
emdoyle committed Feb 8, 2024
1 parent 6b488f3 commit 1dca014
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 13 deletions.
2 changes: 1 addition & 1 deletion modguard/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def init_project(root: str, exclude_paths: list[str] = None):
try:
mark_as_public(file_path, member_name)
violated_boundary.add_public_member(PublicMember(name=import_mod_path))
except errors.ModguardParseError:
except errors.ModguardError:
print(
f"Skipping member {member_name} in {file_path}; could not mark as public"
)
57 changes: 47 additions & 10 deletions modguard/parsing/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from modguard import public
from modguard.core.public import PublicMember
from modguard.errors import ModguardParseError
from modguard.errors import ModguardParseError, ModguardError

from .utils import file_to_module_path

Expand Down Expand Up @@ -145,12 +145,43 @@ def get_public_members(file_path: str) -> list[PublicMember]:
return public_member_visitor.public_members


# TODO: handle re-exported members (follow imports?)
class MemberFinder(ast.NodeVisitor):
def __init__(self, member_name: str):
self.member_name = member_name
self.matched_lineno: Optional[int] = None
self.matched_assignment = False
self.depth = 0

def _check_assignment(self, node):
if self.depth == 0:
for target in node.targets:
if isinstance(target, ast.Name) and target.id == self.member_name:
self.matched_lineno = target.lineno
self.matched_assignment = True
return
elif isinstance(target, ast.List) or isinstance(target, ast.Tuple):
for elt in target.elts:
if isinstance(elt, ast.Name) and elt.id == self.member_name:
self.matched_lineno = elt.lineno
self.matched_assignment = True
return

def visit_Assign(self, node):
self._check_assignment(node)
self.generic_visit(node)

def visit_AnnAssign(self, node):
self._check_assignment(node)
self.generic_visit(node)

def visit_Global(self, node):
if self.member_name in node.names:
self.matched_lineno = node.lineno
self.matched_assignment = True
return
self.generic_visit(node)

def visit_FunctionDef(self, node):
if self.depth == 0 and node.name == self.member_name:
self.matched_lineno = node.lineno
Expand Down Expand Up @@ -209,14 +240,20 @@ def mark_as_public(file_path: str, member_name: str = ""):
f"Failed to find member {member_name} in file {file_path}"
)

with open(file_path, "w") as file:
file_lines = file_content.splitlines(keepends=True)
lines_to_write = [
*file_lines[: member_finder.matched_lineno - 1],
PUBLIC_DECORATOR + "\n",
*file_lines[member_finder.matched_lineno - 1 :],
]
if not modguard_public_is_imported:
lines_to_write = [IMPORT_MODGUARD + "\n", *lines_to_write]
normal_lineno = member_finder.matched_lineno - 1
file_lines = file_content.splitlines(keepends=True)
if member_finder.matched_assignment:
raise ModguardError(
f"Failed to mark {member_name} as public in file {file_path}."
)

lines_to_write = [
*file_lines[:normal_lineno],
PUBLIC_DECORATOR + "\n",
*file_lines[normal_lineno:],
]
if not modguard_public_is_imported:
lines_to_write = [IMPORT_MODGUARD + "\n", *lines_to_write]

with open(file_path, "w") as file:
file.write("".join(lines_to_write))
4 changes: 2 additions & 2 deletions modguard/public.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
def public(fn: callable = None, *, allowlist: list[str] = None):
return fn
def public(obj: object = None, *, allowlist: list[str] = None):
return obj

0 comments on commit 1dca014

Please sign in to comment.