diff --git a/Lib/modulefinder.py b/Lib/modulefinder.py --- a/Lib/modulefinder.py +++ b/Lib/modulefinder.py @@ -67,7 +67,8 @@ class Module: class ModuleFinder: - def __init__(self, path=None, debug=0, excludes=[], replace_paths=[]): + def __init__(self, path=None, debug=0, excludes=[], replace_paths=[], + recurse=True): if path is None: path = sys.path self.path = path @@ -78,6 +79,7 @@ class ModuleFinder: self.excludes = excludes self.replace_paths = replace_paths self.processed_paths = [] # Used in debugging only + self.recurse = recurse def msg(self, level, str, *args): if level <= self.debug: @@ -432,9 +434,10 @@ class ModuleFinder: # We don't expect anything else from the generator. raise RuntimeError(what) - for c in co.co_consts: - if isinstance(c, type(co)): - self.scan_code(c, m) + if self.recurse: + for c in co.co_consts: + if isinstance(c, type(co)): + self.scan_code(c, m) def load_package(self, fqname, pathname): self.msgin(2, "load_package", fqname, pathname) diff --git a/Lib/test/test_modulefinder.py b/Lib/test/test_modulefinder.py --- a/Lib/test/test_modulefinder.py +++ b/Lib/test/test_modulefinder.py @@ -196,6 +196,21 @@ a/module.py from . import bar """] +non_recursive_import_test_1 = [ + "a", + ["a", "b"], + [], + [], + """\ +a.py + import b +b.py + import c +c.py + import sys +"""] + + def open_file(path): dirname = os.path.dirname(path) @@ -223,11 +238,11 @@ def create_package(source): class ModuleFinderTest(unittest.TestCase): - def _do_test(self, info, report=False): + def _do_test(self, info, report=False, **kwargs): import_this, modules, missing, maybe_missing, source = info create_package(source) try: - mf = modulefinder.ModuleFinder(path=TEST_PATH) + mf = modulefinder.ModuleFinder(path=TEST_PATH, **kwargs) mf.import_hook(import_this) if report: mf.report() @@ -273,6 +288,9 @@ class ModuleFinderTest(unittest.TestCase def test_relative_imports_3(self): self._do_test(relative_import_test_3) + def test_non_recursive_1(self): + self._do_test(non_recursive_import_test_1, recurse=False, debug=10) + def test_main(): support.run_unittest(ModuleFinderTest)