Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 47facc8

Browse files
drisspgfacebook-github-bot
authored andcommitted
Mark fp8 buffers as static (#225)
Summary: Thank you eellison, based off of this repro: #119 (comment) Marking the individual buffers allows for cuda graphs to be used. Pull Request resolved: #225 Reviewed By: awgu Differential Revision: D54178086 Pulled By: drisspg fbshipit-source-id: 8797045b2a88825601b0fe7c8cadc03f557af96e
1 parent b508920 commit 47facc8

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

float8_experimental/float8_linear_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,10 @@ def get_float8_layers(model: torch.nn.Module):
148148

149149
# Get all fp8 layers and tensors
150150
fp8_layers = [child for child in model.modules() if isinstance(child, Float8Linear)]
151-
151+
if not torch._dynamo.is_compiling():
152+
for layer in fp8_layers:
153+
for buf in layer.buffers():
154+
torch._dynamo.mark_static_address(buf, guard=True)
152155
return fp8_layers
153156

154157

@@ -290,7 +293,7 @@ def inner_func():
290293
fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe
291294
)
292295

293-
# Iterate through the layers and update the scales, and set the flag to signal that the amaxes/scales are ready
296+
# Iterate through the layers and update the scales
294297
for idx, child in enumerate(fp8_layers):
295298
child.fp8_scale_x.copy_(new_x_scales[idx])
296299
child.fp8_scale_w.copy_(new_w_scales[idx])
@@ -301,6 +304,5 @@ def inner_func():
301304
inner_func()
302305

303306
for child in fp8_layers:
304-
# 4. set a flag to signal amaxes/scales are ready
305-
# We only update the flag if we know it will be checked by the modules
307+
# Set a flag to signal amaxes/scales are ready
306308
child.amax_and_scale_synced = True

test/test_compile.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55
# LICENSE file in the root directory of this source tree.
66
import copy
77
import random
8+
import sys
89
import unittest
10+
from io import StringIO
911

1012
import pytest
1113

1214
import torch
1315
import torch.nn as nn
1416
from float8_experimental.float8_linear import Float8Linear
1517
from float8_experimental.float8_linear_utils import (
18+
get_float8_layers,
1619
get_float8_linear,
1720
LinearType,
1821
swap_linear_with_float8_linear,
@@ -218,5 +221,43 @@ def test_sync_amax_func():
218221
assert cnts.frame_count == 1, "Compiled graph should have 1 frame!"
219222

220223

224+
class capture_stderr(list):
225+
"""
226+
Replace sys.stderr with a temporary StringIO
227+
"""
228+
229+
def __enter__(self):
230+
self.sys_stderr = sys.stderr
231+
self.stringio = StringIO()
232+
sys.stderr = self.stringio
233+
return self
234+
235+
def __exit__(self, *args):
236+
self.append(str(self.stringio.getvalue()))
237+
del self.stringio
238+
sys.stderr = self.sys_stderr
239+
240+
241+
@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA not available")
242+
def test_sync_amax_func_cuda_graph_success():
243+
torch._dynamo.reset()
244+
with capture_stderr() as stderr:
245+
my_module = nn.Sequential(
246+
nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True)
247+
).to("cuda")
248+
swap_linear_with_float8_linear(my_module, Float8Linear)
249+
inpt = torch.randn(
250+
16, 16, device="cuda", dtype=torch.float32, requires_grad=True
251+
)
252+
sync_func = torch.compile(
253+
sync_float8_amax_and_scale_history, mode="reduce-overhead", fullgraph=True
254+
)
255+
fp8_layers = get_float8_layers(my_module)
256+
my_module(inpt)
257+
sync_func(my_module, fp8_layers)
258+
259+
assert "skipping cudagraphs due to mutaton on input" not in stderr[0]
260+
261+
221262
if __name__ == "__main__":
222263
pytest.main([__file__])

0 commit comments

Comments
 (0)