Skip to content

Commit c26f129

Browse files
authored
Allow inferring +int to be a Literal (#16910)
This makes unary positive on integers preserve the literal value of the integer, allowing `var: Literal[1] = +1` to be accepted. Basically I looked for code handling `__neg__` and added a branch for `__pos__` as well. Fixes #16728.
1 parent b6e91d4 commit c26f129

File tree

6 files changed

+45
-11
lines changed

6 files changed

+45
-11
lines changed

mypy/checkexpr.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4437,6 +4437,10 @@ def try_getting_int_literals(self, index: Expression) -> list[int] | None:
44374437
operand = index.expr
44384438
if isinstance(operand, IntExpr):
44394439
return [-1 * operand.value]
4440+
if index.op == "+":
4441+
operand = index.expr
4442+
if isinstance(operand, IntExpr):
4443+
return [operand.value]
44404444
typ = get_proper_type(self.accept(index))
44414445
if isinstance(typ, Instance) and typ.last_known_value is not None:
44424446
typ = typ.last_known_value

mypy/exprtotype.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,12 @@ def expr_to_unanalyzed_type(
183183
elif isinstance(expr, UnaryExpr):
184184
typ = expr_to_unanalyzed_type(expr.expr, options, allow_new_syntax)
185185
if isinstance(typ, RawExpressionType):
186-
if isinstance(typ.literal_value, int) and expr.op == "-":
187-
typ.literal_value *= -1
188-
return typ
186+
if isinstance(typ.literal_value, int):
187+
if expr.op == "-":
188+
typ.literal_value *= -1
189+
return typ
190+
elif expr.op == "+":
191+
return typ
189192
raise TypeTranslationError()
190193
elif isinstance(expr, IntExpr):
191194
return RawExpressionType(expr.value, "builtins.int", line=expr.line, column=expr.column)

mypy/plugins/default.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No
100100
return int_pow_callback
101101
elif fullname == "builtins.int.__neg__":
102102
return int_neg_callback
103+
elif fullname == "builtins.int.__pos__":
104+
return int_pos_callback
103105
elif fullname in ("builtins.tuple.__mul__", "builtins.tuple.__rmul__"):
104106
return tuple_mul_callback
105107
elif fullname in {n + ".setdefault" for n in TPDICT_FB_NAMES}:
@@ -471,32 +473,43 @@ def int_pow_callback(ctx: MethodContext) -> Type:
471473
return ctx.default_return_type
472474

473475

474-
def int_neg_callback(ctx: MethodContext) -> Type:
475-
"""Infer a more precise return type for int.__neg__.
476+
def int_neg_callback(ctx: MethodContext, multiplier: int = -1) -> Type:
477+
"""Infer a more precise return type for int.__neg__ and int.__pos__.
476478
477479
This is mainly used to infer the return type as LiteralType
478-
if the original underlying object is a LiteralType object
480+
if the original underlying object is a LiteralType object.
479481
"""
480482
if isinstance(ctx.type, Instance) and ctx.type.last_known_value is not None:
481483
value = ctx.type.last_known_value.value
482484
fallback = ctx.type.last_known_value.fallback
483485
if isinstance(value, int):
484486
if is_literal_type_like(ctx.api.type_context[-1]):
485-
return LiteralType(value=-value, fallback=fallback)
487+
return LiteralType(value=multiplier * value, fallback=fallback)
486488
else:
487489
return ctx.type.copy_modified(
488490
last_known_value=LiteralType(
489-
value=-value, fallback=ctx.type, line=ctx.type.line, column=ctx.type.column
491+
value=multiplier * value,
492+
fallback=ctx.type,
493+
line=ctx.type.line,
494+
column=ctx.type.column,
490495
)
491496
)
492497
elif isinstance(ctx.type, LiteralType):
493498
value = ctx.type.value
494499
fallback = ctx.type.fallback
495500
if isinstance(value, int):
496-
return LiteralType(value=-value, fallback=fallback)
501+
return LiteralType(value=multiplier * value, fallback=fallback)
497502
return ctx.default_return_type
498503

499504

505+
def int_pos_callback(ctx: MethodContext) -> Type:
506+
"""Infer a more precise return type for int.__pos__.
507+
508+
This is identical to __neg__, except the value is not inverted.
509+
"""
510+
return int_neg_callback(ctx, +1)
511+
512+
500513
def tuple_mul_callback(ctx: MethodContext) -> Type:
501514
"""Infer a more precise return type for tuple.__mul__ and tuple.__rmul__.
502515

test-data/unit/check-literal.test

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,29 +397,36 @@ from typing_extensions import Literal
397397
a1: Literal[4]
398398
b1: Literal[0x2a]
399399
c1: Literal[-300]
400+
d1: Literal[+8]
400401

401402
reveal_type(a1) # N: Revealed type is "Literal[4]"
402403
reveal_type(b1) # N: Revealed type is "Literal[42]"
403404
reveal_type(c1) # N: Revealed type is "Literal[-300]"
405+
reveal_type(d1) # N: Revealed type is "Literal[8]"
404406

405407
a2t = Literal[4]
406408
b2t = Literal[0x2a]
407409
c2t = Literal[-300]
410+
d2t = Literal[+8]
408411
a2: a2t
409412
b2: b2t
410413
c2: c2t
414+
d2: d2t
411415

412416
reveal_type(a2) # N: Revealed type is "Literal[4]"
413417
reveal_type(b2) # N: Revealed type is "Literal[42]"
414418
reveal_type(c2) # N: Revealed type is "Literal[-300]"
419+
reveal_type(d2) # N: Revealed type is "Literal[8]"
415420

416421
def f1(x: Literal[4]) -> Literal[4]: pass
417422
def f2(x: Literal[0x2a]) -> Literal[0x2a]: pass
418423
def f3(x: Literal[-300]) -> Literal[-300]: pass
424+
def f4(x: Literal[+8]) -> Literal[+8]: pass
419425

420426
reveal_type(f1) # N: Revealed type is "def (x: Literal[4]) -> Literal[4]"
421427
reveal_type(f2) # N: Revealed type is "def (x: Literal[42]) -> Literal[42]"
422428
reveal_type(f3) # N: Revealed type is "def (x: Literal[-300]) -> Literal[-300]"
429+
reveal_type(f4) # N: Revealed type is "def (x: Literal[8]) -> Literal[8]"
423430
[builtins fixtures/tuple.pyi]
424431
[out]
425432

@@ -2747,6 +2754,9 @@ d: Literal[1] = 1
27472754
e: Literal[2] = 2
27482755
f: Literal[+1] = 1
27492756
g: Literal[+2] = 2
2757+
h: Literal[1] = +1
2758+
i: Literal[+2] = 2
2759+
j: Literal[+3] = +3
27502760

27512761
x: Literal[+True] = True # E: Invalid type: Literal[...] cannot contain arbitrary expressions
27522762
y: Literal[-True] = -1 # E: Invalid type: Literal[...] cannot contain arbitrary expressions
@@ -2759,14 +2769,15 @@ from typing_extensions import Literal, Final
27592769

27602770
ONE: Final = 1
27612771
x: Literal[-1] = -ONE
2772+
y: Literal[+1] = +ONE
27622773

27632774
TWO: Final = 2
27642775
THREE: Final = 3
27652776

27662777
err_code = -TWO
27672778
if bool():
27682779
err_code = -THREE
2769-
[builtins fixtures/float.pyi]
2780+
[builtins fixtures/ops.pyi]
27702781

27712782
[case testAliasForEnumTypeAsLiteral]
27722783
from typing_extensions import Literal

test-data/unit/check-tuples.test

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,10 +337,12 @@ if int():
337337
b = t1[-1]
338338
if int():
339339
a = t1[(0)]
340+
if int():
341+
b = t1[+1]
340342
if int():
341343
x = t3[0:3] # type (A, B, C)
342344
if int():
343-
y = t3[0:5:2] # type (A, C, E)
345+
y = t3[0:+5:2] # type (A, C, E)
344346
if int():
345347
x = t3[:-2] # type (A, B, C)
346348

test-data/unit/fixtures/tuple.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class classmethod: pass
3232
# We need int and slice for indexing tuples.
3333
class int:
3434
def __neg__(self) -> 'int': pass
35+
def __pos__(self) -> 'int': pass
3536
class float: pass
3637
class slice: pass
3738
class bool(int): pass

0 commit comments

Comments
 (0)