Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit eb23050

Browse files
committed
[2/x]: fix numerics integration test and test delayed vs dynamic
Summary: 1. the SAM test wasn't easy to use because it had real weights and hence required real data for useful testing, which is not convenient from an integration test. Switched to LLaMa FFN with random weights, and made all the thresholds tight to actually check numerics are close. 2. extended numerics test to check all combinations of delayed vs dynamic 3. to be able to do (2), extended the module swap utility to configure delayed vs dynamic on a model level, for now without an option to customize further Test Plan: ``` pytest test/test_numerics_integration.py -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 954ce82 Pull Request resolved: #291
1 parent 1194661 commit eb23050

File tree

5 files changed

+210
-81
lines changed

5 files changed

+210
-81
lines changed

float8_experimental/float8_linear_utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,18 @@ def filter_out_small_unaligned_layers(size_limit: int) -> Callable[[nn.Linear],
114114
)
115115

116116

117+
# TODO(future PR): probably create a per-linear config which contains
118+
# all of the options (emulate, scaling, etc)
117119
def swap_linear_with_float8_linear(
118120
module: nn.Module,
119121
module_cls: Type[nn.Module],
120122
*,
121123
skip_fqn_list: Optional[List[str]] = None,
122124
emulate: bool = False,
123125
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
126+
scaling_type_x: TensorScalingType = TensorScalingType.DELAYED,
127+
scaling_type_w: TensorScalingType = TensorScalingType.DELAYED,
128+
scaling_type_dL_dY: TensorScalingType = TensorScalingType.DELAYED,
124129
) -> nn.Module:
125130
"""
126131
Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances
@@ -134,6 +139,9 @@ def swap_linear_with_float8_linear(
134139
emulate (bool): Whether to emulate the fp8 matmul logic in fp32.
135140
linear_layer_filter (Optional[Callable[[nn.Linear], bool]]): If specified, only the linear layers
136141
that pass the filter function will be swapped.
142+
scaling_type_x (TensorScalingType): scaling type for `x`
143+
scaling_type_w (TensorScalingType): scaling type for `w`
144+
scaling_type_dL_dY (TensorScalingType): scaling type for `dL_dY`
137145
"""
138146
module_names_to_skip = set(skip_fqn_list or [])
139147
if isinstance(module, nn.Linear) and (
@@ -167,7 +175,16 @@ def post_order_traversal(
167175
assert (
168176
parent_module is not None
169177
), f"Linear root module should return early: {module}"
170-
float8linear_module = module_cls.from_float(module, emulate=emulate)
178+
if module_cls is Float8DynamicLinear:
179+
float8linear_module = module_cls.from_float(module, emulate=emulate)
180+
else:
181+
float8linear_module = module_cls.from_float(
182+
module,
183+
emulate=emulate,
184+
scaling_type_x=scaling_type_x,
185+
scaling_type_w=scaling_type_w,
186+
scaling_type_dL_dY=scaling_type_dL_dY,
187+
)
171188
setattr(parent_module, module_name, float8linear_module)
172189

173190
post_order_traversal(root_module, "", None)

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ dependencies = [
1919

2020
[project.optional-dependencies]
2121
test = [
22-
"transformers==4.38.2",
2322
"pandas >= 2.0",
2423
"tqdm==4.66.2",
2524
"fire==0.5.0",

test/test_everything.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ set -e
55
IS_ROCM=$(rocm-smi --version || true)
66

77
pytest test/test_base.py
8-
pytest test/test_sam.py
98
pytest test/test_compile.py
9+
pytest test/test_numerics_integration.py
1010

1111
# These tests do not work on ROCm yet
1212
if [ -z "$IS_ROCM" ]

test/test_numerics_integration.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Tests LLaMa FeedForward numerics with float8
8+
9+
import copy
10+
from typing import Optional
11+
12+
import pytest
13+
14+
import torch
15+
import torch.nn as nn
16+
import torch.nn.functional as F
17+
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
18+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
19+
from float8_experimental.float8_linear_utils import (
20+
linear_requires_sync,
21+
LinearType,
22+
swap_linear_with_float8_linear,
23+
sync_float8_amax_and_scale_history,
24+
)
25+
from float8_experimental.float8_utils import compute_error, IS_ROCM
26+
27+
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
28+
29+
30+
torch.manual_seed(0)
31+
32+
33+
# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py
34+
class FeedForward(nn.Module):
35+
"""
36+
FeedForward module
37+
38+
Args:
39+
dim (int): Input dimension.
40+
hidden_dim (int): Hidden dimension of the feedforward layer.
41+
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
42+
ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None.
43+
44+
Attributes:
45+
w1 (Linear): Linear transformation for the first layer.
46+
w2 (Linear): Linear transformation for the second layer.
47+
w3 (Linear): Linear transformation for the third layer.
48+
49+
"""
50+
51+
def __init__(
52+
self,
53+
dim: int,
54+
hidden_dim: int,
55+
multiple_of: int,
56+
ffn_dim_multiplier: Optional[float],
57+
):
58+
super().__init__()
59+
hidden_dim = int(2 * hidden_dim / 3)
60+
# custom dim factor multiplier
61+
if ffn_dim_multiplier is not None:
62+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
63+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
64+
65+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
66+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
67+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
68+
69+
def forward(self, x):
70+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
71+
72+
def init_weights(self, init_std: float):
73+
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
74+
for linear in (self.w2, self.w3):
75+
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
76+
77+
78+
class TestFloat8NumericsIntegrationTest:
79+
@pytest.mark.parametrize(
80+
"scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
81+
)
82+
@pytest.mark.parametrize(
83+
"scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
84+
)
85+
@pytest.mark.parametrize(
86+
"scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
87+
)
88+
@pytest.mark.parametrize("linear_cls", [Float8Linear, Float8DynamicLinear])
89+
@pytest.mark.skipif(not is_H100, reason="requires H100 GPU")
90+
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
91+
def test_encoder_fw_bw(
92+
self,
93+
linear_cls,
94+
scaling_type_x: TensorScalingType,
95+
scaling_type_w: TensorScalingType,
96+
scaling_type_dL_dY: TensorScalingType,
97+
):
98+
linear_type = (
99+
LinearType.DELAYED if linear_cls == Float8Linear else LinearType.DYNAMIC
100+
)
101+
if linear_type is LinearType.DYNAMIC:
102+
# Only test one combination of scaling types, as they are a no-op
103+
# for Float8DynamicLinear. It would be cleaner to split into two
104+
# tests, but IMO not worth it since Float8DynamicLinear will be
105+
# deleted soon
106+
is_all_dynamic = (
107+
scaling_type_x is TensorScalingType.DYNAMIC
108+
and scaling_type_w is TensorScalingType.DYNAMIC
109+
and scaling_type_dL_dY is TensorScalingType.DYNAMIC
110+
)
111+
if not is_all_dynamic:
112+
pytest.skip()
113+
114+
# TODO(later): maybe add float16 back if it becomes important
115+
data_dtype = torch.bfloat16
116+
117+
# LLaMa 3 70B shapes
118+
model_ref = (
119+
FeedForward(
120+
dim=4096,
121+
hidden_dim=16384,
122+
multiple_of=1024,
123+
ffn_dim_multiplier=1.3,
124+
)
125+
.cuda()
126+
.to(data_dtype)
127+
)
128+
129+
# for now just test the encoder to simplify things
130+
model_fp8 = copy.deepcopy(model_ref)
131+
swap_linear_with_float8_linear(
132+
model_fp8,
133+
linear_cls,
134+
emulate=False,
135+
scaling_type_x=scaling_type_x,
136+
scaling_type_w=scaling_type_w,
137+
scaling_type_dL_dY=scaling_type_dL_dY,
138+
)
139+
140+
lr = 0.01
141+
optim_ref = torch.optim.SGD(model_ref.parameters(), lr=lr)
142+
optim_fp8 = torch.optim.SGD(model_fp8.parameters(), lr=lr)
143+
144+
# Note: you need two different inputs to properly test numerics
145+
# of delayed scaling, because the first time around the initialization
146+
# logic of delayed scaling behaves as dynamic scaling
147+
# TODO(future): also make unit tests do this properly
148+
shape = (1, 8192, 4096)
149+
data1 = torch.randn(*shape, device="cuda", dtype=data_dtype)
150+
data2 = torch.randn(*shape, device="cuda", dtype=data_dtype)
151+
152+
model_ref(data1).sum().backward()
153+
# zero out grads without stepping, since we just want to compare grads
154+
# of the second datum
155+
optim_ref.zero_grad()
156+
model_ref_out = model_ref(data2)
157+
model_ref_out.sum().backward()
158+
159+
if linear_requires_sync(
160+
linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY
161+
):
162+
sync_float8_amax_and_scale_history(model_fp8)
163+
model_fp8(data1).sum().backward()
164+
# zero out grads without stepping, since we just want to compare grads
165+
# of the second datum
166+
optim_fp8.zero_grad()
167+
if linear_requires_sync(
168+
linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY
169+
):
170+
sync_float8_amax_and_scale_history(model_fp8)
171+
model_fp8_out = model_fp8(data2)
172+
model_fp8_out.sum().backward()
173+
174+
out_sqnr = compute_error(model_ref_out, model_fp8_out)
175+
assert out_sqnr > 20.0
176+
177+
ref_name_to_grad = {
178+
name: param.grad for name, param in model_ref.named_parameters()
179+
}
180+
181+
grad_sqnr_threshold = 20.0
182+
183+
for name, param in model_fp8.named_parameters():
184+
ref_grad = ref_name_to_grad[name]
185+
cur_grad = param.grad
186+
sqnr = compute_error(ref_grad, cur_grad)
187+
assert sqnr > grad_sqnr_threshold
188+
189+
190+
if __name__ == "__main__":
191+
pytest.main([__file__])

test/test_sam.py

Lines changed: 0 additions & 78 deletions
This file was deleted.

0 commit comments

Comments
 (0)