Skip to content

Commit 8cf285d

Browse files
aorenstefacebook-github-bot
authored andcommitted
[BE] typing for decorators - fx/_compatibility (pytorch#134054)
Summary: X-link: ctrl-labs/src2#33884 X-link: pytorch/executorch#4810 X-link: pytorch/torchrec#2322 Pull Request resolved: pytorch#134054 See pytorch#131429 Test Plan: unit tests pass Differential Revision: D61493706
1 parent e52e93e commit 8cf285d

24 files changed

+10
-29
lines changed

torch/_export/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
import copy
43
import dataclasses

torch/fx/_compatibility.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
# mypy: allow-untyped-defs
2-
from typing import Any, Dict
1+
from typing import Any, Dict, Callable, TypeVar
32
import textwrap
43

54
_BACK_COMPAT_OBJECTS : Dict[Any, None] = {}
65
_MARKED_WITH_COMPATIBILITY : Dict[Any, None] = {}
76

8-
def compatibility(is_backward_compatible : bool):
7+
_T = TypeVar("_T")
8+
9+
def compatibility(is_backward_compatible: bool) -> Callable[[_T], _T]:
910
if is_backward_compatible:
1011

11-
def mark_back_compat(fn):
12+
def mark_back_compat(fn: _T) -> _T:
1213
docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
1314
docstring += """
1415
.. note::
@@ -22,7 +23,7 @@ def mark_back_compat(fn):
2223
return mark_back_compat
2324
else:
2425

25-
def mark_not_back_compat(fn):
26+
def mark_not_back_compat(fn: _T) -> _T:
2627
docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
2728
docstring += """
2829
.. warning::

torch/fx/_lazy_graph_module.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
from contextlib import contextmanager
43

torch/fx/_symbolic_trace.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
import builtins
43
import copy

torch/fx/experimental/proxy_tensor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from typing_extensions import Concatenate, ParamSpec, Self
4040
from weakref import WeakKeyDictionary
41+
from collections.abc import MutableMapping
4142

4243
import torch
4344
import torch._ops
@@ -201,6 +202,7 @@ def set_proxy_slot(
201202
if isinstance(obj, Tensor):
202203
# We DO want to clobber proxies whenever we run an inplace operation
203204
# on a tensor, and it affects the metadata on the proxy.
205+
assert isinstance(proxy, _ProxyTensor)
204206
tracer.tensor_tracker[obj] = proxy
205207
elif isinstance(obj, (_AnyScriptObject)):
206208
# We DO want to clobber proxies, with a similar rationale as for tensors.
@@ -991,7 +993,7 @@ class PythonKeyTracer(Tracer):
991993

992994
def __init__(self) -> None:
993995
super().__init__(autowrap_modules=()) # type: ignore[arg-type]
994-
self.tensor_tracker = WeakTensorKeyDictionary()
996+
self.tensor_tracker: MutableMapping[Tensor, _ProxyTensor] = WeakTensorKeyDictionary()
995997
self.symnode_tracker = _SymNodeDict()
996998
self.script_object_tracker = WeakIdKeyDictionary(
997999
dict=None, ref_type=_WeakHashRef
@@ -1365,7 +1367,7 @@ def __sym_dispatch__(
13651367
class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer):
13661368
script_object_tracker: WeakKeyDictionary
13671369
symnode_tracker: WeakKeyDictionary
1368-
tensor_tracker: WeakTensorKeyDictionary
1370+
tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
13691371
sympy_expr_tracker: Dict[sympy.Symbol, object]
13701372
torch_fn_metadata: Optional[OpOverload]
13711373
torch_fn_counts: Dict[OpOverload, int]

torch/fx/graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
from collections import defaultdict
43
from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name

torch/fx/graph_module.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
import contextlib
43
import copy

torch/fx/interpreter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
from .graph_module import GraphModule
43
from ._lazy_graph_module import _make_graph_module

torch/fx/node.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# Nodes represent a definition of a value in our graph of operators.
32
from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set
43
from ._compatibility import compatibility

torch/fx/operator_schemas.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
import torch
43
import inspect

torch/fx/passes/graph_manipulation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
from typing import Any, Dict, List, NamedTuple, Optional
43

torch/fx/passes/infra/pass_manager.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
import inspect
43
import logging

torch/fx/passes/operator_support.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
import abc
43
import typing as t

torch/fx/passes/param_fetch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
from torch.fx.graph_module import GraphModule
32
from typing import Any, Callable, Dict, List, Tuple, Type
43
import torch

torch/fx/passes/runtime_assert.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
import logging
43
import operator

torch/fx/passes/split_module.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
import inspect
43
from typing import Any, Callable, Dict, List, Optional, Set

torch/fx/passes/split_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
import copy
43
from dataclasses import dataclass, field

torch/fx/passes/splitter_base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
import argparse
43
import copy

torch/fx/passes/tools_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
from typing import List, Tuple, Union, Dict, Any, Set, Mapping, Optional
43
import collections

torch/fx/passes/utils/common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
from typing import Dict, Tuple
43

torch/fx/passes/utils/fuser_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
import copy
43
from queue import SimpleQueue

torch/fx/subgraph_rewriter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
from .graph_module import GraphModule
32
from .graph import Graph
43
from .node import Node

torch/fx/traceback.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
import traceback
43
from contextlib import contextmanager

torch/onnx/_internal/onnxruntime.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
import dataclasses
43
import importlib

0 commit comments

Comments
 (0)