Skip to content

Commit 5f1c4d7

Browse files
committed
Implement foundation for detecting partially defined vars
1 parent dfbaff7 commit 5f1c4d7

File tree

6 files changed

+319
-0
lines changed

6 files changed

+319
-0
lines changed

mypy/build.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,11 @@
4747
from mypy.checker import TypeChecker
4848
from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error
4949
from mypy.indirection import TypeIndirectionVisitor
50+
from mypy.messages import MessageBuilder
5051
from mypy.nodes import Import, ImportAll, ImportBase, ImportFrom, MypyFile, SymbolTable
5152
from mypy.semanal import SemanticAnalyzer
5253
from mypy.semanal_pass1 import SemanticAnalyzerPreAnalysis
54+
from mypy.undefined_vars import UndefinedVariableVisitor
5355
from mypy.util import (
5456
DecodeError,
5557
decode_python_encoding,
@@ -2340,6 +2342,11 @@ def finish_passes(self) -> None:
23402342
manager = self.manager
23412343
if self.options.semantic_analysis_only:
23422344
return
2345+
if manager.options.disallow_undefined_vars:
2346+
manager.errors.set_file(self.xpath, self.tree.fullname, options=manager.options)
2347+
self.tree.accept(
2348+
UndefinedVariableVisitor(MessageBuilder(manager.errors, manager.modules))
2349+
)
23432350
t0 = time_ref()
23442351
with self.wrap_context():
23452352
# Some tests (and tools) want to look at the set of all types.

mypy/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,8 @@ def add_invertible_flag(
823823
help="Make arguments prepended via Concatenate be truly positional-only",
824824
group=strictness_group,
825825
)
826+
# Experiment flag to detect undefined variables being used.
827+
add_invertible_flag("--disallow-undefined-vars", default=False, help=argparse.SUPPRESS)
826828

827829
strict_help = "Strict mode; enables the following flags: {}".format(
828830
", ".join(strict_flag_names)

mypy/messages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,6 +1216,9 @@ def invalid_keyword_var_arg(self, typ: Type, is_mapping: bool, context: Context)
12161216
def undefined_in_superclass(self, member: str, context: Context) -> None:
12171217
self.fail(f'"{member}" undefined in superclass', context)
12181218

1219+
def variable_may_be_undefined(self, name: str, context: Context) -> None:
1220+
self.fail(f'Name "{name}" may be undefined', context)
1221+
12191222
def first_argument_for_super_must_be_type(self, actual: Type, context: Context) -> None:
12201223
actual = get_proper_type(actual)
12211224
if isinstance(actual, Instance):

mypy/options.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@ def __init__(self) -> None:
180180
# Make arguments prepended via Concatenate be truly positional-only.
181181
self.strict_concatenate = False
182182

183+
# Disallow using vars that could be undefined.
184+
self.disallow_undefined_vars = False
185+
183186
# Report an error for any branches inferred to be unreachable as a result of
184187
# type analysis.
185188
self.warn_unreachable = False

mypy/undefined_vars.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from __future__ import annotations
2+
3+
from typing import NamedTuple
4+
5+
from mypy.messages import MessageBuilder
6+
from mypy.nodes import (
7+
AssignmentStmt,
8+
FuncDef,
9+
FuncItem,
10+
IfStmt,
11+
ListExpr,
12+
Lvalue,
13+
NameExpr,
14+
TupleExpr,
15+
WhileStmt,
16+
)
17+
from mypy.traverser import TraverserVisitor
18+
19+
20+
class DefinedVariables(NamedTuple):
21+
"""DefinedVariables contains information about variable definition at the end of a branching statement.
22+
`if` and `match` are examples of branching statements.
23+
24+
`may_be_defined` contains variables that were defined in only some branches.
25+
`must_be_defined` contains variables that were defined in all branches.
26+
"""
27+
28+
may_be_defined: set[str]
29+
must_be_defined: set[str]
30+
31+
32+
class BranchingTracker:
33+
def __init__(self) -> None:
34+
self.defined_by_branch: list[DefinedVariables] = [
35+
DefinedVariables(may_be_defined=set(), must_be_defined=set())
36+
]
37+
38+
def next_branch(self) -> None:
39+
self.defined_by_branch.append(
40+
DefinedVariables(may_be_defined=set(), must_be_defined=set())
41+
)
42+
43+
def record_definition(self, name: str) -> None:
44+
assert len(self.defined_by_branch) > 0
45+
self.defined_by_branch[-1].must_be_defined.add(name)
46+
self.defined_by_branch[-1].may_be_defined.discard(name)
47+
48+
def record_nested_branch(self, vars: DefinedVariables) -> None:
49+
assert len(self.defined_by_branch) > 0
50+
current_branch = self.defined_by_branch[-1]
51+
current_branch.must_be_defined.update(vars.must_be_defined)
52+
current_branch.may_be_defined.update(vars.may_be_defined)
53+
current_branch.may_be_defined.difference_update(current_branch.must_be_defined)
54+
55+
def is_possibly_undefined(self, name: str) -> bool:
56+
assert len(self.defined_by_branch) > 0
57+
return name in self.defined_by_branch[-1].may_be_defined
58+
59+
def done(self) -> DefinedVariables:
60+
assert len(self.defined_by_branch) > 0
61+
if len(self.defined_by_branch) == 1:
62+
# If there's only one branch, then we just return current.
63+
# Note that this case is a different case when an empty branch is omitted (e.g. `if` without `else`).
64+
return self.defined_by_branch[0]
65+
66+
# must_be_defined is a union of must_be_defined of all branches.
67+
must_be_defined = set(self.defined_by_branch[0].must_be_defined)
68+
for branch_vars in self.defined_by_branch[1:]:
69+
must_be_defined.intersection_update(branch_vars.must_be_defined)
70+
# may_be_defined are all variables that are not must be defined.
71+
all_vars = set()
72+
for branch_vars in self.defined_by_branch:
73+
all_vars.update(branch_vars.may_be_defined)
74+
all_vars.update(branch_vars.must_be_defined)
75+
may_be_defined = all_vars.difference(must_be_defined)
76+
return DefinedVariables(may_be_defined=may_be_defined, must_be_defined=must_be_defined)
77+
78+
79+
class DefinedVariableTracker:
80+
"""DefinedVariableTracker manages the state and scope for the UndefinedVariablesVisitor."""
81+
82+
def __init__(self) -> None:
83+
# todo(stas): we should initialize this with some variables.
84+
# There's always at least one scope. Within each scope, there's at least one "global" BranchingTracker.
85+
self.scopes: list[list[BranchingTracker]] = [[BranchingTracker()]]
86+
87+
def _scope(self) -> list[BranchingTracker]:
88+
assert len(self.scopes) > 0
89+
return self.scopes[-1]
90+
91+
def enter_scope(self) -> None:
92+
self.scopes.append([BranchingTracker()])
93+
94+
def exit_scope(self) -> None:
95+
self.scopes.pop()
96+
97+
def start_branch_statement(self) -> None:
98+
self._scope().append(BranchingTracker())
99+
100+
def next_branch(self) -> None:
101+
assert len(self._scope()) > 1
102+
self._scope()[-1].next_branch()
103+
104+
def end_branch_statement(self) -> None:
105+
assert len(self._scope()) > 1
106+
result = self._scope().pop().done()
107+
self._scope()[-1].record_nested_branch(result)
108+
109+
def record_declaration(self, name: str) -> None:
110+
assert len(self.scopes) > 0
111+
assert len(self.scopes[-1]) > 0
112+
self._scope()[-1].record_definition(name)
113+
114+
def is_possibly_undefined(self, name: str) -> bool:
115+
assert len(self._scope()) > 0
116+
# A variable is undefined if it's in a set of `may_be_defined` but not in `must_be_defined`.
117+
# Cases where a variable is not defined altogether are handled by semantic analyzer.
118+
return self._scope()[-1].is_possibly_undefined(name)
119+
120+
121+
class UndefinedVariableVisitor(TraverserVisitor):
122+
"""Detect variables that are defined only part of the time.
123+
124+
This visitor detects the following case:
125+
if foo():
126+
x = 1
127+
print(x) # Error: "x" may be undefined.
128+
129+
Note that this code does not detect variables not defined in any of the branches -- that is
130+
handled by the semantic analyzer.
131+
"""
132+
133+
def __init__(self, msg: MessageBuilder) -> None:
134+
self.msg = msg
135+
self.tracker = DefinedVariableTracker()
136+
137+
def process_lvalue(self, lvalue: Lvalue) -> None:
138+
if isinstance(lvalue, NameExpr):
139+
self.tracker.record_declaration(lvalue.name)
140+
elif isinstance(lvalue, (ListExpr, TupleExpr)):
141+
for item in lvalue.items:
142+
self.process_lvalue(item)
143+
144+
def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
145+
for lvalue in o.lvalues:
146+
self.process_lvalue(lvalue)
147+
super().visit_assignment_stmt(o)
148+
149+
def visit_if_stmt(self, o: IfStmt) -> None:
150+
for e in o.expr:
151+
e.accept(self)
152+
self.tracker.start_branch_statement()
153+
for b in o.body:
154+
b.accept(self)
155+
self.tracker.next_branch()
156+
if o.else_body:
157+
o.else_body.accept(self)
158+
self.tracker.end_branch_statement()
159+
160+
def visit_func_def(self, o: FuncDef) -> None:
161+
self.tracker.enter_scope()
162+
super().visit_func_def(o)
163+
self.tracker.exit_scope()
164+
165+
def visit_func(self, o: FuncItem) -> None:
166+
if o.arguments is not None:
167+
for arg in o.arguments:
168+
self.tracker.record_declaration(arg.variable.name)
169+
super().visit_func(o)
170+
171+
def visit_while_stmt(self, o: WhileStmt) -> None:
172+
o.expr.accept(self)
173+
self.tracker.start_branch_statement()
174+
o.body.accept(self)
175+
self.tracker.next_branch()
176+
if o.else_body:
177+
o.else_body.accept(self)
178+
self.tracker.end_branch_statement()
179+
180+
def visit_name_expr(self, o: NameExpr) -> None:
181+
if self.tracker.is_possibly_undefined(o.name):
182+
self.msg.variable_may_be_undefined(o.name, o)
183+
super().visit_name_expr(o)
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
[case testDefinedInOneBranch]
2+
# flags: --disallow-undefined-vars
3+
if int():
4+
a = 1
5+
else:
6+
x = 2
7+
z = a + 1 # E: Name "a" may be undefined
8+
[case testElif]
9+
# flags: --disallow-undefined-vars
10+
if int():
11+
a = 1
12+
elif int():
13+
a = 2
14+
else:
15+
x = 3
16+
17+
z = a + 1 # E: Name "a" may be undefined
18+
19+
[case testDefinedInAllBranches]
20+
# flags: --disallow-undefined-vars
21+
if int():
22+
a = 1
23+
elif int():
24+
a = 2
25+
else:
26+
a = 3
27+
z = a + 1
28+
29+
[case testOmittedElse]
30+
# flags: --disallow-undefined-vars
31+
if int():
32+
a = 1
33+
z = a + 1 # E: Name "a" may be undefined
34+
35+
[case testUpdatedInIf]
36+
# flags: --disallow-undefined-vars
37+
# Variable a is already defined. Just updating it in an "if" is acceptable.
38+
a = 1
39+
if int():
40+
a = 2
41+
z = a + 1
42+
43+
[case testNestedIf]
44+
# flags: --disallow-undefined-vars
45+
if int():
46+
if int():
47+
a = 1
48+
x = 1
49+
x = x + 1
50+
else:
51+
a = 2
52+
b = a + x # E: Name "x" may be undefined
53+
b = b + 1
54+
else:
55+
b = 2
56+
z = a + b # E: Name "a" may be undefined
57+
58+
[case testVeryNestedIf]
59+
# flags: --disallow-undefined-vars
60+
if int():
61+
if int():
62+
if int():
63+
a = 1
64+
else:
65+
a = 2
66+
x = a
67+
else:
68+
a = 2
69+
b = a
70+
else:
71+
b = 2
72+
z = a + b # E: Name "a" may be undefined
73+
74+
[case testTupleUnpack]
75+
# flags: --disallow-undefined-vars
76+
77+
if int():
78+
(x, y) = (1, 2)
79+
else:
80+
[y, z] = [1, 2]
81+
a = y + x # E: Name "x" may be undefined
82+
a = y + z # E: Name "z" may be undefined
83+
84+
[case testRedefined]
85+
# flags: --disallow-undefined-vars
86+
87+
if int():
88+
y = 2
89+
y = 3
90+
x = y + 2
91+
92+
[case testScope]
93+
# flags: --disallow-undefined-vars
94+
def foo() -> None:
95+
if int():
96+
y = 2
97+
98+
if int():
99+
y = 3
100+
x = y # E: Name "y" may be undefined
101+
102+
[case testFuncParams]
103+
# flags: --disallow-undefined-vars
104+
def foo(a: int) -> None:
105+
if int():
106+
a = 2
107+
x = a
108+
109+
[case testWhile]
110+
# flags: --disallow-undefined-vars
111+
while int():
112+
x = 1
113+
114+
y = x # E: Name "x" may be undefined
115+
116+
while int():
117+
z = 1
118+
else:
119+
z = 2
120+
121+
y = z # No error.

0 commit comments

Comments
 (0)