Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 1e71def

Browse files
vkuzofacebook-github-bot
authored andcommitted
fix numerics integration test and test delayed vs dynamic (#291)
Summary: Pull Request resolved: #291 1. the SAM test wasn't easy to use because it had real weights and hence required real data for useful testing, which is not convenient from an integration test. Switched to LLaMa FFN with random weights, and made all the thresholds tight to actually check numerics are close. 2. extended numerics test to check all combinations of delayed vs dynamic 3. to be able to do (2), extended the module swap utility to configure delayed vs dynamic on a model level, for now without an option to customize further Reviewed By: drisspg Differential Revision: D59305796 fbshipit-source-id: 4b1cd097ff82ce81a774cab535b0c890d47a2ae8
1 parent 3cb42e1 commit 1e71def

File tree

5 files changed

+224
-81
lines changed

5 files changed

+224
-81
lines changed

float8_experimental/float8_linear_utils.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,41 @@ def swap_linear_with_float8_linear(
191191
skip_fqn_list: Optional[List[str]] = None,
192192
emulate: bool = False,
193193
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
194+
scaling_type_x: TensorScalingType = TensorScalingType.DELAYED,
195+
scaling_type_w: TensorScalingType = TensorScalingType.DELAYED,
196+
scaling_type_dL_dY: TensorScalingType = TensorScalingType.DELAYED,
194197
) -> Optional[nn.Module]:
198+
"""
199+
Swaps `torch.nn.Linear` in `module` with `Float8Linear` or `Float8DynamicLinear`.
200+
201+
Args:
202+
module: Module to modify.
203+
module_cls: `Float8Linear` or `Float8DynamicLinear`.
204+
from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
205+
skip_fqn_list: If specified, a list of module FQNs to skip.
206+
emulate: If True, emulation is used instead of hardware accelerated gemm
207+
linear_layer_filter: If specified, only the linear layers
208+
that pass the filter function will be swapped.
209+
scaling_type_x (TensorScalingType): scaling type for `x`
210+
scaling_type_w (TensorScalingType): scaling type for `w`
211+
scaling_type_dL_dY (TensorScalingType): scaling type for `dL_dY`
212+
213+
Returns:
214+
nn.Module: The modified module with swapped linear layers.
215+
"""
216+
if module_cls is Float8DynamicLinear:
217+
from_float = lambda m: module_cls.from_float(m, emulate=emulate)
218+
else:
219+
from_float = lambda m: module_cls.from_float(
220+
m,
221+
emulate=emulate,
222+
scaling_type_x=scaling_type_x,
223+
scaling_type_w=scaling_type_w,
224+
scaling_type_dL_dY=scaling_type_dL_dY,
225+
)
195226
return swap_linear_layers(
196227
module,
197-
lambda m: module_cls.from_float(m, emulate=emulate),
228+
from_float,
198229
skip_fqn_list=skip_fqn_list,
199230
linear_layer_filter=linear_layer_filter,
200231
)

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ dependencies = [
1919

2020
[project.optional-dependencies]
2121
test = [
22-
"transformers==4.38.2",
2322
"pandas >= 2.0",
2423
"tqdm==4.66.2",
2524
"fire==0.5.0",

test/test_everything.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ set -e
55
IS_ROCM=$(rocm-smi --version || true)
66

77
pytest test/test_base.py
8-
pytest test/test_sam.py
98
pytest test/test_compile.py
109
pytest test/test_inference_flows.py
10+
pytest test/test_numerics_integration.py
1111

1212
# These tests do not work on ROCm yet
1313
if [ -z "$IS_ROCM" ]

test/test_numerics_integration.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Tests LLaMa FeedForward numerics with float8
8+
9+
import copy
10+
from typing import Optional
11+
12+
import pytest
13+
14+
import torch
15+
import torch.nn as nn
16+
import torch.nn.functional as F
17+
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
18+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
19+
from float8_experimental.float8_linear_utils import (
20+
linear_requires_sync,
21+
LinearType,
22+
swap_linear_with_float8_linear,
23+
sync_float8_amax_and_scale_history,
24+
)
25+
from float8_experimental.float8_utils import compute_error, IS_ROCM
26+
27+
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
28+
29+
30+
torch.manual_seed(0)
31+
32+
33+
# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py
34+
class FeedForward(nn.Module):
35+
"""
36+
FeedForward module
37+
38+
Args:
39+
dim (int): Input dimension.
40+
hidden_dim (int): Hidden dimension of the feedforward layer.
41+
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
42+
ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None.
43+
44+
Attributes:
45+
w1 (Linear): Linear transformation for the first layer.
46+
w2 (Linear): Linear transformation for the second layer.
47+
w3 (Linear): Linear transformation for the third layer.
48+
49+
"""
50+
51+
def __init__(
52+
self,
53+
dim: int,
54+
hidden_dim: int,
55+
multiple_of: int,
56+
ffn_dim_multiplier: Optional[float],
57+
):
58+
super().__init__()
59+
hidden_dim = int(2 * hidden_dim / 3)
60+
# custom dim factor multiplier
61+
if ffn_dim_multiplier is not None:
62+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
63+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
64+
65+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
66+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
67+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
68+
69+
def forward(self, x):
70+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
71+
72+
def init_weights(self, init_std: float):
73+
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
74+
for linear in (self.w2, self.w3):
75+
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
76+
77+
78+
class TestFloat8NumericsIntegrationTest:
79+
@pytest.mark.parametrize(
80+
"scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
81+
)
82+
@pytest.mark.parametrize(
83+
"scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
84+
)
85+
@pytest.mark.parametrize(
86+
"scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
87+
)
88+
@pytest.mark.parametrize("linear_cls", [Float8Linear, Float8DynamicLinear])
89+
@pytest.mark.skipif(not is_H100, reason="requires H100 GPU")
90+
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
91+
def test_encoder_fw_bw(
92+
self,
93+
linear_cls,
94+
scaling_type_x: TensorScalingType,
95+
scaling_type_w: TensorScalingType,
96+
scaling_type_dL_dY: TensorScalingType,
97+
):
98+
linear_type = (
99+
LinearType.DELAYED if linear_cls == Float8Linear else LinearType.DYNAMIC
100+
)
101+
if linear_type is LinearType.DYNAMIC:
102+
# Only test one combination of scaling types, as they are a no-op
103+
# for Float8DynamicLinear. It would be cleaner to split into two
104+
# tests, but IMO not worth it since Float8DynamicLinear will be
105+
# deleted soon
106+
is_all_dynamic = (
107+
scaling_type_x is TensorScalingType.DYNAMIC
108+
and scaling_type_w is TensorScalingType.DYNAMIC
109+
and scaling_type_dL_dY is TensorScalingType.DYNAMIC
110+
)
111+
if not is_all_dynamic:
112+
pytest.skip()
113+
114+
# TODO(later): maybe add float16 back if it becomes important
115+
data_dtype = torch.bfloat16
116+
117+
# LLaMa 3 70B shapes
118+
model_ref = (
119+
FeedForward(
120+
dim=4096,
121+
hidden_dim=16384,
122+
multiple_of=1024,
123+
ffn_dim_multiplier=1.3,
124+
)
125+
.cuda()
126+
.to(data_dtype)
127+
)
128+
129+
# for now just test the encoder to simplify things
130+
model_fp8 = copy.deepcopy(model_ref)
131+
swap_linear_with_float8_linear(
132+
model_fp8,
133+
linear_cls,
134+
emulate=False,
135+
scaling_type_x=scaling_type_x,
136+
scaling_type_w=scaling_type_w,
137+
scaling_type_dL_dY=scaling_type_dL_dY,
138+
)
139+
140+
lr = 0.01
141+
optim_ref = torch.optim.SGD(model_ref.parameters(), lr=lr)
142+
optim_fp8 = torch.optim.SGD(model_fp8.parameters(), lr=lr)
143+
144+
# Note: you need two different inputs to properly test numerics
145+
# of delayed scaling, because the first time around the initialization
146+
# logic of delayed scaling behaves as dynamic scaling
147+
# TODO(future): also make unit tests do this properly
148+
shape = (1, 8192, 4096)
149+
data1 = torch.randn(*shape, device="cuda", dtype=data_dtype)
150+
data2 = torch.randn(*shape, device="cuda", dtype=data_dtype)
151+
152+
model_ref(data1).sum().backward()
153+
# zero out grads without stepping, since we just want to compare grads
154+
# of the second datum
155+
optim_ref.zero_grad()
156+
model_ref_out = model_ref(data2)
157+
model_ref_out.sum().backward()
158+
159+
if linear_requires_sync(
160+
linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY
161+
):
162+
sync_float8_amax_and_scale_history(model_fp8)
163+
model_fp8(data1).sum().backward()
164+
# zero out grads without stepping, since we just want to compare grads
165+
# of the second datum
166+
optim_fp8.zero_grad()
167+
if linear_requires_sync(
168+
linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY
169+
):
170+
sync_float8_amax_and_scale_history(model_fp8)
171+
model_fp8_out = model_fp8(data2)
172+
model_fp8_out.sum().backward()
173+
174+
out_sqnr = compute_error(model_ref_out, model_fp8_out)
175+
assert out_sqnr > 20.0
176+
177+
ref_name_to_grad = {
178+
name: param.grad for name, param in model_ref.named_parameters()
179+
}
180+
181+
grad_sqnr_threshold = 20.0
182+
183+
for name, param in model_fp8.named_parameters():
184+
ref_grad = ref_name_to_grad[name]
185+
cur_grad = param.grad
186+
sqnr = compute_error(ref_grad, cur_grad)
187+
assert sqnr > grad_sqnr_threshold
188+
189+
190+
if __name__ == "__main__":
191+
pytest.main([__file__])

test/test_sam.py

Lines changed: 0 additions & 78 deletions
This file was deleted.

0 commit comments

Comments
 (0)