Skip to content

Constant fold initializers of final variables #14283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions mypy/constant_fold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Constant folding of expressions.

For example, 3 + 5 can be constant folded into 8.
"""

from __future__ import annotations

from typing import Union
from typing_extensions import Final

from mypy.nodes import Expression, FloatExpr, IntExpr, NameExpr, OpExpr, StrExpr, UnaryExpr, Var

# All possible result types of constant folding
ConstantValue = Union[int, bool, float, str]
CONST_TYPES: Final = (int, bool, float, str)


def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | None:
"""Return the constant value of an expression for supported operations.

Among other things, support int arithmetic and string
concatenation. For example, the expression 3 + 5 has the constant
value 8.

Also bind simple references to final constants defined in the
current module (cur_mod_id). Binding to references is best effort
-- we don't bind references to other modules. Mypyc trusts these
to be correct in compiled modules, so that it can replace a
constant expression (or a reference to one) with the statically
computed value. We don't want to infer constant values based on
stubs, in particular, as these might not match the implementation
(due to version skew, for example).

Return None if unsuccessful.
"""
if isinstance(expr, IntExpr):
return expr.value
if isinstance(expr, StrExpr):
return expr.value
if isinstance(expr, FloatExpr):
return expr.value
elif isinstance(expr, NameExpr):
if expr.name == "True":
return True
elif expr.name == "False":
return False
Comment on lines +36 to +46
Copy link
Member

@AlexWaygood AlexWaygood Dec 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, this is very similar logic to the code @JelleZijlstra added in evalexpr.py in a9c62c5. I wonder if that logic could be reused here?

(Apologies if this comment makes no sense; I don't know much about mypyc internals!)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, there is definitely some overlap. Right now the benefits of sharing code don't seem big enough to me to make it worthwhile to increase coupling, but if we add support for more things, it may make sense to refactor and share parts of the implementations. Note that mypy and mypyc already share some of the constant folding code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a visitor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not using a visitor, since only relatively few AST node types are handled here, and others have shared default behavior. I generally use a visitor when I need to implement logic for many node types. However, I think that in this case both using and not using a visitor would have been reasonable.

node = expr.node
if (
isinstance(node, Var)
and node.is_final
and node.fullname.rsplit(".", 1)[0] == cur_mod_id
):
value = node.final_value
if isinstance(value, (CONST_TYPES)):
return value
elif isinstance(expr, OpExpr):
left = constant_fold_expr(expr.left, cur_mod_id)
right = constant_fold_expr(expr.right, cur_mod_id)
if isinstance(left, int) and isinstance(right, int):
return constant_fold_binary_int_op(expr.op, left, right)
elif isinstance(left, str) and isinstance(right, str):
return constant_fold_binary_str_op(expr.op, left, right)
elif isinstance(expr, UnaryExpr):
value = constant_fold_expr(expr.expr, cur_mod_id)
if isinstance(value, int):
return constant_fold_unary_int_op(expr.op, value)
return None


def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | None:
if op == "+":
return left + right
if op == "-":
return left - right
elif op == "*":
return left * right
elif op == "//":
if right != 0:
return left // right
elif op == "%":
if right != 0:
return left % right
elif op == "&":
return left & right
elif op == "|":
return left | right
elif op == "^":
return left ^ right
elif op == "<<":
if right >= 0:
return left << right
elif op == ">>":
if right >= 0:
return left >> right
elif op == "**":
if right >= 0:
ret = left**right
assert isinstance(ret, int)
return ret
return None


def constant_fold_unary_int_op(op: str, value: int) -> int | None:
if op == "-":
return -value
elif op == "~":
return ~value
elif op == "+":
return value
return None


def constant_fold_binary_str_op(op: str, left: str, right: str) -> str | None:
if op == "+":
return left + right
return None
73 changes: 33 additions & 40 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from typing_extensions import Final, TypeAlias as _TypeAlias

from mypy import errorcodes as codes, message_registry
from mypy.constant_fold import constant_fold_expr
from mypy.errorcodes import ErrorCode
from mypy.errors import Errors, report_internal_error
from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type
Expand Down Expand Up @@ -91,7 +92,6 @@
AwaitExpr,
Block,
BreakStmt,
BytesExpr,
CallExpr,
CastExpr,
ClassDef,
Expand All @@ -108,7 +108,6 @@
Expression,
ExpressionStmt,
FakeExpression,
FloatExpr,
ForStmt,
FuncBase,
FuncDef,
Expand All @@ -121,7 +120,6 @@
ImportBase,
ImportFrom,
IndexExpr,
IntExpr,
LambdaExpr,
ListComprehension,
ListExpr,
Expand Down Expand Up @@ -250,7 +248,6 @@
FunctionLike,
Instance,
LiteralType,
LiteralValue,
NoneType,
Overloaded,
Parameters,
Expand Down Expand Up @@ -3138,7 +3135,8 @@ def store_final_status(self, s: AssignmentStmt) -> None:
node = s.lvalues[0].node
if isinstance(node, Var):
node.is_final = True
node.final_value = self.unbox_literal(s.rvalue)
if s.type:
node.final_value = constant_fold_expr(s.rvalue, self.cur_mod_id)
if self.is_class_scope() and (
isinstance(s.rvalue, TempNode) and s.rvalue.no_rhs
):
Expand Down Expand Up @@ -3198,13 +3196,6 @@ def flatten_lvalues(self, lvalues: list[Expression]) -> list[Expression]:
res.append(lv)
return res

def unbox_literal(self, e: Expression) -> int | float | bool | str | None:
if isinstance(e, (IntExpr, FloatExpr, StrExpr)):
return e.value
elif isinstance(e, NameExpr) and e.name in ("True", "False"):
return True if e.name == "True" else False
return None

def process_type_annotation(self, s: AssignmentStmt) -> None:
"""Analyze type annotation or infer simple literal type."""
if s.type:
Expand Down Expand Up @@ -3259,39 +3250,33 @@ def is_annotated_protocol_member(self, s: AssignmentStmt) -> bool:

def analyze_simple_literal_type(self, rvalue: Expression, is_final: bool) -> Type | None:
"""Return builtins.int if rvalue is an int literal, etc.
If this is a 'Final' context, we return "Literal[...]" instead."""
if self.options.semantic_analysis_only or self.function_stack:
# Skip this if we're only doing the semantic analysis pass.
# This is mostly to avoid breaking unit tests.
# Also skip inside a function; this is to avoid confusing

If this is a 'Final' context, we return "Literal[...]" instead.
"""
if self.function_stack:
# Skip inside a function; this is to avoid confusing
# the code that handles dead code due to isinstance()
# inside type variables with value restrictions (like
# AnyStr).
return None
if isinstance(rvalue, FloatExpr):
return self.named_type_or_none("builtins.float")

value: LiteralValue | None = None
type_name: str | None = None
if isinstance(rvalue, IntExpr):
value, type_name = rvalue.value, "builtins.int"
if isinstance(rvalue, StrExpr):
value, type_name = rvalue.value, "builtins.str"
if isinstance(rvalue, BytesExpr):
value, type_name = rvalue.value, "builtins.bytes"

if type_name is not None:
assert value is not None
typ = self.named_type_or_none(type_name)
if typ and is_final:
return typ.copy_modified(
last_known_value=LiteralType(
value=value, fallback=typ, line=typ.line, column=typ.column
)
)
return typ

return None
value = constant_fold_expr(rvalue, self.cur_mod_id)
if value is None:
return None

if isinstance(value, bool):
type_name = "builtins.bool"
elif isinstance(value, int):
type_name = "builtins.int"
elif isinstance(value, str):
type_name = "builtins.str"
elif isinstance(value, float):
type_name = "builtins.float"

typ = self.named_type_or_none(type_name)
if typ and is_final:
return typ.copy_modified(last_known_value=LiteralType(value=value, fallback=typ))
return typ

def analyze_alias(
self, name: str, rvalue: Expression, allow_placeholder: bool = False
Expand Down Expand Up @@ -3827,6 +3812,14 @@ def store_declared_types(self, lvalue: Lvalue, typ: Type) -> None:
var = lvalue.node
var.type = typ
var.is_ready = True
typ = get_proper_type(typ)
if (
var.is_final
and isinstance(typ, Instance)
and typ.last_known_value
and (not self.type or not self.type.is_enum)
):
var.final_value = typ.last_known_value.value
# If node is not a variable, we'll catch it elsewhere.
elif isinstance(lvalue, TupleExpr):
typ = get_proper_type(typ)
Expand Down
5 changes: 4 additions & 1 deletion mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@
# Note: Although "Literal[None]" is a valid type, we internally always convert
# such a type directly into "None". So, "None" is not a valid parameter of
# LiteralType and is omitted from this list.
LiteralValue: _TypeAlias = Union[int, str, bool]
#
# Note: Float values are only used internally. They are not accepted within
# Literal[...].
LiteralValue: _TypeAlias = Union[int, str, bool, float]


# If we only import type_visitor in the middle of the file, mypy
Expand Down
59 changes: 10 additions & 49 deletions mypyc/irbuild/constant_fold.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
"""Constant folding of IR values.

For example, 3 + 5 can be constant folded into 8.

This is mostly like mypy.constant_fold, but we can bind some additional
NameExpr and MemberExpr references here, since we have more knowledge
about which definitions can be trusted -- we constant fold only references
to other compiled modules in the same compilation unit.
"""

from __future__ import annotations

from typing import Union
from typing_extensions import Final

from mypy.constant_fold import (
constant_fold_binary_int_op,
constant_fold_binary_str_op,
constant_fold_unary_int_op,
)
from mypy.nodes import Expression, IntExpr, MemberExpr, NameExpr, OpExpr, StrExpr, UnaryExpr, Var
from mypyc.irbuild.builder import IRBuilder

Expand Down Expand Up @@ -51,52 +61,3 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue |
if isinstance(value, int):
return constant_fold_unary_int_op(expr.op, value)
return None


def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | None:
if op == "+":
return left + right
if op == "-":
return left - right
elif op == "*":
return left * right
elif op == "//":
if right != 0:
return left // right
elif op == "%":
if right != 0:
return left % right
elif op == "&":
return left & right
elif op == "|":
return left | right
elif op == "^":
return left ^ right
elif op == "<<":
if right >= 0:
return left << right
elif op == ">>":
if right >= 0:
return left >> right
elif op == "**":
if right >= 0:
ret = left**right
assert isinstance(ret, int)
return ret
return None


def constant_fold_unary_int_op(op: str, value: int) -> int | None:
if op == "-":
return -value
elif op == "~":
return ~value
elif op == "+":
return value
return None


def constant_fold_binary_str_op(op: str, left: str, right: str) -> str | None:
if op == "+":
return left + right
return None
2 changes: 1 addition & 1 deletion mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -3273,7 +3273,7 @@ L2:
[case testFinalStaticInt]
from typing import Final

x: Final = 1 + 1
x: Final = 1 + int()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intent of changes like these is to preserve the old test case behavior where constant folding wasn't performed.


def f() -> int:
return x - 1
Expand Down
14 changes: 2 additions & 12 deletions mypyc/test-data/irbuild-constant-fold.test
Original file line number Diff line number Diff line change
Expand Up @@ -205,23 +205,13 @@ Y: Final = 2 + 4

def f() -> None:
a = X + 1
# TODO: Constant fold this as well
a = Y + 1
[out]
def f():
a, r0 :: int
r1 :: bool
r2 :: int
a :: int
L0:
a = 12
r0 = __main__.Y :: static
if is_error(r0) goto L1 else goto L2
L1:
r1 = raise NameError('value for final name "Y" was not set')
unreachable
L2:
r2 = CPyTagged_Add(r0, 2)
a = r2
a = 14
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This illustrates how much we can simplify the generated code. The simpler code is also significantly faster, since there is no memory read any more.

return 1

[case testIntConstantFoldingClassFinal]
Expand Down
Loading