|
2 | 2 | from typing import Callable, Optional, List
|
3 | 3 |
|
4 | 4 | from mypy import message_registry
|
5 |
| -from mypy.nodes import StrExpr, IntExpr, DictExpr, UnaryExpr |
| 5 | +from mypy.nodes import Expression, StrExpr, IntExpr, DictExpr, UnaryExpr |
6 | 6 | from mypy.plugin import (
|
7 |
| - Plugin, FunctionContext, MethodContext, MethodSigContext, AttributeContext, ClassDefContext |
| 7 | + Plugin, FunctionContext, MethodContext, MethodSigContext, AttributeContext, ClassDefContext, |
| 8 | + CheckerPluginInterface, |
8 | 9 | )
|
9 | 10 | from mypy.plugins.common import try_getting_str_literals
|
10 | 11 | from mypy.types import (
|
@@ -66,6 +67,8 @@ def get_method_hook(self, fullname: str
|
66 | 67 | return ctypes.array_getitem_callback
|
67 | 68 | elif fullname == 'ctypes.Array.__iter__':
|
68 | 69 | return ctypes.array_iter_callback
|
| 70 | + elif fullname == 'pathlib.Path.open': |
| 71 | + return path_open_callback |
69 | 72 | return None
|
70 | 73 |
|
71 | 74 | def get_attribute_hook(self, fullname: str
|
@@ -101,23 +104,55 @@ def get_class_decorator_hook(self, fullname: str
|
101 | 104 |
|
102 | 105 |
|
103 | 106 | def open_callback(ctx: FunctionContext) -> Type:
|
104 |
| - """Infer a better return type for 'open'. |
105 |
| -
|
106 |
| - Infer TextIO or BinaryIO as the return value if the mode argument is not |
107 |
| - given or is a literal. |
| 107 | + """Infer a better return type for 'open'.""" |
| 108 | + return _analyze_open_signature( |
| 109 | + arg_types=ctx.arg_types, |
| 110 | + args=ctx.args, |
| 111 | + mode_arg_index=1, |
| 112 | + default_return_type=ctx.default_return_type, |
| 113 | + api=ctx.api, |
| 114 | + ) |
| 115 | + |
| 116 | + |
| 117 | +def path_open_callback(ctx: MethodContext) -> Type: |
| 118 | + """Infer a better return type for 'pathlib.Path.open'.""" |
| 119 | + return _analyze_open_signature( |
| 120 | + arg_types=ctx.arg_types, |
| 121 | + args=ctx.args, |
| 122 | + mode_arg_index=0, |
| 123 | + default_return_type=ctx.default_return_type, |
| 124 | + api=ctx.api, |
| 125 | + ) |
| 126 | + |
| 127 | + |
| 128 | +def _analyze_open_signature(arg_types: List[List[Type]], |
| 129 | + args: List[List[Expression]], |
| 130 | + mode_arg_index: int, |
| 131 | + default_return_type: Type, |
| 132 | + api: CheckerPluginInterface, |
| 133 | + ) -> Type: |
| 134 | + """A helper for analyzing any function that has approximately |
| 135 | + the same signature as the builtin 'open(...)' function. |
| 136 | +
|
| 137 | + Currently, the only thing the caller can customize is the index |
| 138 | + of the 'mode' argument. If the mode argument is omitted or is a |
| 139 | + string literal, we refine the return type to either 'TextIO' or |
| 140 | + 'BinaryIO' as appropriate. |
108 | 141 | """
|
109 | 142 | mode = None
|
110 |
| - if not ctx.arg_types or len(ctx.arg_types[1]) != 1: |
| 143 | + if not arg_types or len(arg_types[mode_arg_index]) != 1: |
111 | 144 | mode = 'r'
|
112 |
| - elif isinstance(ctx.args[1][0], StrExpr): |
113 |
| - mode = ctx.args[1][0].value |
| 145 | + else: |
| 146 | + mode_expr = args[mode_arg_index][0] |
| 147 | + if isinstance(mode_expr, StrExpr): |
| 148 | + mode = mode_expr.value |
114 | 149 | if mode is not None:
|
115 |
| - assert isinstance(ctx.default_return_type, Instance) # type: ignore |
| 150 | + assert isinstance(default_return_type, Instance) # type: ignore |
116 | 151 | if 'b' in mode:
|
117 |
| - return ctx.api.named_generic_type('typing.BinaryIO', []) |
| 152 | + return api.named_generic_type('typing.BinaryIO', []) |
118 | 153 | else:
|
119 |
| - return ctx.api.named_generic_type('typing.TextIO', []) |
120 |
| - return ctx.default_return_type |
| 154 | + return api.named_generic_type('typing.TextIO', []) |
| 155 | + return default_return_type |
121 | 156 |
|
122 | 157 |
|
123 | 158 | def contextmanager_callback(ctx: FunctionContext) -> Type:
|
|
0 commit comments