This repository was archived by the owner on Aug 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 19
[2/x]: fix numerics integration test and test delayed vs dynamic #291
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# Tests LLaMa FeedForward numerics with float8 | ||
|
||
import copy | ||
from typing import Optional | ||
|
||
import pytest | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear | ||
from float8_experimental.float8_linear import Float8Linear, TensorScalingType | ||
from float8_experimental.float8_linear_utils import ( | ||
linear_requires_sync, | ||
LinearType, | ||
swap_linear_with_float8_linear, | ||
sync_float8_amax_and_scale_history, | ||
) | ||
from float8_experimental.float8_utils import compute_error, IS_ROCM | ||
|
||
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) | ||
|
||
|
||
torch.manual_seed(0) | ||
|
||
|
||
# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py | ||
class FeedForward(nn.Module): | ||
""" | ||
FeedForward module | ||
|
||
Args: | ||
dim (int): Input dimension. | ||
hidden_dim (int): Hidden dimension of the feedforward layer. | ||
multiple_of (int): Value to ensure hidden dimension is a multiple of this value. | ||
ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. | ||
|
||
Attributes: | ||
w1 (Linear): Linear transformation for the first layer. | ||
w2 (Linear): Linear transformation for the second layer. | ||
w3 (Linear): Linear transformation for the third layer. | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
dim: int, | ||
hidden_dim: int, | ||
multiple_of: int, | ||
ffn_dim_multiplier: Optional[float], | ||
): | ||
super().__init__() | ||
hidden_dim = int(2 * hidden_dim / 3) | ||
# custom dim factor multiplier | ||
if ffn_dim_multiplier is not None: | ||
hidden_dim = int(ffn_dim_multiplier * hidden_dim) | ||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) | ||
|
||
self.w1 = nn.Linear(dim, hidden_dim, bias=False) | ||
self.w2 = nn.Linear(hidden_dim, dim, bias=False) | ||
self.w3 = nn.Linear(dim, hidden_dim, bias=False) | ||
|
||
def forward(self, x): | ||
return self.w2(F.silu(self.w1(x)) * self.w3(x)) | ||
|
||
def init_weights(self, init_std: float): | ||
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) | ||
for linear in (self.w2, self.w3): | ||
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) | ||
|
||
|
||
class TestFloat8NumericsIntegrationTest: | ||
@pytest.mark.parametrize( | ||
"scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] | ||
) | ||
@pytest.mark.parametrize( | ||
"scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] | ||
) | ||
@pytest.mark.parametrize( | ||
"scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] | ||
) | ||
@pytest.mark.parametrize("linear_cls", [Float8Linear, Float8DynamicLinear]) | ||
@pytest.mark.skipif(not is_H100, reason="requires H100 GPU") | ||
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") | ||
def test_encoder_fw_bw( | ||
self, | ||
linear_cls, | ||
scaling_type_x: TensorScalingType, | ||
scaling_type_w: TensorScalingType, | ||
scaling_type_dL_dY: TensorScalingType, | ||
): | ||
linear_type = ( | ||
LinearType.DELAYED if linear_cls == Float8Linear else LinearType.DYNAMIC | ||
) | ||
if linear_type is LinearType.DYNAMIC: | ||
# Only test one combination of scaling types, as they are a no-op | ||
# for Float8DynamicLinear. It would be cleaner to split into two | ||
# tests, but IMO not worth it since Float8DynamicLinear will be | ||
# deleted soon | ||
is_all_dynamic = ( | ||
scaling_type_x is TensorScalingType.DYNAMIC | ||
and scaling_type_w is TensorScalingType.DYNAMIC | ||
and scaling_type_dL_dY is TensorScalingType.DYNAMIC | ||
) | ||
if not is_all_dynamic: | ||
pytest.skip() | ||
|
||
# TODO(later): maybe add float16 back if it becomes important | ||
data_dtype = torch.bfloat16 | ||
|
||
# LLaMa 3 70B shapes | ||
model_ref = ( | ||
FeedForward( | ||
dim=4096, | ||
hidden_dim=16384, | ||
multiple_of=1024, | ||
ffn_dim_multiplier=1.3, | ||
) | ||
.cuda() | ||
.to(data_dtype) | ||
) | ||
|
||
# for now just test the encoder to simplify things | ||
model_fp8 = copy.deepcopy(model_ref) | ||
swap_linear_with_float8_linear( | ||
model_fp8, | ||
linear_cls, | ||
emulate=False, | ||
scaling_type_x=scaling_type_x, | ||
scaling_type_w=scaling_type_w, | ||
scaling_type_dL_dY=scaling_type_dL_dY, | ||
) | ||
|
||
lr = 0.01 | ||
optim_ref = torch.optim.SGD(model_ref.parameters(), lr=lr) | ||
optim_fp8 = torch.optim.SGD(model_fp8.parameters(), lr=lr) | ||
|
||
# Note: you need two different inputs to properly test numerics | ||
# of delayed scaling, because the first time around the initialization | ||
# logic of delayed scaling behaves as dynamic scaling | ||
# TODO(future): also make unit tests do this properly | ||
shape = (1, 8192, 4096) | ||
data1 = torch.randn(*shape, device="cuda", dtype=data_dtype) | ||
data2 = torch.randn(*shape, device="cuda", dtype=data_dtype) | ||
|
||
model_ref(data1).sum().backward() | ||
# zero out grads without stepping, since we just want to compare grads | ||
# of the second datum | ||
optim_ref.zero_grad() | ||
model_ref_out = model_ref(data2) | ||
model_ref_out.sum().backward() | ||
|
||
if linear_requires_sync( | ||
linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY | ||
): | ||
sync_float8_amax_and_scale_history(model_fp8) | ||
model_fp8(data1).sum().backward() | ||
# zero out grads without stepping, since we just want to compare grads | ||
# of the second datum | ||
optim_fp8.zero_grad() | ||
if linear_requires_sync( | ||
linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY | ||
): | ||
sync_float8_amax_and_scale_history(model_fp8) | ||
model_fp8_out = model_fp8(data2) | ||
model_fp8_out.sum().backward() | ||
|
||
out_sqnr = compute_error(model_ref_out, model_fp8_out) | ||
assert out_sqnr > 20.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe a message |
||
|
||
ref_name_to_grad = { | ||
name: param.grad for name, param in model_ref.named_parameters() | ||
} | ||
|
||
grad_sqnr_threshold = 20.0 | ||
|
||
for name, param in model_fp8.named_parameters(): | ||
ref_grad = ref_name_to_grad[name] | ||
cur_grad = param.grad | ||
sqnr = compute_error(ref_grad, cur_grad) | ||
assert sqnr > grad_sqnr_threshold | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__]) |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
love it, this test has been annoying