Skip to content

Implement miscellaneous fixes for partially-defined check #14175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2349,6 +2349,9 @@ def type_check_second_pass(self) -> bool:

def detect_partially_defined_vars(self, type_map: dict[Expression, Type]) -> None:
assert self.tree is not None, "Internal error: method must be called on parsed file only"
if self.tree.is_stub:
# We skip stub files because they aren't actually executed.
return
manager = self.manager
if manager.errors.is_error_code_enabled(
codes.PARTIALLY_DEFINED
Expand Down
43 changes: 43 additions & 0 deletions mypy/partially_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@
FuncItem,
GeneratorExpr,
IfStmt,
Import,
ImportFrom,
LambdaExpr,
ListExpr,
Lvalue,
MatchStmt,
NameExpr,
RaiseStmt,
RefExpr,
ReturnStmt,
StarExpr,
TupleExpr,
WhileStmt,
WithStmt,
implicit_module_attrs,
)
from mypy.patterns import AsPattern, StarredPattern
from mypy.reachability import ALWAYS_TRUE, infer_pattern_value
Expand Down Expand Up @@ -213,6 +219,10 @@ def is_undefined(self, name: str) -> bool:
return self._scope().branch_stmts[-1].is_undefined(name)


def refers_to_builtin(o: RefExpr) -> bool:
return o.fullname is not None and o.fullname.startswith("builtins.")


class PartiallyDefinedVariableVisitor(ExtendedTraverserVisitor):
"""Detects the following cases:
- A variable that's defined only part of the time.
Expand All @@ -236,6 +246,8 @@ def __init__(self, msg: MessageBuilder, type_map: dict[Expression, Type]) -> Non
self.type_map = type_map
self.loop_depth = 0
self.tracker = DefinedVariableTracker()
for name in implicit_module_attrs:
self.tracker.record_definition(name)

def process_lvalue(self, lvalue: Lvalue | None) -> None:
if isinstance(lvalue, NameExpr):
Expand All @@ -244,6 +256,8 @@ def process_lvalue(self, lvalue: Lvalue | None) -> None:
for ref in refs:
self.msg.var_used_before_def(lvalue.name, ref)
self.tracker.record_definition(lvalue.name)
elif isinstance(lvalue, StarExpr):
self.process_lvalue(lvalue.expr)
elif isinstance(lvalue, (ListExpr, TupleExpr)):
for item in lvalue.items:
self.process_lvalue(item)
Expand Down Expand Up @@ -291,6 +305,7 @@ def visit_match_stmt(self, o: MatchStmt) -> None:
self.tracker.end_branch_statement()

def visit_func_def(self, o: FuncDef) -> None:
self.tracker.record_definition(o.name)
self.tracker.enter_scope()
super().visit_func_def(o)
self.tracker.exit_scope()
Expand Down Expand Up @@ -332,6 +347,11 @@ def visit_return_stmt(self, o: ReturnStmt) -> None:
super().visit_return_stmt(o)
self.tracker.skip_branch()

def visit_lambda_expr(self, o: LambdaExpr) -> None:
self.tracker.enter_scope()
super().visit_lambda_expr(o)
self.tracker.exit_scope()

def visit_assert_stmt(self, o: AssertStmt) -> None:
super().visit_assert_stmt(o)
if checker.is_false_literal(o.expr):
Expand Down Expand Up @@ -377,6 +397,8 @@ def visit_starred_pattern(self, o: StarredPattern) -> None:
super().visit_starred_pattern(o)

def visit_name_expr(self, o: NameExpr) -> None:
if refers_to_builtin(o):
return
if self.tracker.is_partially_defined(o.name):
# A variable is only defined in some branches.
if self.msg.errors.is_error_code_enabled(errorcodes.PARTIALLY_DEFINED):
Expand Down Expand Up @@ -404,3 +426,24 @@ def visit_with_stmt(self, o: WithStmt) -> None:
expr.accept(self)
self.process_lvalue(idx)
o.body.accept(self)

def visit_import(self, o: Import) -> None:
for mod, alias in o.ids:
if alias is not None:
self.tracker.record_definition(alias)
else:
# When you do `import x.y`, only `x` becomes defined.
names = mod.split(".")
if len(names) > 0:
# `names` should always be nonempty, but we don't want mypy
# to crash on invalid code.
self.tracker.record_definition(names[0])
super().visit_import(o)

def visit_import_from(self, o: ImportFrom) -> None:
for mod, alias in o.names:
name = alias
if name is None:
name = mod
self.tracker.record_definition(name)
super().visit_import_from(o)
121 changes: 121 additions & 0 deletions test-data/unit/check-partially-defined.test
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ else:
a = y + x # E: Name "x" may be undefined
a = y + z # E: Name "z" may be undefined

[case testIndexExpr]
# flags: --enable-error-code partially-defined

if int():
*x, y = (1, 2)
else:
x = [1, 2]
a = x # No error.
b = y # E: Name "y" may be undefined

[case testRedefined]
# flags: --enable-error-code partially-defined
y = 3
Expand All @@ -104,6 +114,32 @@ else:

x = y + 2

[case testFunction]
# flags: --enable-error-code partially-defined
def f0() -> None:
if int():
def some_func() -> None:
pass

some_func() # E: Name "some_func" may be undefined

def f1() -> None:
if int():
def some_func() -> None:
pass
else:
def some_func() -> None:
pass

some_func() # No error.

[case testLambda]
# flags: --enable-error-code partially-defined
def f0(b: bool) -> None:
if b:
fn = lambda: 2
y = fn # E: Name "fn" may be undefined

[case testGenerator]
# flags: --enable-error-code partially-defined
if int():
Expand Down Expand Up @@ -460,3 +496,88 @@ def f4() -> None:
y = z # E: Name "z" is used before definition
x = z # E: Name "z" is used before definition
z: int = 2

[case testUseBeforeDefImportsBasic]
# flags: --enable-error-code use-before-def
import foo # type: ignore
import x.y # type: ignore

def f0() -> None:
a = foo # No error.
foo: int = 1

def f1() -> None:
a = y # E: Name "y" is used before definition
y: int = 1

def f2() -> None:
a = x # No error.
x: int = 1

def f3() -> None:
a = x.y # No error.
x: int = 1

[case testUseBeforeDefImportBasicRename]
# flags: --enable-error-code use-before-def
import x.y as z # type: ignore
from typing import Any

def f0() -> None:
a = z # No error.
z: int = 1

def f1() -> None:
a = x # E: Name "x" is used before definition
x: int = 1

def f2() -> None:
a = x.y # E: Name "x" is used before definition
x: Any = 1

def f3() -> None:
a = y # E: Name "y" is used before definition
y: int = 1

[case testUseBeforeDefImportFrom]
# flags: --enable-error-code use-before-def
from foo import x # type: ignore

def f0() -> None:
a = x # No error.
x: int = 1

[case testUseBeforeDefImportFromRename]
# flags: --enable-error-code use-before-def
from foo import x as y # type: ignore

def f0() -> None:
a = y # No error.
y: int = 1

def f1() -> None:
a = x # E: Name "x" is used before definition
x: int = 1

[case testUseBeforeDefFunctionDeclarations]
# flags: --enable-error-code use-before-def

def f0() -> None:
def inner() -> None:
pass

inner() # No error.
inner = lambda: None

[case testUseBeforeDefBuiltins]
# flags: --enable-error-code use-before-def

def f0() -> None:
s = type(123)
type = "abc"
a = type

[case testUseBeforeDefImplicitModuleAttrs]
# flags: --enable-error-code use-before-def
a = __name__ # No error.
__name__ = "abc"