Skip to content

Make parent patching occur only on imports #2061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 29, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,19 +1399,6 @@ def parse_file(self) -> None:
self.dep_line_map = dep_line_map
self.check_blockers()

def patch_parent(self) -> None:
# Include module in the symbol table of the enclosing package.
if '.' not in self.id:
return
manager = self.manager
modules = manager.modules
parent, child = self.id.rsplit('.', 1)
if parent in modules:
manager.trace("Added %s.%s" % (parent, child))
modules[parent].names[child] = SymbolTableNode(MODULE_REF, self.tree, parent)
else:
manager.log("Hm... couldn't add %s.%s" % (parent, child))

def semantic_analysis(self) -> None:
with self.wrap_context():
self.manager.semantic_analyzer.visit_file(self.tree, self.xpath)
Expand Down Expand Up @@ -1682,8 +1669,6 @@ def process_fresh_scc(graph: Graph, scc: List[str]) -> None:
"""Process the modules in one SCC from their cached data."""
for id in scc:
graph[id].load_tree()
for id in scc:
graph[id].patch_parent()
for id in scc:
graph[id].fix_cross_refs()
for id in scc:
Expand All @@ -1697,8 +1682,6 @@ def process_stale_scc(graph: Graph, scc: List[str]) -> None:
# If the former, parse_file() is a no-op.
graph[id].parse_file()
graph[id].fix_suppressed_dependencies(graph)
for id in scc:
graph[id].patch_parent()
for id in scc:
graph[id].semantic_analysis()
for id in scc:
Expand Down
5 changes: 2 additions & 3 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from mypy.constraints import get_actual_type
from mypy.checkstrformat import StringFormatterChecker
from mypy.expandtype import expand_type
import mypy.checkexpr

from mypy import experiments

Expand All @@ -61,15 +60,15 @@ class ExpressionChecker:
# This is shared with TypeChecker, but stored also here for convenience.
msg = None # type: MessageBuilder

strfrm_checker = None # type: mypy.checkstrformat.StringFormatterChecker
strfrm_checker = None # type: StringFormatterChecker

def __init__(self,
chk: 'mypy.checker.TypeChecker',
msg: MessageBuilder) -> None:
"""Construct an expression type checker."""
self.chk = chk
self.msg = msg
self.strfrm_checker = mypy.checkexpr.StringFormatterChecker(self, self.chk, self.msg)
self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg)

def visit_name_expr(self, e: NameExpr) -> Type:
"""Type check a name expression.
Expand Down
1 change: 1 addition & 0 deletions mypy/checkstrformat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
if False:
# break import cycle only needed for mypy
import mypy.checker
import mypy.checkexpr
from mypy import messages
from mypy.messages import MessageBuilder

Expand Down
36 changes: 36 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,30 @@ def visit_import(self, i: Import) -> None:
base = id.split('.')[0]
self.add_module_symbol(base, base, module_public=module_public,
context=i)
self.add_submodules_to_parent_modules(id, module_public)

def add_submodules_to_parent_modules(self, id: str, module_public: bool) -> None:
"""Recursively adds a reference to a newly loaded submodule to its parent.

When you import a submodule in any way, Python will add a reference to that
submodule to its parent. So, if you do something like `import A.B` or
`from A import B` or `from A.B import Foo`, Python will add a reference to
module A.B to A's namespace.

Note that this "parent patching" process is completely independent from any
changes made to the *importer's* namespace. For example, if you have a file
named `foo.py` where you do `from A.B import Bar`, then foo's namespace will
be modified to contain a reference to only Bar. Independently, A's namespace
will be modified to contain a reference to `A.B`.
"""
while '.' in id:
parent, child = id.rsplit('.', 1)
modules_loaded = parent in self.modules and id in self.modules
if modules_loaded and child not in self.modules[parent].names:
sym = SymbolTableNode(MODULE_REF, self.modules[id], parent,
module_public=module_public)
self.modules[parent].names[child] = sym
id = parent

def add_module_symbol(self, id: str, as_id: str, module_public: bool,
context: Context) -> None:
Expand All @@ -908,8 +932,19 @@ def visit_import_from(self, imp: ImportFrom) -> None:
import_id = self.correct_relative_import(imp)
if import_id in self.modules:
module = self.modules[import_id]
self.add_submodules_to_parent_modules(import_id, True)
for id, as_id in imp.names:
node = module.names.get(id)

# If the module does not contain a symbol with the name 'id',
# try checking if it's a module instead.
if id not in module.names or node.kind == UNBOUND_IMPORTED:
possible_module_id = import_id + '.' + id
mod = self.modules.get(possible_module_id)
if mod is not None:
node = SymbolTableNode(MODULE_REF, mod, import_id)
self.add_submodules_to_parent_modules(possible_module_id, True)

if node and node.kind != UNBOUND_IMPORTED:
node = self.normalize_type_alias(node, imp)
if not node:
Expand Down Expand Up @@ -991,6 +1026,7 @@ def visit_import_all(self, i: ImportAll) -> None:
i_id = self.correct_relative_import(i)
if i_id in self.modules:
m = self.modules[i_id]
self.add_submodules_to_parent_modules(i_id, True)
for name, node in m.names.items():
node = self.normalize_type_alias(node, i)
if not name.startswith('_') and node.module_public:
Expand Down
163 changes: 163 additions & 0 deletions test-data/unit/check-modules.test
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,169 @@ import m.a
[out]


-- Checks dealing with submodules and different kinds of imports
-- -------------------------------------------------------------

[case testSubmoduleRegularImportAddsAllParents]
import a.b.c
reveal_type(a.value) # E: Revealed type is 'builtins.int'
reveal_type(a.b.value) # E: Revealed type is 'builtins.str'
reveal_type(a.b.c.value) # E: Revealed type is 'builtins.float'
b.value # E: Name 'b' is not defined
c.value # E: Name 'c' is not defined

[file a/__init__.py]
value = 3
[file a/b/__init__.py]
value = "a"
[file a/b/c.py]
value = 3.2
[out]

[case testSubmoduleImportAsDoesNotAddParents]
import a.b.c as foo
reveal_type(foo.value) # E: Revealed type is 'builtins.float'
a.value # E: Name 'a' is not defined
b.value # E: Name 'b' is not defined
c.value # E: Name 'c' is not defined

[file a/__init__.py]
value = 3
[file a/b/__init__.py]
value = "a"
[file a/b/c.py]
value = 3.2
[out]

[case testSubmoduleImportFromDoesNotAddParents]
from a import b
reveal_type(b.value) # E: Revealed type is 'builtins.str'
b.c.value # E: "module" has no attribute "c"
a.value # E: Name 'a' is not defined

[file a/__init__.py]
value = 3
[file a/b/__init__.py]
value = "a"
[file a/b/c.py]
value = 3.2
[builtins fixtures/module.pyi]
[out]

[case testSubmoduleImportFromDoesNotAddParents2]
from a.b import c
reveal_type(c.value) # E: Revealed type is 'builtins.float'
a.value # E: Name 'a' is not defined
b.value # E: Name 'b' is not defined

[file a/__init__.py]
value = 3
[file a/b/__init__.py]
value = "a"
[file a/b/c.py]
value = 3.2
[out]

[case testSubmoduleRegularImportNotDirectlyAddedToParent]
import a.b.c
def accept_float(x: float) -> None: pass
accept_float(a.b.c.value)

[file a/__init__.py]
value = 3
b.value
a.b.value

[file a/b/__init__.py]
value = "a"
c.value
a.b.c.value

[file a/b/c.py]
value = 3.2
[out]
main:1: note: In module imported here:
tmp/a/b/__init__.py:2: error: Name 'c' is not defined
tmp/a/b/__init__.py:3: error: Name 'a' is not defined
tmp/a/__init__.py:2: error: Name 'b' is not defined
tmp/a/__init__.py:3: error: Name 'a' is not defined

[case testSubmoduleMixingLocalAndQualifiedNames]
from a.b import MyClass
val1 = None # type: a.b.MyClass # E: Name 'a' is not defined
val2 = None # type: MyClass

[file a/__init__.py]
[file a/b.py]
class MyClass: pass
[out]

[case testSubmoduleMixingImportFrom]
import parent.child

[file parent/__init__.py]

[file parent/common.py]
class SomeClass: pass

[file parent/child.py]
from parent.common import SomeClass
from parent import common
foo = parent.common.SomeClass()

[builtins fixtures/module.pyi]
[out]
main:1: note: In module imported here:
tmp/parent/child.py:3: error: Name 'parent' is not defined

[case testSubmoduleMixingImportFromAndImport]
import parent.child

[file parent/__init__.py]

[file parent/common.py]
class SomeClass: pass

[file parent/unrelated.py]
class ShouldNotLoad: pass

[file parent/child.py]
from parent.common import SomeClass
import parent

# Note, since this might be unintuitive -- when `parent.common` is loaded in any way,
# shape, or form, it's added to `parent`'s namespace, which is why the below line
# succeeds.
foo = parent.common.SomeClass()
reveal_type(foo)
bar = parent.unrelated.ShouldNotLoad()

[builtins fixtures/module.pyi]
[out]
main:1: note: In module imported here:
tmp/parent/child.py:8: error: Revealed type is 'parent.common.SomeClass'
tmp/parent/child.py:9: error: "module" has no attribute "unrelated"

[case testSubmoduleMixingImportFromAndImport2]
import parent.child

[file parent/__init__.py]

[file parent/common.py]
class SomeClass: pass

[file parent/child.py]
from parent import common
import parent
foo = parent.common.SomeClass()
reveal_type(foo)

[builtins fixtures/module.pyi]
[out]
main:1: note: In module imported here:
tmp/parent/child.py:4: error: Revealed type is 'parent.common.SomeClass'


-- Misc

[case testInheritFromBadImport]
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/module.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ class type: pass
class function: pass
class int: pass
class str: pass
class bool: pass