Skip to content

Commit dc62ced

Browse files
authored
Update compiling_optimizer_lr_scheduler.py
1 parent d759583 commit dc62ced

File tree

1 file changed

+28
-31
lines changed

1 file changed

+28
-31
lines changed

recipes_source/compiling_optimizer_lr_scheduler.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""
22
(beta) Running the compiled optimizer with an LR Scheduler
3+
============================================================
34
45
**Author:** `Michael Lazos <https://github.com/mlazos>`_
56
"""
@@ -37,6 +38,7 @@
3738
#####################################################################
3839
# Setting up and running the compiled optimizer with LR Scheduler
3940
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
41+
#
4042
# In this section, we'll use the Adam optimizer with LinearLR Scheduler
4143
# and create a helper function to wrap the ``step()`` call for each of them
4244
# in ``torch.compile()``.
@@ -46,7 +48,7 @@
4648
# ``torch.compile`` is only supported on CUDA devices that have a compute capability of 7.0 or higher.
4749

4850

49-
# exit cleanly if we are on a device that doesn't support torch.compile
51+
# exit cleanly if we are on a device that doesn't support ``torch.compile``
5052
if torch.cuda.get_device_capability() < (7, 0):
5153
print("Exiting because torch.compile is not supported on this device.")
5254
import sys
@@ -70,14 +72,6 @@ def fn():
7072
fn()
7173
print(opt.param_groups[0]["lr"])
7274

73-
########################################################################
74-
# Sample Output:
75-
#
76-
# >> tensor(0.0047)
77-
# >> tensor(0.0060)
78-
# >> tensor(0.0073)
79-
# >> tensor(0.0087)
80-
# >> tensor(0.0100)
8175

8276
######################################################################
8377
# Extension: What happens with a non-tensor LR?
@@ -106,28 +100,30 @@ def fn():
106100

107101
######################################################################
108102
# Sample Output:
103+
#
104+
# .. code-block:: bash
109105
#
110-
# >>[DEBUG]:Recompiling function step in /data/users/mlazos/pytorch/torch/optim/adam.py:191
111-
# >> triggered by the following guard failure(s):
112-
# >> - L['self'].param_groups[0]['lr'] == 0.003333333333333333
113-
# >>[DEBUG]:Recompiling function step in /data/users/mlazos/pytorch/torch/optim/adam.py:191
114-
# >> triggered by the following guard failure(s):
115-
# >> - L['self'].param_groups[0]['lr'] == 0.004666666666666667
116-
# >> - L['self'].param_groups[0]['lr'] == 0.003333333333333333
117-
# >>[DEBUG]:Recompiling function step in /data/users/mlazos/pytorch/torch/optim/adam.py:191
118-
# >> triggered by the following guard failure(s):
119-
# >> - L['self'].param_groups[0]['lr'] == 0.006000000000000001
120-
# >> - L['self'].param_groups[0]['lr'] == 0.004666666666666667
121-
# >> - L['self'].param_groups[0]['lr'] == 0.003333333333333333
122-
# >>[DEBUG]:Recompiling function step in /data/users/mlazos/pytorch/torch/optim/adam.py:191
123-
# >> triggered by the following guard failure(s):
124-
# >> - L['self'].param_groups[0]['lr'] == 0.007333333333333335
125-
# >> - L['self'].param_groups[0]['lr'] == 0.006000000000000001
126-
# >> - L['self'].param_groups[0]['lr'] == 0.004666666666666667
127-
# >> - L['self'].param_groups[0]['lr'] == 0.003333333333333333
106+
# >>[DEBUG]:Recompiling function step in /data/users/mlazos/pytorch/torch/optim/adam.py:191
107+
# >> triggered by the following guard failure(s):
108+
# >> - L['self'].param_groups[0]['lr'] == 0.003333333333333333
109+
# >>[DEBUG]:Recompiling function step in /data/users/mlazos/pytorch/torch/optim/adam.py:191
110+
# >> triggered by the following guard failure(s):
111+
# >> - L['self'].param_groups[0]['lr'] == 0.004666666666666667
112+
# >> - L['self'].param_groups[0]['lr'] == 0.003333333333333333
113+
# >>[DEBUG]:Recompiling function step in /data/users/mlazos/pytorch/torch/optim/adam.py:191
114+
# >> triggered by the following guard failure(s):
115+
# >> - L['self'].param_groups[0]['lr'] == 0.006000000000000001
116+
# >> - L['self'].param_groups[0]['lr'] == 0.004666666666666667
117+
# >> - L['self'].param_groups[0]['lr'] == 0.003333333333333333
118+
# >>[DEBUG]:Recompiling function step in /data/users/mlazos/pytorch/torch/optim/adam.py:191
119+
# >> triggered by the following guard failure(s):
120+
# >> - L['self'].param_groups[0]['lr'] == 0.007333333333333335
121+
# >> - L['self'].param_groups[0]['lr'] == 0.006000000000000001
122+
# >> - L['self'].param_groups[0]['lr'] == 0.004666666666666667
123+
# >> - L['self'].param_groups[0]['lr'] == 0.003333333333333333
128124
#
129-
# With this example, we can see that we recompile the optimizer 4 additional
130-
# due to the guard failure on the 'lr' in param_groups[0]
125+
# With this example, we can see that we recompile the optimizer 4 additional times
126+
# due to the guard failure on the 'lr' in param_groups[0].
131127

132128
######################################################################
133129
# Conclusion
@@ -139,5 +135,6 @@ def fn():
139135
# with a LinearLR scheduler to demonstrate the LR changing across iterations.
140136
#
141137
# See also:
142-
# * tutorial on the compiled optimizer - `Compiled optimizer tutorial <https://pytorch.org/tutorials/recipes/compiling_optimizer.html>`_
143-
# * deeper technical details on the compiled optimizer see `Compiling the optimizer with PT2 <https://dev-discuss.pytorch.org/t/compiling-the-optimizer-with-pt2/1669>`_
138+
#
139+
# * `Compiled optimizer tutorial <https://pytorch.org/tutorials/recipes/compiling_optimizer.html>`__ - an intro into the compiled optimizer.
140+
# * `Compiling the optimizer with PT2 <https://dev-discuss.pytorch.org/t/compiling-the-optimizer-with-pt2/1669>`__ - deeper technical details on the compiled optimizer.

0 commit comments

Comments
 (0)