Skip to content

Refactor: use context manager for Scope #11053

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 40 additions & 44 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,29 +297,26 @@ def check_first_pass(self) -> None:
self.recurse_into_functions = True
with state.strict_optional_set(self.options.strict_optional):
self.errors.set_file(self.path, self.tree.fullname, scope=self.tscope)
self.tscope.enter_file(self.tree.fullname)
with self.enter_partial_types():
with self.binder.top_frame_context():
with self.tscope.module_scope(self.tree.fullname):
with self.enter_partial_types(), self.binder.top_frame_context():
for d in self.tree.defs:
self.accept(d)

assert not self.current_node_deferred
assert not self.current_node_deferred

all_ = self.globals.get('__all__')
if all_ is not None and all_.type is not None:
all_node = all_.node
assert all_node is not None
seq_str = self.named_generic_type('typing.Sequence',
[self.named_type('builtins.str')])
if self.options.python_version[0] < 3:
all_ = self.globals.get('__all__')
if all_ is not None and all_.type is not None:
all_node = all_.node
assert all_node is not None
seq_str = self.named_generic_type('typing.Sequence',
[self.named_type('builtins.unicode')])
if not is_subtype(all_.type, seq_str):
str_seq_s, all_s = format_type_distinctly(seq_str, all_.type)
self.fail(message_registry.ALL_MUST_BE_SEQ_STR.format(str_seq_s, all_s),
all_node)

self.tscope.leave()
[self.named_type('builtins.str')])
if self.options.python_version[0] < 3:
seq_str = self.named_generic_type('typing.Sequence',
[self.named_type('builtins.unicode')])
if not is_subtype(all_.type, seq_str):
str_seq_s, all_s = format_type_distinctly(seq_str, all_.type)
self.fail(message_registry.ALL_MUST_BE_SEQ_STR.format(str_seq_s, all_s),
all_node)

def check_second_pass(self,
todo: Optional[Sequence[Union[DeferredNode,
Expand All @@ -334,25 +331,24 @@ def check_second_pass(self,
if not todo and not self.deferred_nodes:
return False
self.errors.set_file(self.path, self.tree.fullname, scope=self.tscope)
self.tscope.enter_file(self.tree.fullname)
self.pass_num += 1
if not todo:
todo = self.deferred_nodes
else:
assert not self.deferred_nodes
self.deferred_nodes = []
done: Set[Union[DeferredNodeType, FineGrainedDeferredNodeType]] = set()
for node, active_typeinfo in todo:
if node in done:
continue
# This is useful for debugging:
# print("XXX in pass %d, class %s, function %s" %
# (self.pass_num, type_name, node.fullname or node.name))
done.add(node)
with self.tscope.class_scope(active_typeinfo) if active_typeinfo else nothing():
with self.scope.push_class(active_typeinfo) if active_typeinfo else nothing():
self.check_partial(node)
self.tscope.leave()
with self.tscope.module_scope(self.tree.fullname):
self.pass_num += 1
if not todo:
todo = self.deferred_nodes
else:
assert not self.deferred_nodes
self.deferred_nodes = []
done: Set[Union[DeferredNodeType, FineGrainedDeferredNodeType]] = set()
for node, active_typeinfo in todo:
if node in done:
continue
# This is useful for debugging:
# print("XXX in pass %d, class %s, function %s" %
# (self.pass_num, type_name, node.fullname or node.name))
done.add(node)
with self.tscope.class_scope(active_typeinfo) if active_typeinfo else nothing():
with self.scope.push_class(active_typeinfo) if active_typeinfo else nothing():
self.check_partial(node)
return True

def check_partial(self, node: Union[DeferredNodeType, FineGrainedDeferredNodeType]) -> None:
Expand Down Expand Up @@ -874,7 +870,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])
if isinstance(typ.ret_type, TypeVarType):
if typ.ret_type.variance == CONTRAVARIANT:
self.fail(message_registry.RETURN_TYPE_CANNOT_BE_CONTRAVARIANT,
typ.ret_type)
typ.ret_type)

# Check that Generator functions have the appropriate return type.
if defn.is_generator:
Expand Down Expand Up @@ -992,7 +988,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])
self.accept(item.body)
unreachable = self.binder.is_unreachable()

if (self.options.warn_no_return and not unreachable):
if self.options.warn_no_return and not unreachable:
if (defn.is_generator or
is_named_instance(self.return_types[-1], 'typing.AwaitableGenerator')):
return_type = self.get_generator_return_type(self.return_types[-1],
Expand Down Expand Up @@ -1083,7 +1079,7 @@ def is_unannotated_any(t: Type) -> bool:
code=codes.NO_UNTYPED_DEF)
elif fdef.is_generator:
if is_unannotated_any(self.get_generator_return_type(ret_type,
fdef.is_coroutine)):
fdef.is_coroutine)):
self.fail(message_registry.RETURN_TYPE_EXPECTED, fdef,
code=codes.NO_UNTYPED_DEF)
elif fdef.is_coroutine and isinstance(ret_type, Instance):
Expand Down Expand Up @@ -2641,8 +2637,7 @@ def check_rvalue_count_in_assignment(self, lvalues: List[Lvalue], rvalue_count:
len(lvalues) - 1, context)
return False
elif rvalue_count != len(lvalues):
self.msg.wrong_number_values_to_unpack(rvalue_count,
len(lvalues), context)
self.msg.wrong_number_values_to_unpack(rvalue_count, len(lvalues), context)
return False
return True

Expand Down Expand Up @@ -2896,8 +2891,7 @@ def check_lvalue(self, lvalue: Lvalue) -> Tuple[Optional[Type],
elif isinstance(lvalue, IndexExpr):
index_lvalue = lvalue
elif isinstance(lvalue, MemberExpr):
lvalue_type = self.expr_checker.analyze_ordinary_member_access(lvalue,
True)
lvalue_type = self.expr_checker.analyze_ordinary_member_access(lvalue, True)
self.store_type(lvalue, lvalue_type)
elif isinstance(lvalue, NameExpr):
lvalue_type = self.expr_checker.analyze_ref_expr(lvalue, lvalue=True)
Expand Down Expand Up @@ -4144,6 +4138,7 @@ def is_type_call(expr: CallExpr) -> bool:
"""Is expr a call to type with one argument?"""
return (refers_to_fullname(expr.callee, 'builtins.type')
and len(expr.args) == 1)

# exprs that are being passed into type
exprs_in_type_calls: List[Expression] = []
# type that is being compared to type(expr)
Expand Down Expand Up @@ -4194,6 +4189,7 @@ def combine_maps(list_maps: List[TypeMap]) -> TypeMap:
if d is not None:
result_map.update(d)
return result_map

if_map = combine_maps(if_maps)
# type(x) == T is only true when x has the same type as T, meaning
# that it can be false if x is an instance of a subclass of T. That means
Expand Down
10 changes: 0 additions & 10 deletions mypy/nullcontext.py

This file was deleted.

Empty file removed mypy/ordered_dict.py
Empty file.
66 changes: 30 additions & 36 deletions mypy/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from contextlib import contextmanager
from typing import List, Optional, Iterator, Tuple

from mypy.backports import nullcontext
from mypy.nodes import TypeInfo, FuncBase


Expand Down Expand Up @@ -51,18 +52,30 @@ def current_function_name(self) -> Optional[str]:
"""Return the current function's short name if it exists"""
return self.function.name if self.function else None

def enter_file(self, prefix: str) -> None:
@contextmanager
def module_scope(self, prefix: str) -> Iterator[None]:
self.module = prefix
self.classes = []
self.function = None
self.ignored = 0
yield
assert self.module
self.module = None

def enter_function(self, fdef: FuncBase) -> None:
@contextmanager
def function_scope(self, fdef: FuncBase) -> Iterator[None]:
if not self.function:
self.function = fdef
else:
# Nested functions are part of the topmost function target.
self.ignored += 1
yield
if self.ignored:
# Leave a scope that's included in the enclosing target.
self.ignored -= 1
else:
assert self.function
self.function = None

def enter_class(self, info: TypeInfo) -> None:
"""Enter a class target scope."""
Expand All @@ -72,53 +85,34 @@ def enter_class(self, info: TypeInfo) -> None:
# Classes within functions are part of the enclosing function target.
self.ignored += 1

def leave(self) -> None:
"""Leave the innermost scope (can be any kind of scope)."""
def leave_class(self) -> None:
"""Leave a class target scope."""
if self.ignored:
# Leave a scope that's included in the enclosing target.
self.ignored -= 1
elif self.function:
# Function is always the innermost target.
self.function = None
elif self.classes:
else:
assert self.classes
# Leave the innermost class.
self.classes.pop()
else:
# Leave module.
assert self.module
self.module = None

@contextmanager
def class_scope(self, info: TypeInfo) -> Iterator[None]:
self.enter_class(info)
yield
self.leave_class()

def save(self) -> SavedScope:
"""Produce a saved scope that can be entered with saved_scope()"""
assert self.module
# We only save the innermost class, which is sufficient since
# the rest are only needed for when classes are left.
cls = self.classes[-1] if self.classes else None
return (self.module, cls, self.function)

@contextmanager
def function_scope(self, fdef: FuncBase) -> Iterator[None]:
self.enter_function(fdef)
yield
self.leave()

@contextmanager
def class_scope(self, info: TypeInfo) -> Iterator[None]:
self.enter_class(info)
yield
self.leave()
return self.module, cls, self.function

@contextmanager
def saved_scope(self, saved: SavedScope) -> Iterator[None]:
module, info, function = saved
self.enter_file(module)
if info:
self.enter_class(info)
if function:
self.enter_function(function)
yield
if function:
self.leave()
if info:
self.leave()
self.leave()
with self.module_scope(module):
with self.class_scope(info) if info else nullcontext():
with self.function_scope(function) if function else nullcontext():
yield
59 changes: 29 additions & 30 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,36 +529,35 @@ def file_context(self,
self.errors.set_file(file_node.path, file_node.fullname, scope=scope)
self.cur_mod_node = file_node
self.cur_mod_id = file_node.fullname
scope.enter_file(self.cur_mod_id)
self._is_stub_file = file_node.path.lower().endswith('.pyi')
self._is_typeshed_stub_file = is_typeshed_file(file_node.path)
self.globals = file_node.names
self.tvar_scope = TypeVarLikeScope()

self.named_tuple_analyzer = NamedTupleAnalyzer(options, self)
self.typed_dict_analyzer = TypedDictAnalyzer(options, self, self.msg)
self.enum_call_analyzer = EnumCallAnalyzer(options, self)
self.newtype_analyzer = NewTypeAnalyzer(options, self, self.msg)

# Counter that keeps track of references to undefined things potentially caused by
# incomplete namespaces.
self.num_incomplete_refs = 0

if active_type:
self.incomplete_type_stack.append(False)
scope.enter_class(active_type)
self.enter_class(active_type.defn.info)
for tvar in active_type.defn.type_vars:
self.tvar_scope.bind_existing(tvar)

yield

if active_type:
scope.leave()
self.leave_class()
self.type = None
self.incomplete_type_stack.pop()
scope.leave()
with scope.module_scope(self.cur_mod_id):
self._is_stub_file = file_node.path.lower().endswith('.pyi')
self._is_typeshed_stub_file = is_typeshed_file(file_node.path)
self.globals = file_node.names
self.tvar_scope = TypeVarLikeScope()

self.named_tuple_analyzer = NamedTupleAnalyzer(options, self)
self.typed_dict_analyzer = TypedDictAnalyzer(options, self, self.msg)
self.enum_call_analyzer = EnumCallAnalyzer(options, self)
self.newtype_analyzer = NewTypeAnalyzer(options, self, self.msg)

# Counter that keeps track of references to undefined things potentially caused by
# incomplete namespaces.
self.num_incomplete_refs = 0

if active_type:
self.incomplete_type_stack.append(False)
scope.enter_class(active_type)
self.enter_class(active_type.defn.info)
for tvar in active_type.defn.type_vars:
self.tvar_scope.bind_existing(tvar)

yield

if active_type:
scope.leave_class()
self.leave_class()
self.type = None
self.incomplete_type_stack.pop()
del self.options

#
Expand Down
5 changes: 2 additions & 3 deletions mypy/semanal_typeargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ def __init__(self, errors: Errors, options: Options, is_typeshed_file: bool) ->

def visit_mypy_file(self, o: MypyFile) -> None:
self.errors.set_file(o.path, o.fullname, scope=self.scope)
self.scope.enter_file(o.fullname)
super().visit_mypy_file(o)
self.scope.leave()
with self.scope.module_scope(o.fullname):
super().visit_mypy_file(o)

def visit_func(self, defn: FuncItem) -> None:
if not self.recurse_into_functions:
Expand Down
Loading