Skip to content

Commit d5ad86e

Browse files
aorenstefacebook-github-bot
authored andcommitted
[BE] typing for decorators - fx/_compatibility (pytorch#134054)
Summary: X-link: ctrl-labs/src2#33884 X-link: pytorch/executorch#4810 X-link: pytorch/torchrec#2322 Pull Request resolved: pytorch#134054 See pytorch#131429 Test Plan: unit tests pass Differential Revision: D61493706
1 parent 7868b65 commit d5ad86e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+195
-217
lines changed

torch/_dynamo/compiled_autograd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ def to_proxy(self, t):
445445

446446
def bind_tensors_to_proxies(self, tensors, proxies):
447447
if isinstance(proxies, torch.fx.Proxy):
448-
proxies = [proxies[i] for i in range(len(tensors))]
448+
proxies = [proxies[i] for i in range(len(tensors))] # type: ignore[index]
449449
assert len(tensors) == len(proxies)
450450
track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer)
451451

torch/_dynamo/eval_frame.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,7 +1539,7 @@ def result_capturing_wrapper(*graph_inputs):
15391539
# Running graph with interpreter is needed for propagating the stack_trace
15401540
def graph_with_interpreter(*args):
15411541
with torch.fx.traceback.preserve_node_meta():
1542-
return torch.fx.Interpreter(graph).run(*args)
1542+
return torch.fx.Interpreter(graph).run(*args) # type: ignore[arg-type]
15431543

15441544
with unset_fake_temporarily(), enable_python_dispatcher(), fake_mode:
15451545
try:
@@ -1561,9 +1561,9 @@ def graph_with_interpreter(*args):
15611561

15621562
assert graph is not None
15631563
for node in graph.graph.find_nodes(op="get_attr"):
1564-
if isinstance(getattr(graph, node.target), torch.Tensor):
1564+
if isinstance(getattr(graph, node.target), torch.Tensor): # type: ignore[arg-type]
15651565
node.meta["val"] = fake_mode.from_tensor(
1566-
getattr(graph, node.target), static_shapes=True
1566+
getattr(graph, node.target), static_shapes=True # type: ignore[arg-type]
15671567
)
15681568

15691569
if same_signature:

torch/_dynamo/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2151,7 +2151,7 @@ def get_real_value(node, tracer):
21512151
return cache[node]
21522152

21532153
op = node.op
2154-
args, kwargs = torch.fx.node.map_arg(
2154+
args, kwargs = torch.fx.node.map_arg( # type: ignore[misc]
21552155
(node.args, node.kwargs),
21562156
lambda n: get_real_value(n, tracer),
21572157
)

torch/_export/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: allow-untyped-decorators
21
# mypy: allow-untyped-defs
32
import copy
43
import dataclasses

torch/_export/converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1343,7 +1343,7 @@ def __init__(
13431343
self.name_to_node,
13441344
# Dummy node.
13451345
torch.fx.Node(
1346-
None,
1346+
None, # type: ignore[arg-type]
13471347
"mock",
13481348
"call_function",
13491349
lambda: None,

torch/_export/pass_base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeG
6868
self.fake_tensor_mode: Optional[FakeTensorMode] = None
6969
self.submodules: Dict[torch.nn.Module, str] = {}
7070

71-
def trace(self) -> None:
71+
def trace(self) -> None: # type: ignore[override]
7272
raise ExportPassBaseError("ExportTracer doesn't support trace().")
7373

7474
def create_arg(self, a: Argument) -> torch.fx.Node:
@@ -160,7 +160,7 @@ def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphMo
160160

161161
def placeholder(
162162
self,
163-
target: str,
163+
target: str, # type: ignore[override]
164164
args: Tuple[Argument, ...],
165165
kwargs: Dict[str, Argument],
166166
) -> ProxyValue:
@@ -218,7 +218,7 @@ def call_function(
218218
raise ExportPassBaseError(f"Unsupported target type: {target}")
219219

220220
def get_attr(
221-
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
221+
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
222222
) -> Argument:
223223
return super().get_attr(target, args, kwargs)
224224

@@ -231,7 +231,7 @@ def call_module(
231231
raise ExportPassBaseError("call_module is not supported.")
232232

233233
def call_method(
234-
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
234+
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
235235
) -> None:
236236
raise ExportPassBaseError("call_method is not supported.")
237237

@@ -394,7 +394,7 @@ def call_submodule(
394394
)
395395
self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
396396
interpreter = self.ExportInterpreter(self, graph_module)
397-
prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter(
397+
prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( # type: ignore[assignment]
398398
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
399399
)
400400
inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs)

torch/_export/passes/replace_with_hop_pass_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def set_hoo_node_meta(call_func_node):
6464
# Rename the name of getitem nodes to the actual name of its contents
6565
# for passing verifier and better readability, also propagate metadata
6666
for get_item_node in call_func_node.users.keys():
67-
idx: int = get_item_node.args[1]
67+
idx: int = get_item_node.args[1] # type: ignore[assignment]
6868
output_node = output_args[idx]
6969
get_item_node._rename(output_node.name)
7070
get_item_node.meta = output_node.meta

torch/_functorch/partitioners.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def _extract_graph_with_inputs_outputs(
186186
# joint_graph.nodes).
187187
continue
188188
elif node.op == "placeholder":
189-
env[node] = InvalidNode
189+
env[node] = InvalidNode # type: ignore[assignment]
190190
elif node.op == "call_function":
191191
all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs)
192192
all_args = [
@@ -195,7 +195,7 @@ def _extract_graph_with_inputs_outputs(
195195
if isinstance(x, fx.Node)
196196
]
197197
if any(all_args):
198-
env[node] = InvalidNode
198+
env[node] = InvalidNode # type: ignore[assignment]
199199
continue
200200
env[node] = new_graph.node_copy(node, lambda x: env[x])
201201
elif node.op == "get_attr":

torch/_higher_order_ops/flex_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,7 @@ def trace_flex_attention_backward(
817817
)
818818
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
819819
block_mask = block_mask[:-1] + (mask_graph,)
820-
proxy_mode.tracer.root.register_module("fw_graph", fw_graph)
820+
proxy_mode.tracer.root.register_module("fw_graph", fw_graph) # type: ignore[arg-type]
821821
proxy_mode.tracer.root.register_module("joint_graph", joint_graph)
822822
proxy_mode.tracer.root.register_module("mask_graph", mask_graph)
823823
node_args = (

torch/_inductor/constant_folding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
222222
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
223223
self.node_replacements[node] = tensor
224224

225-
def run(self) -> Any:
225+
def run(self) -> Any: # type: ignore[override]
226226
env: Dict[torch.fx.Node, Any] = {}
227227
self.insert_placerholder_values(env)
228228
return super().run(initial_env=env)

torch/_inductor/debug.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def func1(*args: Any) -> int:
144144
kwargs = {}
145145
if hasattr(snode, "get_device"):
146146
kwargs = {"device": snode.get_device()}
147-
fx_node = graph.call_function(node_func, args=(), kwargs=kwargs)
147+
fx_node = graph.call_function(node_func, args=(), kwargs=kwargs) # type: ignore[arg-type]
148148

149149
def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool:
150150
if isinstance(snode, FusedSchedulerNode):

torch/_inductor/fx_passes/binary_folding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,17 +154,17 @@ def _check_conv_and_broadcast_op(conv_node, other):
154154
return False
155155
if isinstance(other, torch.fx.Node) and other.op == "get_attr":
156156
other_meta_value = other.meta.get("val")
157-
if not other_meta_value.is_floating_point():
157+
if not other_meta_value.is_floating_point(): # type: ignore[union-attr]
158158
return False
159159
if (
160-
torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype)
160+
torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype) # type: ignore[union-attr]
161161
!= weight_meta_value.dtype
162162
):
163163
if not conv_node.meta.get("_allow_conv_mixed_dtype_folding", False):
164164
return False
165165

166166
if (
167-
other_meta_value.dtype != torch.float
167+
other_meta_value.dtype != torch.float # type: ignore[union-attr]
168168
and weight_meta_value.dtype not in (torch.float16, torch.bfloat16)
169169
):
170170
return False

torch/_inductor/fx_passes/efficient_conv_bn_eval.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs
213213
new_node = graph.create_node(
214214
op="call_function",
215215
target=efficient_conv_bn_eval_decomposed,
216-
args=args,
216+
args=args, # type: ignore[arg-type]
217217
name="efficient_conv_bn_eval",
218218
)
219219

@@ -223,7 +223,7 @@ def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs
223223
# take care of the deletion order:
224224
# delete bn_node first, and then conv_node
225225
graph.erase_node(bn_node)
226-
graph.erase_node(conv_node)
226+
graph.erase_node(conv_node) # type: ignore[arg-type]
227227

228228
return
229229

@@ -304,7 +304,7 @@ def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwa
304304
new_node = graph.create_node(
305305
op="call_function",
306306
target=efficient_conv_bn_eval_decomposed,
307-
args=args,
307+
args=args, # type: ignore[arg-type]
308308
name="efficient_conv_bn_eval",
309309
)
310310

@@ -314,7 +314,7 @@ def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwa
314314
# take care of the deletion order:
315315
# delete bn_node first, and then conv_node
316316
graph.erase_node(bn_node)
317-
graph.erase_node(conv_node)
317+
graph.erase_node(conv_node) # type: ignore[arg-type]
318318

319319
return
320320

@@ -373,7 +373,7 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
373373
# Find a pair of conv and bn computation nodes to optimize.
374374
counters["inductor"]["efficient_conv_bn_eval"] += 1
375375

376-
with graph.inserting_before(conv_node):
376+
with graph.inserting_before(conv_node): # type: ignore[arg-type]
377377
# create `get_attr` node to access modules
378378
# note that we directly call `create_node` to fill the `name`
379379
# argument. `graph.get_attr` and
@@ -403,4 +403,4 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
403403
# take care of the deletion order:
404404
# delete bn_node first, and then conv_node
405405
graph.erase_node(bn_node)
406-
graph.erase_node(conv_node)
406+
graph.erase_node(conv_node) # type: ignore[arg-type]

torch/_inductor/fx_passes/freezing_patterns.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,5 +223,5 @@ def unnecessary_dtype_convert(match: Match, **kwargs):
223223
"""Remove unnecessary dtype conversion op, probably left as a result of Conv-Bn folding"""
224224
graph = match.graph
225225
node = match.output_node()
226-
node.replace_all_uses_with(node.args[0])
226+
node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type]
227227
graph.erase_node(node)

torch/_inductor/fx_passes/joint_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ def pointless_view(match: Match, arg, size):
511511
node = match.output_node()
512512
arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr]
513513
if size == arg_size:
514-
node.replace_all_uses_with(node.args[0])
514+
node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type]
515515
match.erase_nodes(graph)
516516

517517

torch/_inductor/fx_passes/post_grad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,7 @@ def is_index_put_and_requires_h2d_sync_for_gpu_value(node):
10331033
# if the value we are putting is a cpu scalar.
10341034
# Therefore, when inductor sees an index_put_ with byte tensor indices,
10351035
# it should *not* convert the cpu scalar value into a gpu tensor.
1036-
args_, kwargs_ = normalize_function(node.target, node.args, node.kwargs)
1036+
args_, kwargs_ = normalize_function(node.target, node.args, node.kwargs) # type: ignore[misc]
10371037
any_byte_bool_indices = False
10381038
indices = args_[1]
10391039
for i in indices:

torch/_inductor/fx_passes/quantization.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,14 +1887,14 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
18871887
graph.erase_node(conv_node)
18881888
# Erase the dequant pattern
18891889
if dtype == torch.bfloat16:
1890-
graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined]
1891-
graph.erase_node(dequant_node)
1890+
graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined, arg-type]
1891+
graph.erase_node(dequant_node) # type: ignore[arg-type]
18921892
# Erase the dequant per channel pattern
18931893
if clone_node is not None:
1894-
graph.erase_node(clone_node)
1894+
graph.erase_node(clone_node) # type: ignore[arg-type]
18951895
if dtype == torch.bfloat16:
1896-
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
1897-
graph.erase_node(dequant_per_channel)
1896+
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type]
1897+
graph.erase_node(dequant_per_channel) # type: ignore[arg-type]
18981898
counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1
18991899
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len(
19001900
match.nodes
@@ -2581,8 +2581,8 @@ def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
25812581

25822582
new_args = map_arg(new_quant_node.args, maybe_replace_node)
25832583
new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node)
2584-
new_quant_node.args = new_args
2585-
new_quant_node.kwargs = new_kwargs
2584+
new_quant_node.args = new_args # type: ignore[assignment]
2585+
new_quant_node.kwargs = new_kwargs # type: ignore[assignment]
25862586
graph_module.graph.erase_node(quant_node)
25872587

25882588
graph_module.graph.lint()

torch/_inductor/fx_passes/reinplace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def can_fuse():
293293
return
294294

295295
src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr]
296-
with graph.inserting_before(src):
296+
with graph.inserting_before(src): # type: ignore[arg-type]
297297
new_node = graph_call_function(
298298
graph,
299299
_generalized_scatter,
@@ -316,7 +316,7 @@ def can_fuse():
316316
handle_views(new_src)
317317
src.replace_all_uses_with(new_src) # type: ignore[union-attr]
318318

319-
graph.erase_node(src)
319+
graph.erase_node(src) # type: ignore[arg-type]
320320

321321
for node in graph.nodes:
322322
if _is_view_op(node.target):

torch/_inductor/fx_passes/split_cat.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def normalize_split_base(
193193
new_split_node = graph.call_function(
194194
torch.split,
195195
args=new_args,
196-
kwargs=new_kwargs,
196+
kwargs=new_kwargs, # type: ignore[arg-type]
197197
)
198198
split_node.replace_all_uses_with(new_split_node)
199199
new_split_node.meta.update(split_node.meta)
@@ -378,7 +378,7 @@ def normalize_stack_default(match: Match, *args, **kwargs):
378378

379379
with graph.inserting_after(node):
380380
new_node = graph.call_function(
381-
node.target,
381+
node.target, # type: ignore[arg-type]
382382
args=(tensors,),
383383
kwargs={"dim": dim},
384384
)
@@ -529,7 +529,7 @@ def merge_splits(
529529

530530
to_remove = []
531531

532-
with graph.inserting_before(first_split):
532+
with graph.inserting_before(first_split): # type: ignore[arg-type]
533533
# Add the new split node
534534
new_split = graph.call_function(
535535
torch.split,
@@ -1535,7 +1535,7 @@ def mutate_cat_node(match: Match, split_sections: List[int], dim: int):
15351535
# case 1: the cat uses all getitems from the split
15361536
if len(split_sections) == len(cat_user.args[0]): # type: ignore[arg-type]
15371537
# replace the users of the cat node to be the input of the split node
1538-
cat_user.replace_all_uses_with(split_node.args[0])
1538+
cat_user.replace_all_uses_with(split_node.args[0]) # type: ignore[arg-type]
15391539
# remove the cat node
15401540
graph.erase_node(cat_user)
15411541
counters["inductor"]["mutate_cat_pass"] += 1
@@ -1947,7 +1947,7 @@ def remove_split_unbind_children(graph: torch.fx.Graph, inputs: List[torch.fx.No
19471947
# check the split node to remove if it has no users
19481948
for node in nodes:
19491949
if len(node.users.keys()) == 0: # type: ignore[union-attr]
1950-
graph.erase_node(node)
1950+
graph.erase_node(node) # type: ignore[arg-type]
19511951

19521952

19531953
# ############pattern to be optimized is#########

torch/_inductor/graph.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ def get_numel(self, buffer_name: str) -> Union[int, Expr]:
766766
return self.graph_inputs[buffer_name].get_numel()
767767
raise KeyError(f"could not find {buffer_name}")
768768

769-
def run(self, *args: Any) -> Any:
769+
def run(self, *args: Any) -> Any: # type: ignore[override]
770770
with dynamo_timed("GraphLowering.run"):
771771
return super().run(*args)
772772

@@ -911,9 +911,9 @@ def constant_name(self, name: str, device_override: Optional[torch.device]) -> s
911911
)
912912

913913
def placeholder(
914-
self, target: str, args: Tuple[object], kwargs: Dict[str, object]
914+
self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override]
915915
) -> Union[Expr, TensorBox, None]:
916-
example = super().placeholder(target, args, kwargs)
916+
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
917917
self.graph_input_names.append(target)
918918
if isinstance(example, SymTypes):
919919
expr = example.node.expr
@@ -968,7 +968,7 @@ def placeholder(
968968
self.aligned_inputs.add(target)
969969
return tensor
970970

971-
def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg]
971+
def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg, override]
972972
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
973973
return super().call_function(target, args, kwargs)
974974

@@ -1023,7 +1023,7 @@ def can_inline_constant(t: torch.Tensor) -> bool:
10231023
return len(t.shape) == 1 and t.shape[0] <= 8
10241024

10251025
def get_attr(
1026-
self, target: str, args: Tuple[()], kwargs: Dict[str, object]
1026+
self, target: str, args: Tuple[()], kwargs: Dict[str, object] # type: ignore[override]
10271027
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
10281028
# this is a constant
10291029
value = getattr_recursive(self.module, target) # type: ignore[arg-type]
@@ -1063,9 +1063,9 @@ def call_method(self, target: Any, args: Any, kwargs: Any) -> NoReturn:
10631063
raise AssertionError
10641064

10651065
def output(
1066-
self, target: str, args: Tuple[object], kwargs: Dict[str, object]
1066+
self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override]
10671067
) -> None:
1068-
result = super().output(target, args, kwargs)
1068+
result = super().output(target, args, kwargs) # type: ignore[arg-type]
10691069
if not isinstance(result, (tuple, list)):
10701070
# nested subgraphs can have singleton outputs
10711071
result = (result,)

0 commit comments

Comments
 (0)