Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 006f89a

Browse files
vkuzofacebook-github-bot
authored andcommitted
simplify FSDP1 test and add coverage for dynamic scaling (#293)
Summary: Pull Request resolved: #293 1. simplify the FSDP test, instead of testing 1 GPU vs N GPUs, instead hold the number of GPUs constant and test bf16 vs float8. Remove various technical debt that accumulated in this test. 2. add testing for dynamic scaling of weights Reviewed By: drisspg Differential Revision: D59305791 fbshipit-source-id: 03c29364a3d8cbfc2e514694a26ff03ae7c4795c
1 parent 1e71def commit 006f89a

File tree

2 files changed

+101
-177
lines changed

2 files changed

+101
-177
lines changed

test/test_fsdp.py

Lines changed: 95 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
"""
7-
Test numerics of single GPU vs FSDP of toy model. At a high level:
8-
1. start with reference input and state dict for a single GPU model
9-
2. run fw+bw+optim on single GPU, save the results
10-
3. run fw+bw+optim with FSDP, save the results
11-
4. verify that the outputs and state dicts after optim update match
12-
13-
later 1-4 can be repeated for fp16, various combinations of fp8, etc.
7+
Test numerics of bf16 versus float8 with FSDP on. At a high level:
8+
1. start with a reference model, with FSDP on
9+
2. run forward + backward + optim for 2 iterations
10+
3. repeat 2 with float8 enabled (2 iterations needed for delayed scaling)
11+
4. compare outputs and state dict between (2) and (3), should be close
1412
"""
1513

14+
import copy
1615
import os
1716
import warnings
1817

@@ -22,11 +21,12 @@
2221
import torch.distributed as dist
2322
import torch.multiprocessing as mp
2423
import torch.nn as nn
25-
from float8_experimental.float8_linear import Float8Linear
24+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
2625
from float8_experimental.float8_linear_utils import (
2726
swap_linear_with_float8_linear,
2827
sync_float8_amax_and_scale_history,
2928
)
29+
from float8_experimental.float8_utils import compute_error
3030
from torch.distributed.fsdp import (
3131
FullStateDictConfig,
3232
FullyShardedDataParallel as FSDP,
@@ -35,18 +35,9 @@
3535

3636
torch.manual_seed(0)
3737

38-
# assumes user is running the script from /data/users/{user}/float8_experimental
39-
data_dir = os.path.join(os.path.dirname(__file__), "tmp")
40-
input_fname = os.path.join(data_dir, "input.pt")
41-
sd_in_fname = os.path.join(data_dir, "sd_in.pt")
42-
sd_out_single_gpu_fname = os.path.join(data_dir, "sd_out_single_gpu.pt")
43-
sd_out_fsdp_fname = os.path.join(data_dir, "sd_out_fsdp.pt")
44-
output_single_gpu_fname = os.path.join(data_dir, "output_single_gpu.pt")
45-
output_fsdp_fname = os.path.join(data_dir, "output_fsdp.pt")
46-
4738
B, M, K, N = 8, 8, 32, 32
4839
lr = 0.01
49-
N_ITER = 5
40+
N_ITER = 2
5041

5142

5243
def setup(rank, world_size):
@@ -61,15 +52,13 @@ def cleanup():
6152
dist.destroy_process_group()
6253

6354

64-
def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32):
55+
def get_model(K, N, base_dtype=torch.float32):
6556
m = nn.Sequential(
6657
nn.Linear(K, N, dtype=base_dtype),
6758
nn.ReLU(),
6859
nn.Linear(N, N, dtype=base_dtype),
6960
nn.ReLU(),
7061
)
71-
if is_fp8:
72-
swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate)
7362
return m
7463

7564

@@ -79,52 +68,84 @@ def fsdp_main(rank, world_size, args):
7968
setup(rank, world_size)
8069
torch.cuda.set_device(rank)
8170

82-
# TODO: We set fullgraph as an option. However, it currently doesn't work for fullgraph compile.
83-
# We can investigate and fix it later.
84-
is_fp8, emulate, base_dtype, compile, fullgraph = args
85-
model = get_model(K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype).to(
86-
rank
71+
emulate, base_dtype, compile, use_weight_dynamic_scaling = args
72+
model = get_model(K, N, base_dtype=base_dtype).to(rank)
73+
model_fp8 = copy.deepcopy(model)
74+
75+
scaling_type_w = (
76+
TensorScalingType.DYNAMIC
77+
if use_weight_dynamic_scaling
78+
else TensorScalingType.DELAYED
79+
)
80+
81+
# Note: we only iterate over `scaling_type_w` because FSDP only interacts
82+
# with weights.
83+
swap_linear_with_float8_linear(
84+
model_fp8,
85+
Float8Linear,
86+
emulate=False,
87+
scaling_type_w=scaling_type_w,
8788
)
88-
model.load_state_dict(torch.load(sd_in_fname, weights_only=True))
89+
8990
# To compile FSDP, we need use_orig_params to True
9091
model = FSDP(model, use_orig_params=True)
92+
model_fp8 = FSDP(model_fp8, use_orig_params=True)
9193
# TODO: The following line doesn't work. We should fix it.
9294
# model = FSDP(torch.compile(model), use_orig_params=True)
9395

94-
# Note: we need to multiply by world_size here to match single GPU
95-
# optimizer update
96-
optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size)
96+
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
97+
optimizer_fp8 = torch.optim.SGD(model_fp8.parameters(), lr=lr)
9798

98-
ref_input_global = torch.load(input_fname, weights_only=True).to(base_dtype)
99+
# Note: we need two different inputs to properly measure the impact of
100+
# delayed scaling, before the first input uses dynamic scaling to
101+
# populate the buffers
102+
ref_input_global = [
103+
torch.randn(B, M, K).cuda().to(base_dtype),
104+
torch.randn(B, M, K).cuda().to(base_dtype),
105+
]
106+
ref_grad_global = [
107+
torch.randn(B, M, N).cuda().to(base_dtype),
108+
torch.randn(B, M, N).cuda().to(base_dtype),
109+
]
110+
ref_input_local = []
111+
ref_grad_local = []
99112

100113
# basic distributed data sampling
101114
assert B % world_size == 0
102115
bsz_local_start = int(rank / world_size * B)
103116
bsz_local_end = int((rank + 1) / world_size * B)
104-
ref_input_local = ref_input_global[bsz_local_start:bsz_local_end].to(rank)
117+
for idx in range(N_ITER):
118+
ref_input_local.append(
119+
ref_input_global[idx][bsz_local_start:bsz_local_end].to(rank)
120+
)
121+
ref_grad_local.append(
122+
ref_grad_global[idx][bsz_local_start:bsz_local_end].to(rank)
123+
)
105124

106125
sync_float8_func = sync_float8_amax_and_scale_history
107126
if compile:
108-
sync_float8_func = torch.compile(
109-
sync_float8_amax_and_scale_history, fullgraph=fullgraph
110-
)
111-
112-
def forward_backward(model):
113-
optimizer.zero_grad()
114-
y_local = model(ref_input_local)
115-
y_local.sum().backward()
116-
sync_float8_func(model)
117-
optimizer.step()
127+
sync_float8_func = torch.compile(sync_float8_amax_and_scale_history)
128+
129+
def forward_backward(model, optim, is_fp8, i):
130+
optim.zero_grad()
131+
y_local = model(ref_input_local[i])
132+
y_local.backward(ref_grad_local[i])
133+
if is_fp8:
134+
sync_float8_func(model)
135+
optim.step()
118136
return y_local
119137

120-
for iter in range(N_ITER):
138+
for i in range(N_ITER):
121139
# We first run one iteration without compile, as a workaround to compile float8 layer.
122140
# In the first iter, float8 layers go to the branches of "self.is_amax_initialized == False"
123141
# After that, float8 layers go the the branches of "self.is_amax_initialized == True"
124142
# TODO: Need to fix compile to run wihtout this workaround.
125-
if iter == 1 and compile:
126-
model = torch.compile(model, fullgraph=fullgraph)
127-
y_local = forward_backward(model)
143+
if i == 1 and compile:
144+
model = torch.compile(model)
145+
model_fp8 = torch.compile(model_fp8)
146+
y_local = forward_backward(model, optimizer, is_fp8=False, i=i)
147+
y_local_fp8 = forward_backward(model_fp8, optimizer_fp8, is_fp8=True, i=i)
148+
local_sqnr = compute_error(y_local, y_local_fp8)
128149

129150
# get global y
130151
y_global = [
@@ -133,132 +154,50 @@ def forward_backward(model):
133154
]
134155
dist.all_gather(y_global, y_local)
135156
y_global = torch.cat(y_global, dim=0)
157+
y_global_fp8 = [
158+
torch.zeros(*y_local_fp8.shape, dtype=base_dtype).to(rank)
159+
for r in range(world_size)
160+
]
161+
dist.all_gather(y_global_fp8, y_local_fp8)
162+
y_global_fp8 = torch.cat(y_global_fp8, dim=0)
136163
if rank == 0:
137-
torch.save(y_global, output_fsdp_fname)
164+
sqnr = compute_error(y_global, y_global_fp8)
165+
assert sqnr > 15.0, f"SQNR of {sqnr} is too low"
138166

139167
# get global state dict
140168
# https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html
141169
dist.barrier()
142170
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
143171
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
144172
cpu_state = model.state_dict()
173+
with FSDP.state_dict_type(model_fp8, StateDictType.FULL_STATE_DICT, save_policy):
174+
cpu_state_fp8 = model_fp8.state_dict()
145175
if rank == 0:
146-
torch.save(cpu_state, sd_out_fsdp_fname)
176+
for k, v1 in cpu_state.items():
177+
v2 = cpu_state_fp8[k]
178+
v1, v2 = v1.cpu(), v2.cpu()
179+
sqnr = compute_error(v1, v2)
180+
assert sqnr > 15.0, f"SQNR of {sqnr} is too low, k: {k}, v1: {v1}, v2: {v2}"
147181

148182
cleanup()
149183

150184

151-
def run(mode: str, is_fp8: bool, compile_fsdp: bool = False, fullgraph: bool = False):
152-
print(f"Mode: {mode}".center(100, "-"))
185+
def run(compile_fsdp: bool = False, use_weight_dynamic_scaling: bool = False):
153186
base_dtype = torch.bfloat16
154-
if not os.path.exists(data_dir):
155-
os.makedirs(data_dir)
156187

157188
emulate = False
158-
if is_fp8:
159-
if not torch.cuda.is_available():
160-
warnings.warn("CUDA not available, running in emulation_mode")
161-
emulate = True
162-
elif torch.cuda.get_device_capability() < (9, 0):
163-
warnings.warn(
164-
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode"
165-
)
166-
emulate = True
167-
168-
if mode == "generate":
169-
# generate reference input
170-
ref_input = torch.randn(B, M, K).cuda().to(base_dtype)
171-
model = get_model(
172-
K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype
173-
).cuda()
174-
torch.save(ref_input, input_fname)
175-
torch.save(model.state_dict(), sd_in_fname)
176-
177-
elif mode == "single_gpu":
178-
ref_input = torch.load(input_fname, weights_only=True).to(base_dtype)
179-
model = get_model(
180-
K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype
181-
).cuda()
182-
model.load_state_dict(torch.load(sd_in_fname, weights_only=True))
183-
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
184-
185-
def forward_backward():
186-
optimizer.zero_grad()
187-
y = model(ref_input)
188-
y.sum().backward()
189-
sync_float8_amax_and_scale_history(model)
190-
optimizer.step()
191-
return y
192-
193-
for _ in range(N_ITER):
194-
y = forward_backward()
195-
196-
torch.save(y, output_single_gpu_fname)
197-
torch.save(model.state_dict(), sd_out_single_gpu_fname)
198-
199-
elif mode == "fsdp":
200-
WORLD_SIZE = torch.cuda.device_count()
201-
# We only compile for fsdp, and compare the numerics with signle-gpu no-compile
202-
args = (is_fp8, emulate, base_dtype, compile_fsdp, fullgraph)
203-
mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)
204-
205-
elif mode == "analyze":
206-
y_single_gpu = torch.load(output_single_gpu_fname, weights_only=True).cpu()
207-
y_fsdp = torch.load(output_fsdp_fname, weights_only=True).cpu()
208-
if is_fp8 and not emulate:
209-
atol, rtol = 2e-2, 2e-2
210-
else:
211-
atol, rtol = None, None
212-
torch.testing.assert_close(y_single_gpu, y_fsdp, atol=atol, rtol=rtol)
213-
print("output testing single_gpu vs FSDP success")
214-
215-
sd_out_single_gpu = torch.load(sd_out_single_gpu_fname, weights_only=True)
216-
sd_out_fsdp = torch.load(sd_out_fsdp_fname, weights_only=True)
217-
for k, v1 in sd_out_single_gpu.items():
218-
if compile_fsdp:
219-
# The state-dict for compiled fsdp has a `_orig_mod` prefix
220-
k = f"_orig_mod.{k}"
221-
v2 = sd_out_fsdp[k]
222-
v1, v2 = v1.cpu(), v2.cpu()
223-
if is_fp8 and "noop" in k:
224-
# Note: for fp8 single-node vs FSDP, we are not expected
225-
# to match the scale of the gradients which follow the following
226-
# pattern:
227-
#
228-
# `op(g_prev, out_scale) -> g_fp8 -> cast -> g_fp16 -> reduce`.
229-
#
230-
# Reasoning is the order of operations of calculating the above:
231-
# a. single node:
232-
# 1. calculate dL_dValue and s_dL_dValue
233-
# 2. you're done
234-
# b. FSDP:
235-
# 1. calculate dL_dValue and s_dL_dValue of each slice
236-
# 2. reduce using summation
237-
#
238-
# a and b cannot always match because calculating the scale
239-
# involves taking max(dL_dW), FSDP reduces the gradients, and
240-
# max(abs(a), abs(b)) != max(abs(a + b))
241-
#
242-
# In today's codebase, we do not hit this yet. We expect to hit
243-
# this if we implement TP with activation gradients that both need
244-
# reductions and need fp8 distributed comms. Solution - TBD.
245-
246-
# noop buffers are unused, so ok for them to not match
247-
pass
248-
else:
249-
try:
250-
if v1.dtype == torch.bfloat16 and not emulate:
251-
atol, rtol = 2e-2, 2e-2
252-
else:
253-
if k == "1.fp8_amax_history_x" and not emulate:
254-
atol, rtol = 2e-2, 6e-3
255-
else:
256-
atol, rtol = None, None
257-
torch.testing.assert_close(v1, v2, atol=atol, rtol=rtol)
258-
except Exception as e:
259-
print("debug:", k, v1, v2)
260-
raise e
261-
print("state dict testing single_gpu vs FSDP success")
189+
if not torch.cuda.is_available():
190+
warnings.warn("CUDA not available, running in emulation_mode")
191+
emulate = True
192+
elif torch.cuda.get_device_capability() < (9, 0):
193+
warnings.warn(
194+
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode"
195+
)
196+
emulate = True
197+
198+
WORLD_SIZE = torch.cuda.device_count()
199+
args = (emulate, base_dtype, compile_fsdp, use_weight_dynamic_scaling)
200+
mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)
262201

263202

264203
if __name__ == "__main__":

test/test_fsdp.sh

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,12 @@
44
set -e
55

66
launch() {
7-
echo "launching IS_FP8 $IS_FP8, compile_fsdp $COMPILE, fullgraph $FULLGRAPH"
7+
echo "launching compile_fsdp $COMPILE, use_weight_dynamic_scaling $USE_WEIGHT_DYNAMIC_SCALING"
88

9-
# generate the test data
10-
python test/test_fsdp.py --mode generate --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
11-
echo "Success: ✅"
12-
13-
# generate single GPU model output and updated state dict
14-
python test/test_fsdp.py --mode single_gpu --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
15-
echo "Success: ✅"
16-
17-
# generate FSDP model output and updated state dict
189
# the NCCL_DEBUG setting is to avoid log spew
1910
# the CUDA_VISIBLE_DEVICES setting is for easy debugging
20-
# the NCCL_NET setting is to work around transient issues on a
21-
# specific host (`devgpu001.nha2`)
22-
NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 NCCL_NET=SOCKET python test/test_fsdp.py \
23-
--mode fsdp --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
24-
25-
# compare the outputs and state dicts and verify equivalence
26-
python test/test_fsdp.py --mode analyze --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
27-
echo "Success: ✅"
11+
NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/test_fsdp.py \
12+
--compile_fsdp $COMPILE --use_weight_dynamic_scaling $USE_WEIGHT_DYNAMIC_SCALING
2813

2914
echo "✅ All Tests Passed ✅"
3015
}
@@ -34,10 +19,10 @@ if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False";
3419
exit
3520
fi
3621

37-
# IS_FP8, COMPILE, FULLGRAPH
38-
for i in False,False,False True,False,False True,True,False
22+
# COMPILE, USE_WEIGHT_DYNAMIC_SCALING
23+
for i in False,False False,True True,False True,True
3924
do
4025
IFS=","; set -- $i;
41-
IS_FP8=$1; COMPILE=$2; FULLGRAPH=$3
26+
COMPILE=$1; USE_WEIGHT_DYNAMIC_SCALING=$2
4227
launch
4328
done

0 commit comments

Comments
 (0)