Skip to content

Commit 859f44a

Browse files
committed
Simplify name adding API
1 parent 6534bf3 commit 859f44a

File tree

3 files changed

+27
-46
lines changed

3 files changed

+27
-46
lines changed

mypy/stubgen.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -516,21 +516,20 @@ def _get_func_return(self, o: FuncDef, ctx: FunctionContext) -> str | None:
516516
if retname is not None:
517517
return retname
518518
if has_yield_expression(o) or has_yield_from_expression(o):
519-
self.add_typing_import("Generator")
520519
yield_name = "None"
521520
send_name = "None"
522521
return_name = "None"
523522
if has_yield_from_expression(o):
524-
yield_name = send_name = self.add_incomplete()
523+
yield_name = send_name = self.add_name("_typeshed.Incomplete")
525524
else:
526525
for expr, in_assignment in all_yield_expressions(o):
527526
if expr.expr is not None and not is_none_expr(expr.expr):
528-
yield_name = self.add_incomplete()
527+
yield_name = self.add_name("_typeshed.Incomplete")
529528
if in_assignment:
530-
send_name = self.add_incomplete()
529+
send_name = self.add_name("_typeshed.Incomplete")
531530
if has_return_statement(o):
532-
return_name = self.add_incomplete()
533-
generator_name = self.add_typing_import("Generator")
531+
return_name = self.add_name("_typeshed.Incomplete")
532+
generator_name = self.add_name("collections.abc.Generator")
534533
return f"{generator_name}[{yield_name}, {send_name}, {return_name}]"
535534
if not has_return_statement(o) and o.abstract_status == NOT_ABSTRACT:
536535
return "None"
@@ -731,18 +730,17 @@ def get_base_types(self, cdef: ClassDef) -> list[str]:
731730
typename = base.args[0].value
732731
if nt_fields is not None:
733732
fields_str = ", ".join(f"({f!r}, {t})" for f, t in nt_fields)
734-
namedtuple_name = self.add_typing_import("NamedTuple")
733+
namedtuple_name = self.add_name("typing.NamedTuple")
735734
base_types.append(f"{namedtuple_name}({typename!r}, [{fields_str}])")
736-
self.add_typing_import("NamedTuple")
737735
else:
738736
# Invalid namedtuple() call, cannot determine fields
739-
base_types.append(self.add_incomplete())
737+
base_types.append(self.add_name("_typeshed.Incomplete"))
740738
elif self.is_typed_namedtuple(base):
741739
base_types.append(base.accept(p))
742740
else:
743741
# At this point, we don't know what the base class is, so we
744742
# just use Incomplete as the base class.
745-
base_types.append(self.add_incomplete())
743+
base_types.append(self.add_name("_typeshed.Incomplete"))
746744
for name, value in cdef.keywords.items():
747745
if name == "metaclass":
748746
continue # handled separately
@@ -823,7 +821,7 @@ def _get_namedtuple_fields(self, call: CallExpr) -> list[tuple[str, str]] | None
823821
else:
824822
return None # Invalid namedtuple fields type
825823
if field_names:
826-
incomplete = self.add_incomplete()
824+
incomplete = self.add_name("_typeshed.Incomplete")
827825
return [(field_name, incomplete) for field_name in field_names]
828826
else:
829827
return []
@@ -857,7 +855,7 @@ def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
857855
if fields is None:
858856
self.annotate_as_incomplete(lvalue)
859857
return
860-
bases = self.add_typing_import("NamedTuple")
858+
bases = self.add_name("typing.NamedTuple")
861859
# TODO: Add support for generic NamedTuples. Requires `Generic` as base class.
862860
class_def = f"{self._indent}class {lvalue.name}({bases}):"
863861
if len(fields) == 0:
@@ -907,14 +905,13 @@ def process_typeddict(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
907905
total = arg
908906
else:
909907
items.append((arg_name, arg))
910-
self.add_typing_import("TypedDict")
911908
p = AliasPrinter(self)
912909
if any(not key.isidentifier() or keyword.iskeyword(key) for key, _ in items):
913910
# Keep the call syntax if there are non-identifier or reserved keyword keys.
914911
self.add(f"{self._indent}{lvalue.name} = {rvalue.accept(p)}\n")
915912
self._state = VAR
916913
else:
917-
bases = self.add_typing_import("TypedDict")
914+
bases = self.add_name("typing_extensions.TypedDict")
918915
# TODO: Add support for generic TypedDicts. Requires `Generic` as base class.
919916
if total is not None:
920917
bases += f", total={total.accept(p)}"
@@ -931,7 +928,7 @@ def process_typeddict(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
931928
self._state = CLASS
932929

933930
def annotate_as_incomplete(self, lvalue: NameExpr) -> None:
934-
incomplete = self.add_incomplete()
931+
incomplete = self.add_name("_typeshed.Incomplete")
935932
self.add(f"{self._indent}{lvalue.name}: {incomplete}\n")
936933
self._state = VAR
937934

@@ -1134,10 +1131,10 @@ def get_str_type_of_node(
11341131
if isinstance(rvalue, NameExpr) and rvalue.name in ("True", "False"):
11351132
return "bool"
11361133
if can_infer_optional and isinstance(rvalue, NameExpr) and rvalue.name == "None":
1137-
incomplete = self.add_incomplete()
1134+
incomplete = self.add_name("_typeshed.Incomplete")
11381135
return f"{incomplete} | None"
11391136
if can_be_any:
1140-
return self.add_incomplete()
1137+
return self.add_name("_typeshed.Incomplete")
11411138
else:
11421139
return ""
11431140

mypy/stubgenc.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def get_annotation(key: str) -> str | None:
288288
if argtype == "None":
289289
# None is not a useful annotation, but we can infer that the arg
290290
# is optional
291-
incomplete = self.add_incomplete()
291+
incomplete = self.add_name("_typeshed.Incomplete")
292292
argtype = f"{incomplete} | None"
293293
arglist.append(ArgSig(arg, argtype, default=True))
294294
else:
@@ -459,9 +459,9 @@ def get_type_annotation(self, obj: object) -> str:
459459
elif inspect.isclass(obj):
460460
return "type[{}]".format(self.get_type_fullname(obj))
461461
elif isinstance(obj, FunctionType):
462-
return self.add_typing_import("Callable")
462+
return self.add_name("typing.Callable")
463463
elif isinstance(obj, ModuleType):
464-
return self.add_obj_import("types", "ModuleType")
464+
return self.add_name("types.ModuleType", require=False)
465465
else:
466466
return self.get_type_fullname(type(obj))
467467

@@ -556,10 +556,10 @@ def generate_function_stub(
556556

557557
decorators = []
558558
if len(inferred) > 1:
559-
decorators.append("@{}".format(self.add_typing_import("overload")))
559+
decorators.append("@{}".format(self.add_name("typing.overload")))
560560

561561
if ctx.is_abstract:
562-
decorators.append("@{}".format(self.add_abc_import("abstractmethod")))
562+
decorators.append("@{}".format(self.add_name("collections.abc.abstractmethod")))
563563

564564
if class_info is not None:
565565
if self.is_staticmethod(class_info, name, obj):
@@ -646,10 +646,10 @@ def generate_property_stub(
646646
inferred_type = self.strip_or_import(inferred_type)
647647

648648
if static:
649-
classvar = self.add_typing_import("ClassVar")
649+
classvar = self.add_name("typing.ClassVar")
650650
trailing_comment = " # read-only" if readonly else ""
651651
if inferred_type is None:
652-
inferred_type = self.add_incomplete()
652+
inferred_type = self.add_name("_typeshed.Incomplete")
653653

654654
static_properties.append(
655655
f"{self._indent}{name}: {classvar}[{inferred_type}] = ...{trailing_comment}"
@@ -661,7 +661,7 @@ def generate_property_stub(
661661
ro_properties.append(self._indent + sig.format_sig())
662662
else:
663663
if inferred_type is None:
664-
inferred_type = self.add_incomplete()
664+
inferred_type = self.add_name("_typeshed.Incomplete")
665665

666666
rw_properties.append(f"{self._indent}{name}: {inferred_type}")
667667

@@ -755,7 +755,7 @@ def generate_class_stub(self, class_name: str, cls: type, output: list[str]) ->
755755
# special case for __hash__
756756
continue
757757
prop_type_name = self.strip_or_import(self.get_type_annotation(value))
758-
classvar = self.add_typing_import("ClassVar")
758+
classvar = self.add_name("typing.ClassVar")
759759
static_properties.append(f"{self._indent}{attr}: {classvar}[{prop_type_name}] = ...")
760760

761761
self.dedent()

mypy/stubutil.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -552,32 +552,16 @@ def refers_to_fullname(self, name: str, fullname: str | tuple[str, ...]) -> bool
552552
name == short or self.import_tracker.reverse_alias.get(name) == short
553553
)
554554

555-
def add_obj_import(self, module: str, name: str, require: bool = False) -> str:
556-
"""Add a name to be imported.
555+
def add_name(self, fullname: str, require: bool = True) -> str:
556+
"""Add a name to be imported and return the name reference.
557557
558558
The import will be internal to the stub (i.e don't reexport).
559559
"""
560+
module, name = fullname.rsplit(".", 1)
560561
alias = "_" + name if name in self.defined_names else None
561562
self.import_tracker.add_import_from(module, [(name, alias)], require=require)
562563
return alias or name
563564

564-
def add_typing_import(self, name: str, require: bool = True) -> str:
565-
"""Add a name to be imported from typing, unless it's imported already.
566-
567-
The import will be internal to the stub (i.e don't reexport).
568-
"""
569-
return self.add_obj_import("typing", name, require=require)
570-
571-
def add_incomplete(self, require: bool = True) -> str:
572-
return self.add_obj_import("_typeshed", "Incomplete", require=require)
573-
574-
def add_abc_import(self, name: str, require: bool = True) -> str:
575-
"""Add a name to be imported from collections.abc, unless it's imported already.
576-
577-
The import will be internal to the stub.
578-
"""
579-
return self.add_obj_import("collections.abc", name, require=require)
580-
581565
def add_import_line(self, line: str) -> None:
582566
"""Add a line of text to the import section, unless it's already there."""
583567
if line not in self._import_lines:
@@ -647,7 +631,7 @@ def set_defined_names(self, defined_names: set[str]) -> None:
647631
for t in imports:
648632
# require=False means that the import won't be added unless require_name() is called
649633
# for the object during generation.
650-
self.add_obj_import(pkg, t, require=False)
634+
self.add_name(f"{pkg}.{t}", require=False)
651635

652636
def get_signatures(
653637
self,

0 commit comments

Comments
 (0)