Skip to content

Commit 7a826c9

Browse files
ref: future-proof using signature modification (#53347)
the signature of a few of these changes in newer `django-stubs` -- this makes the plugin a little more future-proof
1 parent 4c71260 commit 7a826c9

File tree

1 file changed

+29
-48
lines changed

1 file changed

+29
-48
lines changed

tools/mypy_helpers/plugin.py

Lines changed: 29 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,71 +7,52 @@
77
from mypy.types import CallableType, FunctionLike, Instance
88

99

10-
def replace_transaction_atomic_sig_callback(ctx: FunctionSigContext) -> CallableType:
11-
signature = ctx.default_signature
10+
def _make_using_required_str(ctx: FunctionSigContext) -> CallableType:
11+
sig = ctx.default_signature
1212

13-
using_arg = signature.argument_by_name("using")
14-
if not using_arg:
15-
# No using arg in the signature, bail
16-
return signature
13+
using_arg = sig.argument_by_name("using")
14+
if using_arg is None or using_arg.pos is None:
15+
ctx.api.fail("The using parameter is required", ctx.context)
16+
return sig
1717

18-
# We care about context managers.
19-
ret_type = signature.ret_type
20-
if not isinstance(ret_type, Instance):
21-
return signature
18+
for kind in sig.arg_kinds[: using_arg.pos]:
19+
if kind != ARG_POS:
20+
ctx.api.fail("Expected using to be the first optional", ctx.context)
21+
return sig
2222

23-
# Replace the type and remove the default value of using.
2423
str_type = ctx.api.named_generic_type("builtins.str", [])
24+
arg_kinds = [*sig.arg_kinds[: using_arg.pos], ARG_POS, *sig.arg_kinds[using_arg.pos + 1 :]]
25+
arg_types = [*sig.arg_types[: using_arg.pos], str_type, *sig.arg_types[using_arg.pos + 1 :]]
26+
return sig.copy_modified(arg_kinds=arg_kinds, arg_types=arg_types)
2527

26-
arg_types = signature.arg_types[1:]
27-
arg_kinds = signature.arg_kinds[1:]
28-
29-
return signature.copy_modified(
30-
arg_kinds=[ARG_POS, *arg_kinds],
31-
arg_types=[str_type, *arg_types],
32-
)
33-
34-
35-
def replace_get_connection_sig_callback(ctx: FunctionSigContext) -> CallableType:
36-
signature = ctx.default_signature
37-
using_arg = signature.argument_by_name("using")
38-
if not using_arg:
39-
ctx.api.fail("The using parameter is required", ctx.context)
4028

41-
str_type = ctx.api.named_generic_type("builtins.str", [])
29+
def replace_transaction_atomic_sig_callback(ctx: FunctionSigContext) -> CallableType:
30+
sig = ctx.default_signature
4231

43-
return signature.copy_modified(arg_kinds=[ARG_POS], arg_types=[str_type])
32+
if not sig.argument_by_name("using"):
33+
# No using arg in the signature, bail
34+
return sig
4435

36+
# We care about context managers.
37+
if not isinstance(sig.ret_type, Instance):
38+
return sig
4539

46-
def replace_trailing_using_sig_callback(ctx: FunctionSigContext) -> CallableType:
47-
signature = ctx.default_signature
48-
using_arg = signature.argument_by_name("using")
49-
if not using_arg:
50-
ctx.api.fail("The using parameter is required", ctx.context)
40+
return _make_using_required_str(ctx)
5141

52-
# Update the parameter type to be required and str
53-
str_type = ctx.api.named_generic_type("builtins.str", [])
54-
arg_kinds = signature.arg_kinds[0:-1]
55-
arg_types = signature.arg_types[0:-1]
5642

57-
return signature.copy_modified(
58-
arg_kinds=[*arg_kinds, ARG_POS], arg_types=[*arg_types, str_type]
59-
)
43+
_FUNCTION_SIGNATURE_HOOKS = {
44+
"django.db.transaction.atomic": replace_transaction_atomic_sig_callback,
45+
"django.db.transaction.get_connection": _make_using_required_str,
46+
"django.db.transaction.on_commit": _make_using_required_str,
47+
"django.db.transaction.set_rollback": _make_using_required_str,
48+
}
6049

6150

6251
class SentryMypyPlugin(Plugin):
6352
def get_function_signature_hook(
6453
self, fullname: str
6554
) -> Callable[[FunctionSigContext], FunctionLike] | None:
66-
if fullname == "django.db.transaction.atomic":
67-
return replace_transaction_atomic_sig_callback
68-
if fullname == "django.db.transaction.get_connection":
69-
return replace_get_connection_sig_callback
70-
if fullname == "django.db.transaction.on_commit":
71-
return replace_trailing_using_sig_callback
72-
if fullname == "django.db.transaction.set_rollback":
73-
return replace_trailing_using_sig_callback
74-
return None
55+
return _FUNCTION_SIGNATURE_HOOKS.get(fullname)
7556

7657

7758
def plugin(version: str) -> type[SentryMypyPlugin]:

0 commit comments

Comments
 (0)