Skip to content

Commit

Permalink
Add type annotations for importinfo.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lieryan committed Apr 4, 2024
1 parent ba54cab commit 7d5ffbe
Showing 1 changed file with 33 additions and 14 deletions.
47 changes: 33 additions & 14 deletions rope/refactor/importutils/importinfo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Tuple
from abc import abstractmethod, ABC
from typing import List, Tuple, Optional, Protocol


class ImportStatement:
Expand Down Expand Up @@ -64,9 +65,12 @@ def accept(self, visitor):
return visitor.dispatch(self)


class ImportInfo:
def get_imported_primaries(self, context):
pass
class ImportInfo(ABC):
names_and_aliases: List[tuple[str, Optional[str]]]

@abstractmethod
def get_imported_primaries(self, context) -> List[str]:
...

def get_imported_names(self, context):
return [
Expand All @@ -76,8 +80,9 @@ def get_imported_names(self, context):
def __repr__(self):
return f'<{self.__class__.__name__} "{self.get_import_statement()}">'

def get_import_statement(self):
pass
@abstractmethod
def get_import_statement(self) -> str:
...

def is_empty(self):
pass
Expand Down Expand Up @@ -108,10 +113,13 @@ def get_empty_import():


class NormalImport(ImportInfo):
def __init__(self, names_and_aliases):
def __init__(
self,
names_and_aliases: List[tuple[str, Optional[str]]],
) -> None:
self.names_and_aliases = names_and_aliases

def get_imported_primaries(self, context):
def get_imported_primaries(self, context) -> List[str]:
result = []
for name, alias in self.names_and_aliases:
if alias:
Expand All @@ -120,7 +128,7 @@ def get_imported_primaries(self, context):
result.append(name)
return result

def get_import_statement(self):
def get_import_statement(self) -> str:
result = "import "
for name, alias in self.names_and_aliases:
result += name
Expand All @@ -134,12 +142,20 @@ def is_empty(self):


class FromImport(ImportInfo):
def __init__(self, module_name, level, names_and_aliases):
module_name: str
level: int

def __init__(
self,
module_name: str,
level: int,
names_and_aliases: List[tuple[str, Optional[str]]],
):
self.module_name = module_name
self.level = level
self.names_and_aliases = names_and_aliases

def get_imported_primaries(self, context):
def get_imported_primaries(self, context) -> List[str]:
if self.names_and_aliases[0][0] == "*":
module = self.get_imported_module(context)
return [name for name in module if not name.startswith("_")]
Expand Down Expand Up @@ -176,7 +192,7 @@ def get_imported_module(self, context):
self.module_name, context.folder, self.level
)

def get_import_statement(self):
def get_import_statement(self) -> str:
result = "from " + "." * self.level + self.module_name + " import "
for name, alias in self.names_and_aliases:
result += name
Expand All @@ -193,14 +209,17 @@ def is_star_import(self):


class EmptyImport(ImportInfo):
names_and_aliases: List[Tuple[str, str]] = []
names_and_aliases = []

def is_empty(self):
return True

def get_imported_primaries(self, context):
def get_imported_primaries(self, context) -> List[str]:
return []

def get_import_statement(self) -> str:
raise NotImplementedError()


class ImportContext:
def __init__(self, project, folder):
Expand Down

0 comments on commit 7d5ffbe

Please sign in to comment.