|
| 1 | +""" |
| 2 | +(beta) Compiling the optimizer with torch.compile |
| 3 | +========================================================================================== |
| 4 | +
|
| 5 | +
|
| 6 | +**Author:** `Michale Lazos <https://github.com/mlazos>`_ |
| 7 | +""" |
| 8 | + |
| 9 | +###################################################################### |
| 10 | +# Summary |
| 11 | +# ~~~~~~~~ |
| 12 | +# |
| 13 | +# In this tutorial we will apply torch.compile to the optimizer to observe |
| 14 | +# the GPU performance improvement |
| 15 | +# |
| 16 | +# .. note:: |
| 17 | +# |
| 18 | +# This tutorial requires PyTorch 2.2.0 or later. |
| 19 | +# |
| 20 | + |
| 21 | + |
| 22 | +###################################################################### |
| 23 | +# Model Setup |
| 24 | +# ~~~~~~~~~~~~~~~~~~~~~ |
| 25 | +# For this example we'll use a simple sequence of linear layers. |
| 26 | +# Since we are only benchmarking the optimizer, choice of model doesn't matter |
| 27 | +# because optimizer performance is a function of the number of parameters. |
| 28 | +# |
| 29 | +# Depending on what machine you are using, your exact results may vary. |
| 30 | + |
| 31 | +import torch |
| 32 | + |
| 33 | +model = torch.nn.Sequential( |
| 34 | + *[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)] |
| 35 | +) |
| 36 | +input = torch.rand(1024, device="cuda") |
| 37 | +output = model(input) |
| 38 | +output.sum().backward() |
| 39 | + |
| 40 | +############################################################################# |
| 41 | +# Setting up and running the optimizer benchmark |
| 42 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 43 | +# In this example, we'll use the Adam optimizer |
| 44 | +# and create a helper function to wrap the step() |
| 45 | +# in torch.compile() |
| 46 | + |
| 47 | +opt = torch.optim.Adam(model.parameters(), lr=0.01) |
| 48 | + |
| 49 | + |
| 50 | +@torch.compile() |
| 51 | +def fn(): |
| 52 | + opt.step() |
| 53 | + |
| 54 | + |
| 55 | +# Lets define a helpful benchmarking function: |
| 56 | +import torch.utils.benchmark as benchmark |
| 57 | + |
| 58 | + |
| 59 | +def benchmark_torch_function_in_microseconds(f, *args, **kwargs): |
| 60 | + t0 = benchmark.Timer( |
| 61 | + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} |
| 62 | + ) |
| 63 | + return t0.blocked_autorange().mean * 1e6 |
| 64 | + |
| 65 | + |
| 66 | +# Warmup runs to compile the function |
| 67 | +for _ in range(5): |
| 68 | + fn() |
| 69 | + |
| 70 | +print(f"eager runtime: {benchmark_torch_function_in_microseconds(opt.step)}us") |
| 71 | +print(f"compiled runtime: {benchmark_torch_function_in_microseconds(fn)}us") |
0 commit comments