|
7 | 7 | from mypy.types import CallableType, FunctionLike, Instance
|
8 | 8 |
|
9 | 9 |
|
10 |
| -def replace_transaction_atomic_sig_callback(ctx: FunctionSigContext) -> CallableType: |
11 |
| - signature = ctx.default_signature |
| 10 | +def _make_using_required_str(ctx: FunctionSigContext) -> CallableType: |
| 11 | + sig = ctx.default_signature |
12 | 12 |
|
13 |
| - using_arg = signature.argument_by_name("using") |
14 |
| - if not using_arg: |
15 |
| - # No using arg in the signature, bail |
16 |
| - return signature |
| 13 | + using_arg = sig.argument_by_name("using") |
| 14 | + if using_arg is None or using_arg.pos is None: |
| 15 | + ctx.api.fail("The using parameter is required", ctx.context) |
| 16 | + return sig |
17 | 17 |
|
18 |
| - # We care about context managers. |
19 |
| - ret_type = signature.ret_type |
20 |
| - if not isinstance(ret_type, Instance): |
21 |
| - return signature |
| 18 | + for kind in sig.arg_kinds[: using_arg.pos]: |
| 19 | + if kind != ARG_POS: |
| 20 | + ctx.api.fail("Expected using to be the first optional", ctx.context) |
| 21 | + return sig |
22 | 22 |
|
23 |
| - # Replace the type and remove the default value of using. |
24 | 23 | str_type = ctx.api.named_generic_type("builtins.str", [])
|
| 24 | + arg_kinds = [*sig.arg_kinds[: using_arg.pos], ARG_POS, *sig.arg_kinds[using_arg.pos + 1 :]] |
| 25 | + arg_types = [*sig.arg_types[: using_arg.pos], str_type, *sig.arg_types[using_arg.pos + 1 :]] |
| 26 | + return sig.copy_modified(arg_kinds=arg_kinds, arg_types=arg_types) |
25 | 27 |
|
26 |
| - arg_types = signature.arg_types[1:] |
27 |
| - arg_kinds = signature.arg_kinds[1:] |
28 |
| - |
29 |
| - return signature.copy_modified( |
30 |
| - arg_kinds=[ARG_POS, *arg_kinds], |
31 |
| - arg_types=[str_type, *arg_types], |
32 |
| - ) |
33 |
| - |
34 |
| - |
35 |
| -def replace_get_connection_sig_callback(ctx: FunctionSigContext) -> CallableType: |
36 |
| - signature = ctx.default_signature |
37 |
| - using_arg = signature.argument_by_name("using") |
38 |
| - if not using_arg: |
39 |
| - ctx.api.fail("The using parameter is required", ctx.context) |
40 | 28 |
|
41 |
| - str_type = ctx.api.named_generic_type("builtins.str", []) |
| 29 | +def replace_transaction_atomic_sig_callback(ctx: FunctionSigContext) -> CallableType: |
| 30 | + sig = ctx.default_signature |
42 | 31 |
|
43 |
| - return signature.copy_modified(arg_kinds=[ARG_POS], arg_types=[str_type]) |
| 32 | + if not sig.argument_by_name("using"): |
| 33 | + # No using arg in the signature, bail |
| 34 | + return sig |
44 | 35 |
|
| 36 | + # We care about context managers. |
| 37 | + if not isinstance(sig.ret_type, Instance): |
| 38 | + return sig |
45 | 39 |
|
46 |
| -def replace_trailing_using_sig_callback(ctx: FunctionSigContext) -> CallableType: |
47 |
| - signature = ctx.default_signature |
48 |
| - using_arg = signature.argument_by_name("using") |
49 |
| - if not using_arg: |
50 |
| - ctx.api.fail("The using parameter is required", ctx.context) |
| 40 | + return _make_using_required_str(ctx) |
51 | 41 |
|
52 |
| - # Update the parameter type to be required and str |
53 |
| - str_type = ctx.api.named_generic_type("builtins.str", []) |
54 |
| - arg_kinds = signature.arg_kinds[0:-1] |
55 |
| - arg_types = signature.arg_types[0:-1] |
56 | 42 |
|
57 |
| - return signature.copy_modified( |
58 |
| - arg_kinds=[*arg_kinds, ARG_POS], arg_types=[*arg_types, str_type] |
59 |
| - ) |
| 43 | +_FUNCTION_SIGNATURE_HOOKS = { |
| 44 | + "django.db.transaction.atomic": replace_transaction_atomic_sig_callback, |
| 45 | + "django.db.transaction.get_connection": _make_using_required_str, |
| 46 | + "django.db.transaction.on_commit": _make_using_required_str, |
| 47 | + "django.db.transaction.set_rollback": _make_using_required_str, |
| 48 | +} |
60 | 49 |
|
61 | 50 |
|
62 | 51 | class SentryMypyPlugin(Plugin):
|
63 | 52 | def get_function_signature_hook(
|
64 | 53 | self, fullname: str
|
65 | 54 | ) -> Callable[[FunctionSigContext], FunctionLike] | None:
|
66 |
| - if fullname == "django.db.transaction.atomic": |
67 |
| - return replace_transaction_atomic_sig_callback |
68 |
| - if fullname == "django.db.transaction.get_connection": |
69 |
| - return replace_get_connection_sig_callback |
70 |
| - if fullname == "django.db.transaction.on_commit": |
71 |
| - return replace_trailing_using_sig_callback |
72 |
| - if fullname == "django.db.transaction.set_rollback": |
73 |
| - return replace_trailing_using_sig_callback |
74 |
| - return None |
| 55 | + return _FUNCTION_SIGNATURE_HOOKS.get(fullname) |
75 | 56 |
|
76 | 57 |
|
77 | 58 | def plugin(version: str) -> type[SentryMypyPlugin]:
|
|
0 commit comments