Skip to content

Commit 7752e90

Browse files
committed
[stubgenc] Support nested classes
1 parent 9d49fd8 commit 7752e90

File tree

3 files changed

+61
-13
lines changed

3 files changed

+61
-13
lines changed

mypy/stubgenc.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def generate_stub_for_c_module(module_name: str,
6969
if name.startswith('__') and name.endswith('__'):
7070
continue
7171
if name not in done and not inspect.ismodule(obj):
72-
type_str = strip_or_import(type(obj).__name__, module, imports)
72+
type_str = strip_or_import(get_type_fullname(type(obj)), module, imports)
7373
variables.append('%s: %s' % (name, type_str))
7474
output = []
7575
for line in sorted(set(imports)):
@@ -286,6 +286,7 @@ def generate_c_type_stub(module: ModuleType,
286286
obj_dict = getattr(obj, '__dict__') # type: Mapping[str, Any] # noqa
287287
items = sorted(obj_dict.items(), key=lambda x: method_name_sort_key(x[0]))
288288
methods = [] # type: List[str]
289+
types = [] # type: List[str]
289290
properties = [] # type: List[str]
290291
done = set() # type: Set[str]
291292
for attr, value in items:
@@ -312,14 +313,18 @@ def generate_c_type_stub(module: ModuleType,
312313
done.add(attr)
313314
generate_c_property_stub(attr, value, properties, is_c_property_readonly(value),
314315
module=module, imports=imports)
316+
elif is_c_type(value):
317+
generate_c_type_stub(module, attr, value, types, imports=imports, sigs=sigs,
318+
class_sigs=class_sigs)
319+
done.add(attr)
315320

316321
variables = []
317322
for attr, value in items:
318323
if is_skipped_attribute(attr):
319324
continue
320325
if attr not in done:
321326
variables.append('%s: %s = ...' % (
322-
attr, strip_or_import(type(value).__name__, module, imports)))
327+
attr, strip_or_import(get_type_fullname(type(value)), module, imports)))
323328
all_bases = obj.mro()
324329
if all_bases[-1] is object:
325330
# TODO: Is this always object?
@@ -345,10 +350,15 @@ def generate_c_type_stub(module: ModuleType,
345350
)
346351
else:
347352
bases_str = ''
348-
if not methods and not variables and not properties:
353+
if not methods and not variables and not properties and not types:
349354
output.append('class %s%s: ...' % (class_name, bases_str))
350355
else:
351356
output.append('class %s%s:' % (class_name, bases_str))
357+
for line in types:
358+
if output and output[-1] and \
359+
not output[-1].startswith('class') and line.startswith('class'):
360+
output.append('')
361+
output.append(' ' + line)
352362
for variable in variables:
353363
output.append(' %s' % variable)
354364
for method in methods:
@@ -358,7 +368,7 @@ def generate_c_type_stub(module: ModuleType,
358368

359369

360370
def get_type_fullname(typ: type) -> str:
361-
return '%s.%s' % (typ.__module__, typ.__name__)
371+
return '%s.%s' % (typ.__module__, getattr(typ, '__qualname__', typ.__name__))
362372

363373

364374
def method_name_sort_key(name: str) -> Tuple[int, str]:

mypy/test/teststubgen.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,14 @@ def add_file(self, path: str, result: List[str], header: bool) -> None:
625625
self_arg = ArgSig(name='self')
626626

627627

628+
class TestBaseClass:
629+
pass
630+
631+
632+
class TestClass(TestBaseClass):
633+
pass
634+
635+
628636
class StubgencSuite(unittest.TestCase):
629637
"""Unit tests for stub generation from C modules using introspection.
630638
@@ -668,7 +676,7 @@ class TestClassVariableCls:
668676
mod = ModuleType('module', '') # any module is fine
669677
generate_c_type_stub(mod, 'C', TestClassVariableCls, output, imports)
670678
assert_equal(imports, [])
671-
assert_equal(output, ['class C:', ' x: Any = ...'])
679+
assert_equal(output, ['class C:', ' x: int = ...'])
672680

673681
def test_generate_c_type_inheritance(self) -> None:
674682
class TestClass(KeyError):
@@ -682,12 +690,6 @@ class TestClass(KeyError):
682690
assert_equal(imports, [])
683691

684692
def test_generate_c_type_inheritance_same_module(self) -> None:
685-
class TestBaseClass:
686-
pass
687-
688-
class TestClass(TestBaseClass):
689-
pass
690-
691693
output = [] # type: List[str]
692694
imports = [] # type: List[str]
693695
mod = ModuleType(TestBaseClass.__module__, '')

test-data/stubgen/pybind11_mypy_demo/basics.pyi

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,44 @@ from typing import overload
44
PI: float
55

66
class Point:
7-
AngleUnit: pybind11_type = ...
8-
LengthUnit: pybind11_type = ...
7+
class AngleUnit:
8+
__entries: dict = ...
9+
degree: Point.AngleUnit = ...
10+
radian: Point.AngleUnit = ...
11+
def __init__(self, value: int) -> None: ...
12+
def __eq__(self, other: object) -> bool: ...
13+
def __getstate__(self) -> int: ...
14+
def __hash__(self) -> int: ...
15+
def __index__(self) -> int: ...
16+
def __int__(self) -> int: ...
17+
def __ne__(self, other: object) -> bool: ...
18+
def __setstate__(self, state: int) -> None: ...
19+
@property
20+
def name(self) -> Any: ...
21+
@property
22+
def __doc__(self) -> Any: ...
23+
@property
24+
def __members__(self) -> Any: ...
25+
26+
class LengthUnit:
27+
__entries: dict = ...
28+
inch: Point.LengthUnit = ...
29+
mm: Point.LengthUnit = ...
30+
pixel: Point.LengthUnit = ...
31+
def __init__(self, value: int) -> None: ...
32+
def __eq__(self, other: object) -> bool: ...
33+
def __getstate__(self) -> int: ...
34+
def __hash__(self) -> int: ...
35+
def __index__(self) -> int: ...
36+
def __int__(self) -> int: ...
37+
def __ne__(self, other: object) -> bool: ...
38+
def __setstate__(self, state: int) -> None: ...
39+
@property
40+
def name(self) -> Any: ...
41+
@property
42+
def __doc__(self) -> Any: ...
43+
@property
44+
def __members__(self) -> Any: ...
945
origin: Point = ...
1046
@overload
1147
def __init__(self) -> None: ...

0 commit comments

Comments
 (0)