Skip to content

Commit 5bb681a

Browse files
authored
[mypyc] Recognize Literal types in __match_args__ (#18636)
Fixes #18614
1 parent a8c2345 commit 5bb681a

File tree

2 files changed

+101
-18
lines changed

2 files changed

+101
-18
lines changed

mypyc/irbuild/match.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
ValuePattern,
1717
)
1818
from mypy.traverser import TraverserVisitor
19-
from mypy.types import Instance, TupleType, get_proper_type
19+
from mypy.types import Instance, LiteralType, TupleType, get_proper_type
2020
from mypyc.ir.ops import BasicBlock, Value
2121
from mypyc.ir.rtypes import object_rprimitive
2222
from mypyc.irbuild.builder import IRBuilder
@@ -152,23 +152,7 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None:
152152

153153
node = pattern.class_ref.node
154154
assert isinstance(node, TypeInfo)
155-
156-
ty = node.names.get("__match_args__")
157-
assert ty
158-
159-
match_args_type = get_proper_type(ty.type)
160-
assert isinstance(match_args_type, TupleType)
161-
162-
match_args: list[str] = []
163-
164-
for item in match_args_type.items:
165-
proper_item = get_proper_type(item)
166-
assert isinstance(proper_item, Instance) and proper_item.last_known_value
167-
168-
match_arg = proper_item.last_known_value.value
169-
assert isinstance(match_arg, str)
170-
171-
match_args.append(match_arg)
155+
match_args = extract_dunder_match_args_names(node)
172156

173157
for i, expr in enumerate(pattern.positionals):
174158
self.builder.activate_block(self.code_block)
@@ -355,3 +339,24 @@ def prep_sequence_pattern(
355339
patterns.append(pattern)
356340

357341
return star_index, capture, patterns
342+
343+
344+
def extract_dunder_match_args_names(info: TypeInfo) -> list[str]:
345+
ty = info.names.get("__match_args__")
346+
assert ty
347+
match_args_type = get_proper_type(ty.type)
348+
assert isinstance(match_args_type, TupleType)
349+
350+
match_args: list[str] = []
351+
for item in match_args_type.items:
352+
proper_item = get_proper_type(item)
353+
354+
match_arg = None
355+
if isinstance(proper_item, Instance) and proper_item.last_known_value:
356+
match_arg = proper_item.last_known_value.value
357+
elif isinstance(proper_item, LiteralType):
358+
match_arg = proper_item.value
359+
assert isinstance(match_arg, str), f"Unrecognized __match_args__ item: {item}"
360+
361+
match_args.append(match_arg)
362+
return match_args

mypyc/test-data/irbuild-match.test

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,3 +1727,81 @@ L4:
17271727
L5:
17281728
L6:
17291729
unreachable
1730+
1731+
[case testMatchLiteralMatchArgs_python3_10]
1732+
from typing_extensions import Literal
1733+
1734+
class Foo:
1735+
__match_args__: tuple[Literal["foo"]] = ("foo",)
1736+
foo: str
1737+
1738+
def f(x: Foo) -> None:
1739+
match x:
1740+
case Foo(foo):
1741+
print("foo")
1742+
case _:
1743+
assert False, "Unreachable"
1744+
[out]
1745+
def Foo.__mypyc_defaults_setup(__mypyc_self__):
1746+
__mypyc_self__ :: __main__.Foo
1747+
r0 :: str
1748+
r1 :: tuple[str]
1749+
L0:
1750+
r0 = 'foo'
1751+
r1 = (r0)
1752+
__mypyc_self__.__match_args__ = r1
1753+
return 1
1754+
def f(x):
1755+
x :: __main__.Foo
1756+
r0 :: object
1757+
r1 :: i32
1758+
r2 :: bit
1759+
r3 :: bool
1760+
r4 :: str
1761+
r5 :: object
1762+
r6, foo, r7 :: str
1763+
r8 :: object
1764+
r9 :: str
1765+
r10 :: object
1766+
r11 :: object[1]
1767+
r12 :: object_ptr
1768+
r13, r14 :: object
1769+
r15 :: i32
1770+
r16 :: bit
1771+
r17, r18 :: bool
1772+
L0:
1773+
r0 = __main__.Foo :: type
1774+
r1 = PyObject_IsInstance(x, r0)
1775+
r2 = r1 >= 0 :: signed
1776+
r3 = truncate r1: i32 to builtins.bool
1777+
if r3 goto L1 else goto L3 :: bool
1778+
L1:
1779+
r4 = 'foo'
1780+
r5 = CPyObject_GetAttr(x, r4)
1781+
r6 = cast(str, r5)
1782+
foo = r6
1783+
L2:
1784+
r7 = 'foo'
1785+
r8 = builtins :: module
1786+
r9 = 'print'
1787+
r10 = CPyObject_GetAttr(r8, r9)
1788+
r11 = [r7]
1789+
r12 = load_address r11
1790+
r13 = PyObject_Vectorcall(r10, r12, 1, 0)
1791+
keep_alive r7
1792+
goto L8
1793+
L3:
1794+
L4:
1795+
r14 = box(bool, 0)
1796+
r15 = PyObject_IsTrue(r14)
1797+
r16 = r15 >= 0 :: signed
1798+
r17 = truncate r15: i32 to builtins.bool
1799+
if r17 goto L6 else goto L5 :: bool
1800+
L5:
1801+
r18 = raise AssertionError('Unreachable')
1802+
unreachable
1803+
L6:
1804+
goto L8
1805+
L7:
1806+
L8:
1807+
return 1

0 commit comments

Comments
 (0)