Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 1ef4605

Browse files
author
Andrew Gu
committed
Added compile tests for flat-parameter FSDP
ghstack-source-id: b87ba69 Pull Request resolved: #215
1 parent c9e9fef commit 1ef4605

File tree

3 files changed

+359
-0
lines changed

3 files changed

+359
-0
lines changed

float8_experimental/float8_linear.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ def float8_pre_forward(self, x):
267267
self.is_amax_initialized
268268
and (not self.amax_and_scale_synced)
269269
and torch.is_grad_enabled()
270+
# Skip if running in backward from activation checkpointing
271+
and torch._C._current_graph_task_id() == -1
270272
):
271273
raise AssertionError(
272274
"amaxes and scales not synced, please call `sync_float8_amax_and_scale_history` before forward"

test/test_everything.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ pytest test/test_compile.py
99
./test/test_fsdp.sh
1010
./test/test_fsdp_compile.sh
1111
./test/test_tp.sh
12+
pytest test/test_fsdp/test_flat_param_fsdp_compile.py
1213

1314
echo "all tests successful"
Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
import contextlib
2+
import functools
3+
from typing import List, Optional, Type
4+
5+
import torch
6+
import torch._dynamo.testing
7+
import torch.distributed as dist
8+
from float8_experimental import config
9+
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
10+
from float8_experimental.float8_linear import Float8Linear
11+
from float8_experimental.float8_linear_utils import (
12+
swap_linear_with_float8_linear,
13+
sync_float8_amax_and_scale_history,
14+
)
15+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
16+
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
17+
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
18+
from torch.testing._internal.common_fsdp import FSDPTest
19+
from torch.testing._internal.common_utils import run_tests
20+
from torch.testing._internal.distributed._tensor.common_dtensor import (
21+
ModelArgs,
22+
Transformer,
23+
TransformerBlock,
24+
)
25+
from torch.testing._internal.distributed.fake_pg import FakeStore
26+
27+
# Synctactic sugar: require `use_orig_params=True` for compile
28+
FSDP = functools.partial(FSDP, use_orig_params=True)
29+
# Increase cache size limit for running all unit tests together
30+
torch._dynamo.config.cache_size_limit = 16
31+
32+
33+
class TestFloat8CompileCommon:
34+
def _init_transformer_with_fp8(
35+
self,
36+
module_cls: Type,
37+
*,
38+
checkpoint_activations: bool = False,
39+
use_activation_hooks: Optional[bool] = None,
40+
):
41+
torch.manual_seed(42)
42+
args = ModelArgs(
43+
n_layers=3,
44+
dim=768,
45+
n_heads=12,
46+
dropout_p=0.0,
47+
weight_tying=False,
48+
checkpoint_activations=checkpoint_activations,
49+
)
50+
module = Transformer(args)
51+
# Only dynamic linear supports activation hooks
52+
use_hooks = use_activation_hooks or (module_cls is Float8DynamicLinear)
53+
return swap_linear_with_float8_linear(
54+
module, module_cls, emulate=True, use_activation_hooks=use_hooks
55+
)
56+
57+
@contextlib.contextmanager
58+
def enable_amax_init(self, enable: bool):
59+
prev_value = config.enable_amax_init
60+
config.enable_amax_init = enable
61+
try:
62+
yield
63+
finally:
64+
config.enable_amax_init = prev_value
65+
66+
@contextlib.contextmanager
67+
def enable_pre_and_post_forward(self, enable: bool):
68+
prev_value = config.enable_pre_and_post_forward
69+
config.enable_pre_and_post_forward = enable
70+
try:
71+
yield
72+
finally:
73+
config.enable_pre_and_post_forward = prev_value
74+
75+
def apply_fsdp(self, transformer: Transformer):
76+
return FSDP(
77+
transformer,
78+
auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}),
79+
device_id=dist.get_rank(),
80+
)
81+
82+
83+
class TestFloat8CompileFakePG(
84+
TestFloat8CompileCommon, torch._dynamo.test_case.TestCase
85+
):
86+
def setUp(self):
87+
super().setUp()
88+
fake_store = FakeStore()
89+
dist.init_process_group(
90+
"fake", store=fake_store, rank=0, world_size=self.world_size
91+
)
92+
93+
def tearDown(self):
94+
super().tearDown()
95+
dist.destroy_process_group()
96+
97+
@property
98+
def world_size(self) -> int:
99+
return min(torch.cuda.device_count(), 2)
100+
101+
@skip_if_lt_x_gpu(2)
102+
def test_compile_submodule_dynamic(self):
103+
local_inp = torch.randint(0, 16, (1, 4), device="cuda")
104+
105+
# Compile each transformer block forward
106+
module = self._init_transformer_with_fp8(Float8DynamicLinear)
107+
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
108+
num_compiled_fns = 0
109+
for submodule in module.modules():
110+
if isinstance(submodule, TransformerBlock):
111+
submodule.forward = torch.compile(submodule.forward, backend=cnt)
112+
num_compiled_fns += 1
113+
module = self.apply_fsdp(module)
114+
out = module(local_inp)
115+
out.sum().backward()
116+
self.assertEqual(cnt.frame_count, num_compiled_fns)
117+
118+
# Compile the output projection
119+
module.output.forward = torch.compile(module.output.forward, backend=cnt)
120+
# in float8_mm
121+
# assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
122+
with self.assertRaises(RuntimeError):
123+
module(local_inp)
124+
125+
# Compile each transformer block module
126+
module = self._init_transformer_with_fp8(Float8DynamicLinear)
127+
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
128+
num_compiled_fns = 0
129+
for submodule in module.modules():
130+
if isinstance(submodule, TransformerBlock):
131+
submodule.compile(backend=cnt)
132+
num_compiled_fns += 1
133+
module = self.apply_fsdp(module)
134+
# in float8_mm
135+
# assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
136+
with self.assertRaises(RuntimeError):
137+
module(local_inp)
138+
139+
@skip_if_lt_x_gpu(2)
140+
def test_compile_root_dynamic(self):
141+
# Compile the root module
142+
module = self._init_transformer_with_fp8(Float8DynamicLinear)
143+
module = self.apply_fsdp(module)
144+
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
145+
module = torch.compile(module, backend=cnt)
146+
local_inp = torch.randint(0, 16, (1, 4), device="cuda")
147+
# in forward
148+
# h = layer(h)
149+
# in float8_mm
150+
# assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
151+
with self.assertRaises(RuntimeError):
152+
module(local_inp)
153+
154+
@skip_if_lt_x_gpu(2)
155+
def test_compile_submodule_delayed(self):
156+
local_inp = torch.randint(0, 16, (1, 4), device="cuda")
157+
158+
# Compile each transformer block forward
159+
module = self._init_transformer_with_fp8(Float8Linear)
160+
161+
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
162+
for submodule in module.modules():
163+
if isinstance(submodule, TransformerBlock):
164+
submodule.forward = torch.compile(submodule.forward, backend=cnt)
165+
module = self.apply_fsdp(module)
166+
module(local_inp).sum().backward()
167+
num_float8_linears = sum(
168+
1 for m in module.modules() if isinstance(m, Float8Linear)
169+
)
170+
# TODO: We get one graph per `Float8Linear` in a transformer block
171+
# (-1 because output projection is not compiled).
172+
self.assertEqual(cnt.frame_count, num_float8_linears - 1)
173+
174+
# Compile each transformer block forward with amax init disabled
175+
with self.enable_amax_init(False):
176+
module = self._init_transformer_with_fp8(Float8Linear)
177+
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
178+
num_compiled_fns = 0
179+
for submodule in module.modules():
180+
if isinstance(submodule, TransformerBlock):
181+
submodule.forward = torch.compile(submodule.forward, backend=cnt)
182+
num_compiled_fns += 1
183+
module = self.apply_fsdp(module)
184+
module(local_inp).sum().backward()
185+
num_float8_linears = sum(
186+
1 for m in module.modules() if isinstance(m, Float8Linear)
187+
)
188+
# TODO: We get one graph per `Float8Linear` in a transformer block
189+
# (-1 because output projection is not compiled).
190+
self.assertEqual(cnt.frame_count, num_float8_linears - 1)
191+
192+
# Compile each transformer block forward with amax init disabled and
193+
# pre/post-forward disabled
194+
with self.enable_amax_init(False), self.enable_pre_and_post_forward(False):
195+
module = self._init_transformer_with_fp8(Float8Linear)
196+
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
197+
num_compiled_fns = 0
198+
for submodule in module.modules():
199+
if isinstance(submodule, TransformerBlock):
200+
submodule.forward = torch.compile(submodule.forward, backend=cnt)
201+
num_compiled_fns += 1
202+
module = self.apply_fsdp(module)
203+
module(local_inp).sum().backward()
204+
self.assertEqual(cnt.frame_count, num_compiled_fns)
205+
206+
# Compile each transformer block module with amax init disabled and
207+
# pre/post-forward disabled
208+
with self.enable_amax_init(False), self.enable_pre_and_post_forward(False):
209+
module = self._init_transformer_with_fp8(Float8Linear)
210+
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
211+
for submodule in module.modules():
212+
if isinstance(submodule, TransformerBlock):
213+
submodule.compile(backend=cnt)
214+
module = self.apply_fsdp(module)
215+
module(local_inp).sum().backward()
216+
num_float8_linears = sum(
217+
1 for m in module.modules() if isinstance(m, Float8Linear)
218+
)
219+
# TODO: We get one graph per `Float8Linear` in a transformer block
220+
# (-1 because output projection is not compiled).
221+
self.assertEqual(cnt.frame_count, num_float8_linears - 1)
222+
223+
@skip_if_lt_x_gpu(2)
224+
def test_compile_root_delayed(self):
225+
with self.enable_amax_init(False):
226+
module = self._init_transformer_with_fp8(Float8Linear)
227+
module = self.apply_fsdp(module)
228+
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
229+
module = torch.compile(module, backend=cnt)
230+
local_inp = torch.randint(0, 16, (1, 4), device="cuda")
231+
out = module(local_inp)
232+
out.sum().backward()
233+
num_float8_linears = sum(
234+
1 for m in module.modules() if isinstance(m, Float8Linear)
235+
)
236+
self.assertEqual(cnt.frame_count, num_float8_linears) # TODO!
237+
238+
with self.enable_amax_init(False), self.enable_pre_and_post_forward(False):
239+
module = self._init_transformer_with_fp8(Float8Linear)
240+
module = self.apply_fsdp(module)
241+
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
242+
module = torch.compile(module, backend=cnt)
243+
local_inp = torch.randint(0, 16, (1, 4), device="cuda")
244+
out = module(local_inp)
245+
out.sum().backward()
246+
num_float8_linears = sum(
247+
1 for m in module.modules() if isinstance(m, Float8Linear)
248+
)
249+
self.assertEqual(cnt.frame_count, num_float8_linears) # TODO!
250+
251+
252+
class TestFloat8CompileNCCLPG(TestFloat8CompileCommon, FSDPTest):
253+
@property
254+
def world_size(self) -> int:
255+
return min(torch.cuda.device_count(), 2)
256+
257+
def _test_parity(
258+
self,
259+
ref_model: torch.nn.Module,
260+
ref_optim: torch.optim.Optimizer,
261+
fsdp_model: torch.nn.Module,
262+
fsdp_optim: torch.optim.Optimizer,
263+
local_inp: torch.Tensor,
264+
module_cls: Type,
265+
):
266+
for iter_idx in range(10):
267+
losses: List[torch.Tensor] = []
268+
for model, optim in ((ref_model, ref_optim), (fsdp_model, fsdp_optim)):
269+
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
270+
losses.append(model(local_inp).sum())
271+
losses[-1].backward()
272+
if model is ref_model: # manual data parallelism
273+
for param in model.parameters():
274+
dist.all_reduce(param.grad)
275+
param.grad.div_(self.world_size)
276+
if module_cls is Float8Linear:
277+
sync_float8_amax_and_scale_history(model)
278+
optim.step()
279+
self.assertEqual(losses[0], losses[1])
280+
281+
@skip_if_lt_x_gpu(2)
282+
def test_transformer_parity_delayed_no_mp(self):
283+
module_cls, backend = Float8Linear, "inductor"
284+
with self.enable_amax_init(False), self.enable_pre_and_post_forward(False):
285+
model = self._init_transformer_with_fp8(module_cls)
286+
with self.enable_amax_init(False):
287+
ref_model = self._init_transformer_with_fp8(module_cls).cuda()
288+
# NOTE: We compile the ref model in the same way as the FSDP model for
289+
# numeric parity. Compiling the full ref model or running the full ref
290+
# model in eager both show differences 5+ iterations in.
291+
for module in ref_model.modules():
292+
if isinstance(module, TransformerBlock):
293+
module.forward = torch.compile(module.forward, backend=backend)
294+
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
295+
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
296+
num_compiled_fns = 0
297+
for module in model.modules():
298+
if isinstance(module, TransformerBlock):
299+
# TODO: For `Float8Linear`, compiling the module gives one
300+
# graph per `Float8Linear` instead of per `TransformerBlock`.
301+
module.forward = torch.compile(module.forward, backend=cnt)
302+
num_compiled_fns += 1
303+
fsdp_model = self.apply_fsdp(model)
304+
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2)
305+
306+
local_inp = torch.randint(0, 16, (1, 4), device="cuda")
307+
self._test_parity(
308+
ref_model, ref_optim, fsdp_model, fsdp_optim, local_inp, module_cls
309+
)
310+
self.assertEqual(cnt.frame_count, num_compiled_fns)
311+
312+
@skip_if_lt_x_gpu(2)
313+
def test_transformer_parity_dynamic_no_mp(self):
314+
self.run_subtests(
315+
{"use_activation_hooks": [False, True]},
316+
self._test_transformer_parity_dynamic_no_mp,
317+
)
318+
319+
def _test_transformer_parity_dynamic_no_mp(self, use_activation_hooks: bool):
320+
module_cls, backend = Float8DynamicLinear, "inductor"
321+
model = self._init_transformer_with_fp8(
322+
module_cls, use_activation_hooks=use_activation_hooks
323+
)
324+
ref_model = self._init_transformer_with_fp8(
325+
module_cls, use_activation_hooks=use_activation_hooks
326+
).cuda()
327+
# NOTE: We compile the ref model in the same way as the FSDP model for
328+
# numeric parity. Compiling the full ref model or running the full ref
329+
# model in eager both show differences 5+ iterations in.
330+
for module in ref_model.modules():
331+
if isinstance(module, TransformerBlock):
332+
module.forward = torch.compile(module.forward, backend=backend)
333+
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
334+
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
335+
num_compiled_fns = 0
336+
for module in model.modules():
337+
if isinstance(module, TransformerBlock):
338+
# TODO: For `Float8DynamicLinear`, compiling the module errors
339+
# for both using and not using activation hooks.
340+
# in float8_mm
341+
# assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
342+
# module.compile(backend=cnt)
343+
module.forward = torch.compile(module.forward, backend=cnt)
344+
num_compiled_fns += 1
345+
fsdp_model = self.apply_fsdp(model)
346+
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2)
347+
348+
local_inp = torch.randint(0, 16, (1, 4), device="cuda")
349+
self._test_parity(
350+
ref_model, ref_optim, fsdp_model, fsdp_optim, local_inp, module_cls
351+
)
352+
self.assertEqual(cnt.frame_count, num_compiled_fns)
353+
354+
355+
if __name__ == "__main__":
356+
run_tests()

0 commit comments

Comments
 (0)