diff --git a/README.rst b/README.rst index 274a3e4..a46da95 100644 --- a/README.rst +++ b/README.rst @@ -112,6 +112,8 @@ Below is the full listing of options:: this only triggers if there is only one star import in the file; this is skipped if there are any uses of `__all__` or `del` in the file + --populate-all populate `__all__` with unused import found in the + code. --remove-all-unused-imports remove all unused imports (not just those from the standard library) diff --git a/autoflake.py b/autoflake.py index 46d7a88..27274fe 100755 --- a/autoflake.py +++ b/autoflake.py @@ -294,7 +294,8 @@ def break_up_import(line): def filter_code(source, additional_imports=None, expand_star_imports=False, remove_all_unused_imports=False, - remove_unused_variables=False): + remove_unused_variables=False, + populate_dunder_all=False): """Yield code with unused imports removed.""" imports = SAFE_IMPORTS if additional_imports: @@ -335,6 +336,10 @@ def filter_code(source, additional_imports=None, else: marked_variable_line_numbers = frozenset() + if populate_dunder_all: + marked_import_line_numbers = frozenset() + source = populate_dunder_all_with_modules(source, marked_unused_module) + sio = io.StringIO(source) previous_line = '' for line_number, line in enumerate(sio.readlines(), start=1): @@ -478,6 +483,49 @@ def filter_useless_pass(source): yield line +def populate_dunder_all_with_modules(source, marked_unused_module): + """Return source with `__all__` properly populated.""" + if re.search(r'\b__all__\b', source): + # If there are existing `__all__`, don't mess with it. + return source + + insert_position = len(source) + end_position = -1 + all_modules = [] + + for modules in marked_unused_module.values(): + # Get the imported name, `a.b.Foo` -> Foo + all_modules += [get_imported_name(name) for name in modules] + + if all_modules: + new_all_syntax = '__all__ = ' + str(all_modules) + return ( + source[:insert_position] + + new_all_syntax + + source[end_position:] + ) + else: + return source + + +def get_imported_name(module): + """Return only imported name from pyflakes full module path. + + Example: + - `a.b.Foo` -> `Foo` + - `a as b` -> b + + """ + if '.' in module: + name = module.split('.')[-1] + elif re.search(r'\bas\b', module): + name = re.split(r'\bas\b', module)[-1] + else: + name = module + # str() to force python 2 to not use unicode + return str(name.strip()) + + def get_indentation(line): """Return leading whitespace.""" if line.strip(): @@ -497,7 +545,8 @@ def get_line_ending(line): def fix_code(source, additional_imports=None, expand_star_imports=False, - remove_all_unused_imports=False, remove_unused_variables=False): + remove_all_unused_imports=False, remove_unused_variables=False, + populate_dunder_all=False): """Return code with all filtering run on it.""" if not source: return source @@ -515,9 +564,10 @@ def fix_code(source, additional_imports=None, expand_star_imports=False, additional_imports=additional_imports, expand_star_imports=expand_star_imports, remove_all_unused_imports=remove_all_unused_imports, - remove_unused_variables=remove_unused_variables)))) + remove_unused_variables=remove_unused_variables, + populate_dunder_all=populate_dunder_all)))) - if filtered_source == source: + if filtered_source == source or populate_dunder_all: break source = filtered_source @@ -537,7 +587,9 @@ def fix_file(filename, args, standard_out): additional_imports=args.imports.split(',') if args.imports else None, expand_star_imports=args.expand_star_imports, remove_all_unused_imports=args.remove_all_unused_imports, - remove_unused_variables=args.remove_unused_variables) + remove_unused_variables=args.remove_unused_variables, + populate_dunder_all=args.populate_modules_dunder_all, + ) if original_source != filtered_source: if args.in_place: @@ -692,6 +744,9 @@ def _main(argv, standard_out, standard_error): 'one star import in the file; this is skipped if ' 'there are any uses of `__all__` or `del` in the ' 'file') + parser.add_argument('--populate-modules-dunder-all', action='store_true', + help='populate `__all__` with unused import found in ' + 'the code.') parser.add_argument('--remove-all-unused-imports', action='store_true', help='remove all unused imports (not just those from ' 'the standard library)') diff --git a/test_autoflake.py b/test_autoflake.py index ca54190..a97667d 100755 --- a/test_autoflake.py +++ b/test_autoflake.py @@ -485,6 +485,72 @@ def foo(): """ self.assertEqual(line, ''.join(autoflake.filter_code(line))) + def test_filter_code_populate_dunder_all(self): + self.assertEqual(""" +import math +import sys +__all__ = ['math', 'sys'] +""", ''.join(autoflake.filter_code(""" +import math +import sys +""", populate_dunder_all=True))) + + def test_filter_code_populate_dunder_all_should_not_create_a_mess(self): + code = """ +import math +import sys +__all__ = [ + 'math', 'sys' +] +import abc +""" + self.assertEqual( + code, + ''.join(autoflake.filter_code(code, populate_dunder_all=True))) + + def test_filter_code_populate_dunder_all_should_ignore_dotted_import(self): + code = """ +import foo.bar +""" + self.assertEqual( + code, + ''.join(autoflake.filter_code(code, populate_dunder_all=True))) + + def test_filter_code_populate_dunder_all_from_import(self): + self.assertEqual(""" +from a.b import Foo +from a.c import Bar +__all__ = ['Foo', 'Bar'] +""", ''.join(autoflake.filter_code(""" +from a.b import Foo +from a.c import Bar +""", populate_dunder_all=True))) + + def test_filter_code_populate_dunder_all_as(self): + self.assertEqual(""" +import math as m +__all__ = ['m'] +""", ''.join(autoflake.filter_code(""" +import math as m +""", populate_dunder_all=True))) + + def test_filter_code_populate_dunder_all_with_tab(self): + self.assertEqual(""" +import math\tas\tm +__all__ = ['m'] +""", ''.join(autoflake.filter_code(""" +import math\tas\tm +""", populate_dunder_all=True))) + + def test_filter_code_populate_dunder_all_with_no_change(self): + code = """ +def foo(): + bar = 0 +""" + self.assertEqual( + code, + ''.join(autoflake.filter_code(code, populate_dunder_all=True))) + def test_fix_code(self): self.assertEqual( """\ @@ -987,7 +1053,7 @@ def test_exclude(self): temp_directory = tempfile.mkdtemp(dir='.') try: with open(os.path.join(temp_directory, 'a.py'), 'w') as output: - output.write("import re\n") + output.write('import re\n') os.mkdir(os.path.join(temp_directory, 'd')) with open(os.path.join(temp_directory, 'd', 'b.py'), @@ -995,8 +1061,8 @@ def test_exclude(self): output.write('import os\n') p = subprocess.Popen(list(AUTOFLAKE_COMMAND) + - [temp_directory, '--recursive', '--exclude=a*'], - stdout=subprocess.PIPE) + [temp_directory, '--recursive', '--exclude=a*'], + stdout=subprocess.PIPE) result = p.communicate()[0].decode('utf-8') self.assertNotIn('import re', result) diff --git a/test_fuzz.py b/test_fuzz.py index 7fd8246..d2a5656 100755 --- a/test_fuzz.py +++ b/test_fuzz.py @@ -149,6 +149,10 @@ def process_args(): parser.add_argument('--imports', help='pass to the autoflake "--imports" option') + parser.add_argument('--populate-modules-dunder-all', action='store_true', + help='populate `__all__` with unused import found in ' + 'the code.') + parser.add_argument('--remove-all-unused-imports', action='store_true', help='pass "--remove-all-unused-imports" option to ' 'autoflake') @@ -190,6 +194,9 @@ def check(args): if args.remove_unused_variables: options.append('--remove-unused-variables') + if args.populate_modules_dunder_all: + options.append('--populate-modules-dunder-all') + filenames = dir_paths completed_filenames = set()