1
1
"""
2
2
(beta) Running the compiled optimizer with an LR Scheduler
3
+ ============================================================
3
4
4
5
**Author:** `Michael Lazos <https://github.com/mlazos>`_
5
6
"""
37
38
#####################################################################
38
39
# Setting up and running the compiled optimizer with LR Scheduler
39
40
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
41
+ #
40
42
# In this section, we'll use the Adam optimizer with LinearLR Scheduler
41
43
# and create a helper function to wrap the ``step()`` call for each of them
42
44
# in ``torch.compile()``.
46
48
# ``torch.compile`` is only supported on CUDA devices that have a compute capability of 7.0 or higher.
47
49
48
50
49
- # exit cleanly if we are on a device that doesn't support torch.compile
51
+ # exit cleanly if we are on a device that doesn't support `` torch.compile``
50
52
if torch .cuda .get_device_capability () < (7 , 0 ):
51
53
print ("Exiting because torch.compile is not supported on this device." )
52
54
import sys
@@ -70,14 +72,6 @@ def fn():
70
72
fn ()
71
73
print (opt .param_groups [0 ]["lr" ])
72
74
73
- ########################################################################
74
- # Sample Output:
75
- #
76
- # >> tensor(0.0047)
77
- # >> tensor(0.0060)
78
- # >> tensor(0.0073)
79
- # >> tensor(0.0087)
80
- # >> tensor(0.0100)
81
75
82
76
######################################################################
83
77
# Extension: What happens with a non-tensor LR?
@@ -106,28 +100,30 @@ def fn():
106
100
107
101
######################################################################
108
102
# Sample Output:
103
+ #
104
+ # .. code-block:: bash
109
105
#
110
- # >>[DEBUG]:Recompiling function step in /data/users/mlazos/pytorch/torch/optim/adam.py:191
111
- # >> triggered by the following guard failure(s):
112
- # >> - L['self'].param_groups[0]['lr'] == 0.003333333333333333
113
- # >>[DEBUG]:Recompiling function step in /data/users/mlazos/pytorch/torch/optim/adam.py:191
114
- # >> triggered by the following guard failure(s):
115
- # >> - L['self'].param_groups[0]['lr'] == 0.004666666666666667
116
- # >> - L['self'].param_groups[0]['lr'] == 0.003333333333333333
117
- # >>[DEBUG]:Recompiling function step in /data/users/mlazos/pytorch/torch/optim/adam.py:191
118
- # >> triggered by the following guard failure(s):
119
- # >> - L['self'].param_groups[0]['lr'] == 0.006000000000000001
120
- # >> - L['self'].param_groups[0]['lr'] == 0.004666666666666667
121
- # >> - L['self'].param_groups[0]['lr'] == 0.003333333333333333
122
- # >>[DEBUG]:Recompiling function step in /data/users/mlazos/pytorch/torch/optim/adam.py:191
123
- # >> triggered by the following guard failure(s):
124
- # >> - L['self'].param_groups[0]['lr'] == 0.007333333333333335
125
- # >> - L['self'].param_groups[0]['lr'] == 0.006000000000000001
126
- # >> - L['self'].param_groups[0]['lr'] == 0.004666666666666667
127
- # >> - L['self'].param_groups[0]['lr'] == 0.003333333333333333
106
+ # >>[DEBUG]:Recompiling function step in /data/users/mlazos/pytorch/torch/optim/adam.py:191
107
+ # >> triggered by the following guard failure(s):
108
+ # >> - L['self'].param_groups[0]['lr'] == 0.003333333333333333
109
+ # >>[DEBUG]:Recompiling function step in /data/users/mlazos/pytorch/torch/optim/adam.py:191
110
+ # >> triggered by the following guard failure(s):
111
+ # >> - L['self'].param_groups[0]['lr'] == 0.004666666666666667
112
+ # >> - L['self'].param_groups[0]['lr'] == 0.003333333333333333
113
+ # >>[DEBUG]:Recompiling function step in /data/users/mlazos/pytorch/torch/optim/adam.py:191
114
+ # >> triggered by the following guard failure(s):
115
+ # >> - L['self'].param_groups[0]['lr'] == 0.006000000000000001
116
+ # >> - L['self'].param_groups[0]['lr'] == 0.004666666666666667
117
+ # >> - L['self'].param_groups[0]['lr'] == 0.003333333333333333
118
+ # >>[DEBUG]:Recompiling function step in /data/users/mlazos/pytorch/torch/optim/adam.py:191
119
+ # >> triggered by the following guard failure(s):
120
+ # >> - L['self'].param_groups[0]['lr'] == 0.007333333333333335
121
+ # >> - L['self'].param_groups[0]['lr'] == 0.006000000000000001
122
+ # >> - L['self'].param_groups[0]['lr'] == 0.004666666666666667
123
+ # >> - L['self'].param_groups[0]['lr'] == 0.003333333333333333
128
124
#
129
- # With this example, we can see that we recompile the optimizer 4 additional
130
- # due to the guard failure on the 'lr' in param_groups[0]
125
+ # With this example, we can see that we recompile the optimizer 4 additional times
126
+ # due to the guard failure on the 'lr' in param_groups[0].
131
127
132
128
######################################################################
133
129
# Conclusion
@@ -139,5 +135,6 @@ def fn():
139
135
# with a LinearLR scheduler to demonstrate the LR changing across iterations.
140
136
#
141
137
# See also:
142
- # * tutorial on the compiled optimizer - `Compiled optimizer tutorial <https://pytorch.org/tutorials/recipes/compiling_optimizer.html>`_
143
- # * deeper technical details on the compiled optimizer see `Compiling the optimizer with PT2 <https://dev-discuss.pytorch.org/t/compiling-the-optimizer-with-pt2/1669>`_
138
+ #
139
+ # * `Compiled optimizer tutorial <https://pytorch.org/tutorials/recipes/compiling_optimizer.html>`__ - an intro into the compiled optimizer.
140
+ # * `Compiling the optimizer with PT2 <https://dev-discuss.pytorch.org/t/compiling-the-optimizer-with-pt2/1669>`__ - deeper technical details on the compiled optimizer.
0 commit comments