Skip to content

Commit

Permalink
Merge pull request #787 from python-rope/lieryan-typing-import-info
Browse files Browse the repository at this point in the history
Add type hints to importinfo.py and add repr to ImportInfo
  • Loading branch information
lieryan authored Apr 4, 2024
2 parents 14d6fd2 + 3f89161 commit 9f146b3
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# **Upcoming release**

- #787 Add type hints to importinfo.py and add repr to ImportInfo (@lieryan)
- #786 Upgrade Actions used in Github Workflows (@lieryan)
- #785 Refactoring movetest.py (@lieryan)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ force-exclude = 'ropetest|rope/base/prefs.py'
[tool.coverage.report]
exclude_also = [
"if TYPE_CHECKING:",
"raise NotImplementedError()",
]

[tool.isort]
Expand Down
48 changes: 34 additions & 14 deletions rope/refactor/importutils/importinfo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Tuple
from abc import abstractmethod, ABC
from typing import List, Tuple, Optional, Protocol


class ImportStatement:
Expand Down Expand Up @@ -64,17 +65,22 @@ def accept(self, visitor):
return visitor.dispatch(self)


class ImportInfo:
def get_imported_primaries(self, context):
pass
class ImportInfo(ABC):
names_and_aliases: List[Tuple[str, Optional[str]]]

@abstractmethod
def get_imported_primaries(self, context) -> List[str]: ...

def get_imported_names(self, context):
return [
primary.split(".")[0] for primary in self.get_imported_primaries(context)
]

def get_import_statement(self):
pass
def __repr__(self):
return f'<{self.__class__.__name__} "{self.get_import_statement()}">'

@abstractmethod
def get_import_statement(self) -> str: ...

def is_empty(self):
pass
Expand Down Expand Up @@ -105,10 +111,13 @@ def get_empty_import():


class NormalImport(ImportInfo):
def __init__(self, names_and_aliases):
def __init__(
self,
names_and_aliases: List[Tuple[str, Optional[str]]],
) -> None:
self.names_and_aliases = names_and_aliases

def get_imported_primaries(self, context):
def get_imported_primaries(self, context) -> List[str]:
result = []
for name, alias in self.names_and_aliases:
if alias:
Expand All @@ -117,7 +126,7 @@ def get_imported_primaries(self, context):
result.append(name)
return result

def get_import_statement(self):
def get_import_statement(self) -> str:
result = "import "
for name, alias in self.names_and_aliases:
result += name
Expand All @@ -131,12 +140,20 @@ def is_empty(self):


class FromImport(ImportInfo):
def __init__(self, module_name, level, names_and_aliases):
module_name: str
level: int

def __init__(
self,
module_name: str,
level: int,
names_and_aliases: List[Tuple[str, Optional[str]]],
):
self.module_name = module_name
self.level = level
self.names_and_aliases = names_and_aliases

def get_imported_primaries(self, context):
def get_imported_primaries(self, context) -> List[str]:
if self.names_and_aliases[0][0] == "*":
module = self.get_imported_module(context)
return [name for name in module if not name.startswith("_")]
Expand Down Expand Up @@ -173,7 +190,7 @@ def get_imported_module(self, context):
self.module_name, context.folder, self.level
)

def get_import_statement(self):
def get_import_statement(self) -> str:
result = "from " + "." * self.level + self.module_name + " import "
for name, alias in self.names_and_aliases:
result += name
Expand All @@ -190,14 +207,17 @@ def is_star_import(self):


class EmptyImport(ImportInfo):
names_and_aliases: List[Tuple[str, str]] = []
names_and_aliases = []

def is_empty(self):
return True

def get_imported_primaries(self, context):
def get_imported_primaries(self, context) -> List[str]:
return []

def get_import_statement(self) -> str:
raise NotImplementedError()


class ImportContext:
def __init__(self, project, folder):
Expand Down
22 changes: 22 additions & 0 deletions ropetest/reprtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from rope.contrib import findit
from rope.contrib.autoimport import models
from rope.refactor import occurrences
from rope.refactor.importutils import importinfo
from ropetest import testutils


Expand Down Expand Up @@ -165,3 +166,24 @@ def test_autoimport_models_finalquery(project, mod1):
obj = models.Package.delete_by_package_name
assert isinstance(obj, models.FinalQuery)
assert repr(obj) == expected_repr


def test_repr_normal_import(project):
obj = importinfo.NormalImport([("abc", None), ("ghi", "jkl")])
expected_repr = '<NormalImport "import abc, ghi as jkl">'
assert isinstance(obj, importinfo.NormalImport)
assert repr(obj) == expected_repr


def test_repr_from_import(project):
obj = importinfo.FromImport("pkg1.pkg2", 0, [("abc", None), ("ghi", "jkl")])
expected_repr = '<FromImport "from pkg1.pkg2 import abc, ghi as jkl">'
assert isinstance(obj, importinfo.FromImport)
assert repr(obj) == expected_repr


def test_repr_from_import_with_level(project):
obj = importinfo.FromImport("pkg1.pkg2", 3, [("abc", None), ("ghi", "jkl")])
expected_repr = '<FromImport "from ...pkg1.pkg2 import abc, ghi as jkl">'
assert isinstance(obj, importinfo.FromImport)
assert repr(obj) == expected_repr

0 comments on commit 9f146b3

Please sign in to comment.