Skip to content

Implement foundation for detecting partially defined vars #13601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@
from mypy.checker import TypeChecker
from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error
from mypy.indirection import TypeIndirectionVisitor
from mypy.messages import MessageBuilder
from mypy.nodes import Import, ImportAll, ImportBase, ImportFrom, MypyFile, SymbolTable
from mypy.partially_defined import PartiallyDefinedVariableVisitor
from mypy.semanal import SemanticAnalyzer
from mypy.semanal_pass1 import SemanticAnalyzerPreAnalysis
from mypy.util import (
Expand Down Expand Up @@ -2335,6 +2337,15 @@ def type_check_second_pass(self) -> bool:
self.time_spent_us += time_spent_us(t0)
return result

def detect_partially_defined_vars(self) -> None:
assert self.tree is not None, "Internal error: method must be called on parsed file only"
manager = self.manager
if manager.errors.is_error_code_enabled(codes.PARTIALLY_DEFINED):
manager.errors.set_file(self.xpath, self.tree.fullname, options=manager.options)
self.tree.accept(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this was the right place to plug this in. It needed to be done before self.free_state() is called.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that you could add another method, such as detect_partially_undefined_vars (feel free to alter the name) that performs the analysis if enabled, and call it from process_stale_scc after calling type_check_second_pass and before calling finish_passes. I.e., this would be an additional pass after type checking.

PartiallyDefinedVariableVisitor(MessageBuilder(manager.errors, manager.modules))
)

def finish_passes(self) -> None:
assert self.tree is not None, "Internal error: method must be called on parsed file only"
manager = self.manager
Expand Down Expand Up @@ -3364,6 +3375,7 @@ def process_stale_scc(graph: Graph, scc: list[str], manager: BuildManager) -> No
graph[id].type_check_first_pass()
if not graph[id].type_checker().deferred_nodes:
unfinished_modules.discard(id)
graph[id].detect_partially_defined_vars()
graph[id].finish_passes()

while unfinished_modules:
Expand All @@ -3372,6 +3384,7 @@ def process_stale_scc(graph: Graph, scc: list[str], manager: BuildManager) -> No
continue
if not graph[id].type_check_second_pass():
unfinished_modules.discard(id)
graph[id].detect_partially_defined_vars()
graph[id].finish_passes()
for id in stale:
graph[id].generate_unused_ignore_notes()
Expand Down
6 changes: 6 additions & 0 deletions mypy/errorcodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ def __str__(self) -> str:
UNREACHABLE: Final = ErrorCode(
"unreachable", "Warn about unreachable statements or expressions", "General"
)
PARTIALLY_DEFINED: Final[ErrorCode] = ErrorCode(
"partially-defined",
"Warn about variables that are defined only in some execution paths",
"General",
default_enabled=False,
)
REDUNDANT_EXPR: Final = ErrorCode(
"redundant-expr", "Warn about redundant expressions", "General", default_enabled=False
)
Expand Down
3 changes: 3 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,9 @@ def invalid_keyword_var_arg(self, typ: Type, is_mapping: bool, context: Context)
def undefined_in_superclass(self, member: str, context: Context) -> None:
self.fail(f'"{member}" undefined in superclass', context)

def variable_may_be_undefined(self, name: str, context: Context) -> None:
self.fail(f'Name "{name}" may be undefined', context, code=codes.PARTIALLY_DEFINED)

def first_argument_for_super_must_be_type(self, actual: Type, context: Context) -> None:
actual = get_proper_type(actual)
if isinstance(actual, Instance):
Expand Down
201 changes: 201 additions & 0 deletions mypy/partially_defined.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
from __future__ import annotations

from typing import NamedTuple

from mypy.messages import MessageBuilder
from mypy.nodes import (
AssignmentStmt,
ForStmt,
FuncDef,
FuncItem,
IfStmt,
ListExpr,
Lvalue,
NameExpr,
TupleExpr,
WhileStmt,
)
from mypy.traverser import TraverserVisitor


class DefinedVars(NamedTuple):
"""DefinedVars contains information about variable definition at the end of a branching statement.
`if` and `match` are examples of branching statements.

`may_be_defined` contains variables that were defined in only some branches.
`must_be_defined` contains variables that were defined in all branches.
"""

may_be_defined: set[str]
must_be_defined: set[str]


class BranchStatement:
def __init__(self, already_defined: DefinedVars) -> None:
self.already_defined = already_defined
self.defined_by_branch: list[DefinedVars] = [
DefinedVars(may_be_defined=set(), must_be_defined=set(already_defined.must_be_defined))
]

def next_branch(self) -> None:
self.defined_by_branch.append(
DefinedVars(
may_be_defined=set(), must_be_defined=set(self.already_defined.must_be_defined)
)
)

def record_definition(self, name: str) -> None:
assert len(self.defined_by_branch) > 0
self.defined_by_branch[-1].must_be_defined.add(name)
self.defined_by_branch[-1].may_be_defined.discard(name)

def record_nested_branch(self, vars: DefinedVars) -> None:
assert len(self.defined_by_branch) > 0
current_branch = self.defined_by_branch[-1]
current_branch.must_be_defined.update(vars.must_be_defined)
current_branch.may_be_defined.update(vars.may_be_defined)
current_branch.may_be_defined.difference_update(current_branch.must_be_defined)

def is_possibly_undefined(self, name: str) -> bool:
assert len(self.defined_by_branch) > 0
return name in self.defined_by_branch[-1].may_be_defined

def done(self) -> DefinedVars:
assert len(self.defined_by_branch) > 0
if len(self.defined_by_branch) == 1:
# If there's only one branch, then we just return current.
# Note that this case is a different case when an empty branch is omitted (e.g. `if` without `else`).
return self.defined_by_branch[0]

# must_be_defined is a union of must_be_defined of all branches.
must_be_defined = set(self.defined_by_branch[0].must_be_defined)
for branch_vars in self.defined_by_branch[1:]:
must_be_defined.intersection_update(branch_vars.must_be_defined)
# may_be_defined are all variables that are not must be defined.
all_vars = set()
for branch_vars in self.defined_by_branch:
all_vars.update(branch_vars.may_be_defined)
all_vars.update(branch_vars.must_be_defined)
may_be_defined = all_vars.difference(must_be_defined)
return DefinedVars(may_be_defined=may_be_defined, must_be_defined=must_be_defined)


class DefinedVariableTracker:
"""DefinedVariableTracker manages the state and scope for the UndefinedVariablesVisitor."""

def __init__(self) -> None:
# There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement.
self.scopes: list[list[BranchStatement]] = [
[BranchStatement(DefinedVars(may_be_defined=set(), must_be_defined=set()))]
]

def _scope(self) -> list[BranchStatement]:
assert len(self.scopes) > 0
return self.scopes[-1]

def enter_scope(self) -> None:
assert len(self._scope()) > 0
self.scopes.append([BranchStatement(self._scope()[-1].defined_by_branch[-1])])

def exit_scope(self) -> None:
self.scopes.pop()

def start_branch_statement(self) -> None:
assert len(self._scope()) > 0
self._scope().append(BranchStatement(self._scope()[-1].defined_by_branch[-1]))

def next_branch(self) -> None:
assert len(self._scope()) > 1
self._scope()[-1].next_branch()

def end_branch_statement(self) -> None:
assert len(self._scope()) > 1
result = self._scope().pop().done()
self._scope()[-1].record_nested_branch(result)

def record_declaration(self, name: str) -> None:
assert len(self.scopes) > 0
assert len(self.scopes[-1]) > 0
self._scope()[-1].record_definition(name)

def is_possibly_undefined(self, name: str) -> bool:
assert len(self._scope()) > 0
# A variable is undefined if it's in a set of `may_be_defined` but not in `must_be_defined`.
# Cases where a variable is not defined altogether are handled by semantic analyzer.
return self._scope()[-1].is_possibly_undefined(name)


class PartiallyDefinedVariableVisitor(TraverserVisitor):
"""Detect variables that are defined only part of the time.

This visitor detects the following case:
if foo():
x = 1
print(x) # Error: "x" may be undefined.

Note that this code does not detect variables not defined in any of the branches -- that is
handled by the semantic analyzer.
"""

def __init__(self, msg: MessageBuilder) -> None:
self.msg = msg
self.tracker = DefinedVariableTracker()

def process_lvalue(self, lvalue: Lvalue) -> None:
if isinstance(lvalue, NameExpr):
self.tracker.record_declaration(lvalue.name)
elif isinstance(lvalue, (ListExpr, TupleExpr)):
for item in lvalue.items:
self.process_lvalue(item)

def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
for lvalue in o.lvalues:
self.process_lvalue(lvalue)
super().visit_assignment_stmt(o)

def visit_if_stmt(self, o: IfStmt) -> None:
for e in o.expr:
e.accept(self)
self.tracker.start_branch_statement()
for b in o.body:
b.accept(self)
self.tracker.next_branch()
if o.else_body:
o.else_body.accept(self)
self.tracker.end_branch_statement()

def visit_func_def(self, o: FuncDef) -> None:
self.tracker.enter_scope()
super().visit_func_def(o)
self.tracker.exit_scope()

def visit_func(self, o: FuncItem) -> None:
if o.arguments is not None:
for arg in o.arguments:
self.tracker.record_declaration(arg.variable.name)
super().visit_func(o)

def visit_for_stmt(self, o: ForStmt) -> None:
o.expr.accept(self)
self.process_lvalue(o.index)
o.index.accept(self)
self.tracker.start_branch_statement()
o.body.accept(self)
self.tracker.next_branch()
if o.else_body:
o.else_body.accept(self)
self.tracker.end_branch_statement()

def visit_while_stmt(self, o: WhileStmt) -> None:
o.expr.accept(self)
self.tracker.start_branch_statement()
o.body.accept(self)
self.tracker.next_branch()
if o.else_body:
o.else_body.accept(self)
self.tracker.end_branch_statement()

def visit_name_expr(self, o: NameExpr) -> None:
if self.tracker.is_possibly_undefined(o.name):
self.msg.variable_may_be_undefined(o.name, o)
super().visit_name_expr(o)
1 change: 1 addition & 0 deletions mypy/server/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ def restore(ids: list[str]) -> None:
state.type_checker().reset()
state.type_check_first_pass()
state.type_check_second_pass()
state.detect_partially_defined_vars()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is necessary for mypyd?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah.

t2 = time.time()
state.finish_passes()
t3 = time.time()
Expand Down
Loading