Skip to content

Commit

Permalink
Detect and add individual public members in check
Browse files Browse the repository at this point in the history
  • Loading branch information
emdoyle committed Feb 8, 2024
1 parent 3106cc1 commit e8df3f9
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 21 deletions.
4 changes: 2 additions & 2 deletions modguard/cli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import argparse
import os
import sys
from modguard.check import check, ErrorInfo
from modguard.init import init_project
from .check import check, ErrorInfo
from .init import init_project


class BCOLORS:
Expand Down
64 changes: 45 additions & 19 deletions modguard/parsing/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@ def __init__(self, module_name: str):
self.import_found = False

def visit_ImportFrom(self, node):
is_modguard_module_import = node.module is not None and (
node.module == "modguard" or node.module.startswith("modguard.")
)
if is_modguard_module_import and any(
alias.name == self.module_name for alias in node.names
):
self.import_found = True
return
if self.module_name:
is_modguard_module_import = node.module is not None and (
node.module == "modguard" or node.module.startswith("modguard.")
)
if is_modguard_module_import and any(
alias.name == self.module_name for alias in node.names
):
self.import_found = True
return
self.generic_visit(node)

def visit_Import(self, node):
Expand All @@ -32,6 +33,12 @@ def visit_Import(self, node):
self.generic_visit(node)


def is_modguard_imported(parsed_ast: ast.AST, module_name: str = "") -> bool:
modguard_import_visitor = ModguardImportVisitor(module_name)
modguard_import_visitor.visit(parsed_ast)
return modguard_import_visitor.import_found


class PublicMemberVisitor(ast.NodeVisitor):
def __init__(self, current_mod_path: str, is_package: bool = False):
self.is_modguard_public_imported = False
Expand Down Expand Up @@ -104,20 +111,41 @@ def visit_Call(self, node):
top_level_expr = isinstance(parent_node, ast.Expr) and isinstance(
parent_node.parent, ast.Module
)
is_raw_public_call = (
isinstance(node.func, ast.Name) and node.func.id == "public"
)
is_modguard_public_call = (
isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == "modguard"
and node.func.attr == "public"
)
if (
self.is_modguard_public_imported
and (top_level or top_level_expr)
and isinstance(node.func, ast.Name)
and node.func.id == "public"
and (is_raw_public_call or is_modguard_public_call)
):
# public() has been called at the top-level,
# so we add it as the sole PublicMember and return
self.public_members = [
PublicMember(
name="",
allowlist=self._extract_allowlist(node),
if node.args:
# if public is given positional arguments, add each as a public member
self.public_members.extend(
(
PublicMember(
name=arg.id,
allowlist=self._extract_allowlist(node),
)
for arg in node.args
if isinstance(arg, ast.Name)
)
)
]
else:
# if no positional arguments, we add a PublicMember for the whole module and return
self.public_members = [
PublicMember(
name="",
allowlist=self._extract_allowlist(node),
)
]
return
self.generic_visit(node)

Expand Down Expand Up @@ -226,9 +254,7 @@ def mark_as_public(file_path: str, member_name: str = ""):
except SyntaxError as e:
raise ModguardParseError(f"Syntax error in {file_path}: {e}")

modguard_import_visitor = ModguardImportVisitor("public")
modguard_import_visitor.visit(parsed_ast)
modguard_public_is_imported = modguard_import_visitor.import_found
modguard_public_is_imported = is_modguard_imported(parsed_ast, "public")

if not member_name:
with open(file_path, "w") as file:
Expand Down

0 comments on commit e8df3f9

Please sign in to comment.