Skip to content

Commit 3c8c7e2

Browse files
janselpytorchmergebot
authored andcommitted
[dynamo] Tweak naming for module hook bw_state (pytorch#121609)
Some minor changes not related to the other PRs in the stack Pull Request resolved: pytorch#121609 Approved by: https://github.com/yanboliang
1 parent 7a68e0a commit 3c8c7e2

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

torch/_dynamo/output_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,8 @@ def __init__(
433433
self.backward_state_proxy: Optional[torch.fx.Proxy] = None
434434
self.backward_state_var: Optional[str] = None
435435

436-
def add_backward_state_hook(self, hook: VariableTracker):
437-
name = f"hook{len(self.backward_state)}"
436+
def add_backward_state_hook(self, hook: VariableTracker, prefix="hook"):
437+
name = f"{prefix}{len(self.backward_state)}"
438438
assert name not in self.backward_state
439439
self.backward_state[name] = hook
440440
return name, self.get_backward_state_proxy()

torch/_dynamo/variables/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def _in_graph_bw_hooks(bw_state: BackwardState):
334334
),
335335
)
336336

337-
module_name, bw_state_proxy = tx.output.add_backward_state_hook(module)
337+
module_name, bw_state_proxy = tx.output.add_backward_state_hook(module, "mod")
338338
user_pre_hooks_name, _ = tx.output.add_backward_state_hook(user_pre_hooks)
339339
user_hooks_name, _ = tx.output.add_backward_state_hook(user_hooks)
340340
proxy = tx.output.create_proxy(

torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import logging
1111
from contextlib import nullcontext
1212
from functools import wraps
13-
from typing import Any, List, Optional
13+
from typing import Any, List, Optional, Sequence
1414

1515
import torch
1616
import torch.utils.dlpack
@@ -175,6 +175,12 @@ def rng_functionalization_wrapper(args):
175175
return compiled_fn
176176

177177

178+
def _output_node(gm: torch.fx.GraphModule) -> torch.fx.Node:
179+
"""Return the output node of a graph"""
180+
# reversed() since we expect output at end of graph
181+
return next(n for n in reversed(gm.graph.nodes) if n.op == "output")
182+
183+
178184
def aot_dispatch_autograd(
179185
flat_fn,
180186
flat_args: List[Any],
@@ -295,8 +301,8 @@ def aot_dispatch_autograd(
295301
# and we will end up with a zero grad at x.
296302
# If we later backprop through the second output, this will also require backprop'ing through x.
297303
# Meaning we'll need to use `retain_graph=True` to be able to backprop through x the second time.
298-
_indices_of_inps_to_detach = []
299-
bw_outs = next(n for n in bw_module.graph.nodes if n.op == "output").args[0]
304+
_indices_of_inps_to_detach: List[int] = []
305+
bw_outs: Sequence[torch.fx.Node] = _output_node(bw_module).args[0] # type: ignore[assignment]
300306

301307
# TODO: we should apply the below "detach inputs if their gradients are statically known to be None"
302308
# optimization even if we have subclass inputs/outputs (we do not handle this today).

0 commit comments

Comments
 (0)