Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit de10c25

Browse files
author
Andrew Gu
committed
Added compile tests for flat-parameter FSDP
ghstack-source-id: 39228ff Pull Request resolved: #215
1 parent 34e30d1 commit de10c25

File tree

3 files changed

+346
-0
lines changed

3 files changed

+346
-0
lines changed

float8_experimental/float8_linear.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,8 @@ def float8_pre_forward(self, x):
266266
self.is_amax_initialized
267267
and (not self.amax_and_scale_synced)
268268
and torch.is_grad_enabled()
269+
# Skip if running in backward from activation checkpointing
270+
and torch._C._current_graph_task_id() == -1
269271
):
270272
raise AssertionError(
271273
"amaxes and scales not synced, please call `sync_float8_amax_and_scale_history` before forward"

test/test_everything.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ pytest test/test_compile.py
99
./test/test_fsdp.sh
1010
./test/test_fsdp_compile.sh
1111
./test/test_tp.sh
12+
pytest test/test_fsdp/test_flat_param_fsdp_compile.py
1213

1314
echo "all tests successful"
Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
import contextlib
2+
import copy
3+
from typing import List, Type
4+
5+
import torch
6+
import torch._dynamo.testing
7+
import torch.distributed as dist
8+
from float8_experimental import config
9+
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
10+
from float8_experimental.float8_linear import Float8Linear
11+
from float8_experimental.float8_linear_utils import (
12+
swap_linear_with_float8_linear,
13+
sync_float8_amax_and_scale_history,
14+
)
15+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision
16+
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
17+
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
18+
from torch.testing._internal.common_fsdp import FSDPTest
19+
from torch.testing._internal.common_utils import run_tests
20+
from torch.testing._internal.distributed._tensor.common_dtensor import (
21+
ModelArgs,
22+
Transformer,
23+
TransformerBlock,
24+
)
25+
from torch.testing._internal.distributed.fake_pg import FakeStore
26+
27+
28+
class TestFloat8CompileCommon:
29+
def _init_transformer_with_fp8(
30+
self, module_cls: Type, checkpoint_activations: bool = False
31+
):
32+
torch.manual_seed(42)
33+
args = ModelArgs(
34+
n_layers=3,
35+
dim=768,
36+
n_heads=12,
37+
dropout_p=0.0,
38+
weight_tying=False,
39+
checkpoint_activations=checkpoint_activations,
40+
)
41+
module = Transformer(args)
42+
# Only dynamic linear supports activation hooks
43+
use_hooks = module_cls is Float8DynamicLinear
44+
return swap_linear_with_float8_linear(
45+
module, module_cls, emulate=True, use_activation_hooks=use_hooks
46+
)
47+
48+
@contextlib.contextmanager
49+
def enable_amax_init(self, enable: bool):
50+
prev_value = config.enable_amax_init
51+
config.enable_amax_init = enable
52+
try:
53+
yield
54+
finally:
55+
config.enable_amax_init = prev_value
56+
57+
@contextlib.contextmanager
58+
def enable_pre_and_post_forward(self, enable: bool):
59+
prev_value = config.enable_pre_and_post_forward
60+
config.enable_pre_and_post_forward = enable
61+
try:
62+
yield
63+
finally:
64+
config.enable_pre_and_post_forward = prev_value
65+
66+
67+
class TestFloat8CompileFakePG(
68+
TestFloat8CompileCommon, torch._dynamo.test_case.TestCase
69+
):
70+
def setUp(self):
71+
super().setUp()
72+
fake_store = FakeStore()
73+
dist.init_process_group(
74+
"fake", store=fake_store, rank=0, world_size=self.world_size
75+
)
76+
77+
def tearDown(self):
78+
super().tearDown()
79+
dist.destroy_process_group()
80+
81+
@property
82+
def world_size(self) -> int:
83+
return min(torch.cuda.device_count(), 2)
84+
85+
@skip_if_lt_x_gpu(2)
86+
def test_compile_submodule_dynamic(self):
87+
module = self._init_transformer_with_fp8(Float8DynamicLinear)
88+
89+
# Compile each transformer block separately
90+
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
91+
num_compiled_fns = 0
92+
for submodule in module.modules():
93+
if isinstance(submodule, TransformerBlock):
94+
submodule.forward = torch.compile(submodule.forward, backend=cnt)
95+
num_compiled_fns += 1
96+
module = FSDP(
97+
module,
98+
auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}),
99+
use_orig_params=True,
100+
device_id=dist.get_rank(),
101+
)
102+
local_inp = torch.randint(0, 16, (1, 4), device="cuda")
103+
out = module(local_inp)
104+
out.sum().backward()
105+
self.assertEqual(cnt.frame_count, num_compiled_fns)
106+
107+
# Compile the output projection
108+
module.output.forward = torch.compile(module.output.forward, backend=cnt)
109+
# in float8_mm
110+
# assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
111+
with self.assertRaises(RuntimeError):
112+
module(local_inp)
113+
114+
@skip_if_lt_x_gpu(2)
115+
def test_compile_root_dynamic(self):
116+
module = self._init_transformer_with_fp8(Float8DynamicLinear)
117+
118+
# Compile the root module
119+
module = FSDP(
120+
module,
121+
auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}),
122+
use_orig_params=True,
123+
device_id=dist.get_rank(),
124+
)
125+
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
126+
module = torch.compile(module, backend=cnt)
127+
local_inp = torch.randint(0, 16, (1, 4), device="cuda")
128+
# in forward
129+
# h = layer(h)
130+
# in float8_mm
131+
# assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
132+
with self.assertRaises(RuntimeError):
133+
module(local_inp)
134+
135+
@skip_if_lt_x_gpu(2)
136+
def test_compile_submodule_delayed(self):
137+
module = self._init_transformer_with_fp8(Float8Linear)
138+
139+
# Compile each transformer block separately
140+
torch._dynamo.config.cache_size_limit = 16
141+
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
142+
for submodule in module.modules():
143+
if isinstance(submodule, TransformerBlock):
144+
submodule.forward = torch.compile(submodule.forward, backend=cnt)
145+
module = FSDP(
146+
module,
147+
auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}),
148+
use_orig_params=True,
149+
device_id=dist.get_rank(),
150+
)
151+
local_inp = torch.randint(0, 16, (1, 4), device="cuda")
152+
# Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'
153+
# with self.assertRaises(RuntimeError):
154+
with self.assertRaises(RuntimeError):
155+
module(local_inp)
156+
157+
# Compile each transformer block separately with amax init disabled
158+
with self.enable_amax_init(False):
159+
module = self._init_transformer_with_fp8(Float8Linear)
160+
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
161+
num_compiled_fns = 0
162+
for submodule in module.modules():
163+
if isinstance(submodule, TransformerBlock):
164+
submodule.forward = torch.compile(submodule.forward, backend=cnt)
165+
num_compiled_fns += 1
166+
module = FSDP(
167+
module,
168+
auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}),
169+
use_orig_params=True,
170+
device_id=dist.get_rank(),
171+
)
172+
module(local_inp).sum().backward()
173+
self.assertEqual(cnt.frame_count, 18) # TODO!
174+
175+
# Compile each transformer block separately with amax init disabled and
176+
# pre/post-forward disabled
177+
with self.enable_amax_init(False), self.enable_pre_and_post_forward(False):
178+
module = self._init_transformer_with_fp8(Float8Linear)
179+
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
180+
num_compiled_fns = 0
181+
for submodule in module.modules():
182+
if isinstance(submodule, TransformerBlock):
183+
submodule.forward = torch.compile(submodule.forward, backend=cnt)
184+
num_compiled_fns += 1
185+
module = FSDP(
186+
module,
187+
auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}),
188+
use_orig_params=True,
189+
device_id=dist.get_rank(),
190+
)
191+
module(local_inp).sum().backward()
192+
self.assertEqual(cnt.frame_count, num_compiled_fns)
193+
194+
@skip_if_lt_x_gpu(2)
195+
def test_compile_root_delayed(self):
196+
with self.enable_amax_init(False):
197+
module = self._init_transformer_with_fp8(Float8Linear)
198+
module = FSDP(
199+
module,
200+
auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}),
201+
use_orig_params=True,
202+
device_id=dist.get_rank(),
203+
)
204+
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
205+
module = torch.compile(module, backend=cnt)
206+
local_inp = torch.randint(0, 16, (1, 4), device="cuda")
207+
out = module(local_inp)
208+
out.sum().backward()
209+
self.assertEqual(cnt.frame_count, 19) # TODO!
210+
211+
with self.enable_amax_init(False), self.enable_pre_and_post_forward(False):
212+
module = self._init_transformer_with_fp8(Float8Linear)
213+
module = FSDP(
214+
module,
215+
auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}),
216+
use_orig_params=True,
217+
device_id=dist.get_rank(),
218+
)
219+
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
220+
module = torch.compile(module, backend=cnt)
221+
local_inp = torch.randint(0, 16, (1, 4), device="cuda")
222+
out = module(local_inp)
223+
out.sum().backward()
224+
self.assertEqual(cnt.frame_count, 19) # TODO!
225+
226+
227+
class TestFloat8CompileNCCLPG(TestFloat8CompileCommon, FSDPTest):
228+
@property
229+
def world_size(self) -> int:
230+
return min(torch.cuda.device_count(), 2)
231+
232+
@skip_if_lt_x_gpu(2)
233+
def test_transformer_parity_no_mp(self):
234+
"""
235+
Test numeric parity against manual data parallelism without using
236+
FSDP's mixed precision.
237+
"""
238+
self.run_subtests(
239+
{
240+
"module_cls": [Float8Linear, Float8DynamicLinear],
241+
"checkpoint_activations": [False, True],
242+
},
243+
self._test_transformer_parity_no_mp,
244+
)
245+
246+
def _test_transformer_parity_no_mp(
247+
self, module_cls: Type, checkpoint_activations: bool
248+
):
249+
model = self._init_transformer_with_fp8(module_cls, checkpoint_activations)
250+
ref_model = copy.deepcopy(model).cuda()
251+
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
252+
fsdp_model = FSDP(
253+
model,
254+
auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}),
255+
use_orig_params=True,
256+
device_id=self.rank,
257+
)
258+
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2)
259+
260+
local_inp = torch.randint(0, 16, (1, 4), device="cuda")
261+
with self.enable_amax_init(False), self.enable_pre_and_post_forward(
262+
False
263+
) if module_cls is Float8Linear else contextlib.nullcontext():
264+
for iter_idx in range(10):
265+
losses: List[torch.Tensor] = []
266+
for model, optim in ((ref_model, ref_optim), (fsdp_model, fsdp_optim)):
267+
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
268+
losses.append(model(local_inp).sum())
269+
losses[-1].backward()
270+
if model is ref_model:
271+
for param in model.parameters():
272+
dist.all_reduce(param.grad)
273+
param.grad.div_(self.world_size)
274+
if module_cls is Float8Linear:
275+
sync_float8_amax_and_scale_history(model)
276+
optim.step()
277+
self.assertEqual(losses[0], losses[1])
278+
279+
@skip_if_lt_x_gpu(2)
280+
def test_transformer_parity_bf16_mp(self):
281+
"""
282+
Test numeric parity against manual data parallelism using FSDP's bf16
283+
mixed precision.
284+
"""
285+
self.run_subtests(
286+
{
287+
"module_cls": [Float8Linear, Float8DynamicLinear],
288+
"checkpoint_activations": [False, True],
289+
},
290+
self._test_transformer_parity_bf16_mp,
291+
)
292+
293+
def _test_transformer_parity_bf16_mp(
294+
self, module_cls: Type, checkpoint_activations: bool
295+
):
296+
model = self._init_transformer_with_fp8(module_cls, checkpoint_activations)
297+
ref_model = copy.deepcopy(model).cuda() # used for optimizer
298+
ref_model_bf16 = copy.deepcopy(ref_model).to(
299+
torch.bfloat16
300+
) # used for forward/backward
301+
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
302+
fsdp_model = FSDP(
303+
model,
304+
auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}),
305+
use_orig_params=True,
306+
device_id=self.rank,
307+
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16),
308+
)
309+
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2)
310+
311+
local_inp = torch.randint(0, 16, (1, 4), device="cuda")
312+
with self.enable_amax_init(False), self.enable_pre_and_post_forward(
313+
False
314+
) if module_cls is Float8Linear else contextlib.nullcontext():
315+
for iter_idx in range(10):
316+
losses: List[torch.Tensor] = []
317+
for model, optim in (
318+
(ref_model_bf16, ref_optim),
319+
(fsdp_model, fsdp_optim),
320+
):
321+
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
322+
losses.append(model(local_inp).sum())
323+
losses[-1].backward()
324+
if model is ref_model_bf16:
325+
for param_bf16, param_fp32 in zip(
326+
ref_model_bf16.parameters(), ref_model.parameters()
327+
):
328+
dist.all_reduce(param_bf16.grad)
329+
param_bf16.grad.div_(self.world_size)
330+
param_fp32.grad = param_bf16.grad.float()
331+
param_bf16.grad = None
332+
if module_cls is Float8Linear:
333+
sync_float8_amax_and_scale_history(model)
334+
optim.step()
335+
for param_fp32, param_bf16 in zip(
336+
ref_model.parameters(), ref_model_bf16.parameters()
337+
):
338+
param_bf16.detach().copy_(param_fp32)
339+
self.assertEqual(losses[0], losses[1])
340+
341+
342+
if __name__ == "__main__":
343+
run_tests()

0 commit comments

Comments
 (0)