|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# Tests LLaMa FeedForward numerics with float8 |
| 8 | + |
| 9 | +import copy |
| 10 | +from typing import Optional |
| 11 | + |
| 12 | +import pytest |
| 13 | + |
| 14 | +import torch |
| 15 | +import torch.nn as nn |
| 16 | +import torch.nn.functional as F |
| 17 | +from float8_experimental.float8_dynamic_linear import Float8DynamicLinear |
| 18 | +from float8_experimental.float8_linear import Float8Linear, TensorScalingType |
| 19 | +from float8_experimental.float8_linear_utils import ( |
| 20 | + linear_requires_sync, |
| 21 | + LinearType, |
| 22 | + swap_linear_with_float8_linear, |
| 23 | + sync_float8_amax_and_scale_history, |
| 24 | +) |
| 25 | +from float8_experimental.float8_utils import compute_error, IS_ROCM |
| 26 | + |
| 27 | +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) |
| 28 | + |
| 29 | + |
| 30 | +torch.manual_seed(0) |
| 31 | + |
| 32 | + |
| 33 | +# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py |
| 34 | +class FeedForward(nn.Module): |
| 35 | + """ |
| 36 | + FeedForward module |
| 37 | +
|
| 38 | + Args: |
| 39 | + dim (int): Input dimension. |
| 40 | + hidden_dim (int): Hidden dimension of the feedforward layer. |
| 41 | + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. |
| 42 | + ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. |
| 43 | +
|
| 44 | + Attributes: |
| 45 | + w1 (Linear): Linear transformation for the first layer. |
| 46 | + w2 (Linear): Linear transformation for the second layer. |
| 47 | + w3 (Linear): Linear transformation for the third layer. |
| 48 | +
|
| 49 | + """ |
| 50 | + |
| 51 | + def __init__( |
| 52 | + self, |
| 53 | + dim: int, |
| 54 | + hidden_dim: int, |
| 55 | + multiple_of: int, |
| 56 | + ffn_dim_multiplier: Optional[float], |
| 57 | + ): |
| 58 | + super().__init__() |
| 59 | + hidden_dim = int(2 * hidden_dim / 3) |
| 60 | + # custom dim factor multiplier |
| 61 | + if ffn_dim_multiplier is not None: |
| 62 | + hidden_dim = int(ffn_dim_multiplier * hidden_dim) |
| 63 | + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
| 64 | + |
| 65 | + self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
| 66 | + self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
| 67 | + self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
| 68 | + |
| 69 | + def forward(self, x): |
| 70 | + return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
| 71 | + |
| 72 | + def init_weights(self, init_std: float): |
| 73 | + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) |
| 74 | + for linear in (self.w2, self.w3): |
| 75 | + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) |
| 76 | + |
| 77 | + |
| 78 | +class TestFloat8NumericsIntegrationTest: |
| 79 | + @pytest.mark.parametrize( |
| 80 | + "scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] |
| 81 | + ) |
| 82 | + @pytest.mark.parametrize( |
| 83 | + "scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] |
| 84 | + ) |
| 85 | + @pytest.mark.parametrize( |
| 86 | + "scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] |
| 87 | + ) |
| 88 | + @pytest.mark.parametrize("linear_cls", [Float8Linear, Float8DynamicLinear]) |
| 89 | + @pytest.mark.skipif(not is_H100, reason="requires H100 GPU") |
| 90 | + @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") |
| 91 | + def test_encoder_fw_bw( |
| 92 | + self, |
| 93 | + linear_cls, |
| 94 | + scaling_type_x: TensorScalingType, |
| 95 | + scaling_type_w: TensorScalingType, |
| 96 | + scaling_type_dL_dY: TensorScalingType, |
| 97 | + ): |
| 98 | + linear_type = ( |
| 99 | + LinearType.DELAYED if linear_cls == Float8Linear else LinearType.DYNAMIC |
| 100 | + ) |
| 101 | + if linear_type is LinearType.DYNAMIC: |
| 102 | + # Only test one combination of scaling types, as they are a no-op |
| 103 | + # for Float8DynamicLinear. It would be cleaner to split into two |
| 104 | + # tests, but IMO not worth it since Float8DynamicLinear will be |
| 105 | + # deleted soon |
| 106 | + is_all_dynamic = ( |
| 107 | + scaling_type_x is TensorScalingType.DYNAMIC |
| 108 | + and scaling_type_w is TensorScalingType.DYNAMIC |
| 109 | + and scaling_type_dL_dY is TensorScalingType.DYNAMIC |
| 110 | + ) |
| 111 | + if not is_all_dynamic: |
| 112 | + pytest.skip() |
| 113 | + |
| 114 | + # TODO(later): maybe add float16 back if it becomes important |
| 115 | + data_dtype = torch.bfloat16 |
| 116 | + |
| 117 | + # LLaMa 3 70B shapes |
| 118 | + model_ref = ( |
| 119 | + FeedForward( |
| 120 | + dim=4096, |
| 121 | + hidden_dim=16384, |
| 122 | + multiple_of=1024, |
| 123 | + ffn_dim_multiplier=1.3, |
| 124 | + ) |
| 125 | + .cuda() |
| 126 | + .to(data_dtype) |
| 127 | + ) |
| 128 | + |
| 129 | + # for now just test the encoder to simplify things |
| 130 | + model_fp8 = copy.deepcopy(model_ref) |
| 131 | + swap_linear_with_float8_linear( |
| 132 | + model_fp8, |
| 133 | + linear_cls, |
| 134 | + emulate=False, |
| 135 | + scaling_type_x=scaling_type_x, |
| 136 | + scaling_type_w=scaling_type_w, |
| 137 | + scaling_type_dL_dY=scaling_type_dL_dY, |
| 138 | + ) |
| 139 | + |
| 140 | + lr = 0.01 |
| 141 | + optim_ref = torch.optim.SGD(model_ref.parameters(), lr=lr) |
| 142 | + optim_fp8 = torch.optim.SGD(model_fp8.parameters(), lr=lr) |
| 143 | + |
| 144 | + # Note: you need two different inputs to properly test numerics |
| 145 | + # of delayed scaling, because the first time around the initialization |
| 146 | + # logic of delayed scaling behaves as dynamic scaling |
| 147 | + # TODO(future): also make unit tests do this properly |
| 148 | + shape = (1, 8192, 4096) |
| 149 | + data1 = torch.randn(*shape, device="cuda", dtype=data_dtype) |
| 150 | + data2 = torch.randn(*shape, device="cuda", dtype=data_dtype) |
| 151 | + |
| 152 | + model_ref(data1).sum().backward() |
| 153 | + # zero out grads without stepping, since we just want to compare grads |
| 154 | + # of the second datum |
| 155 | + optim_ref.zero_grad() |
| 156 | + model_ref_out = model_ref(data2) |
| 157 | + model_ref_out.sum().backward() |
| 158 | + |
| 159 | + if linear_requires_sync( |
| 160 | + linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY |
| 161 | + ): |
| 162 | + sync_float8_amax_and_scale_history(model_fp8) |
| 163 | + model_fp8(data1).sum().backward() |
| 164 | + # zero out grads without stepping, since we just want to compare grads |
| 165 | + # of the second datum |
| 166 | + optim_fp8.zero_grad() |
| 167 | + if linear_requires_sync( |
| 168 | + linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY |
| 169 | + ): |
| 170 | + sync_float8_amax_and_scale_history(model_fp8) |
| 171 | + model_fp8_out = model_fp8(data2) |
| 172 | + model_fp8_out.sum().backward() |
| 173 | + |
| 174 | + out_sqnr = compute_error(model_ref_out, model_fp8_out) |
| 175 | + assert out_sqnr > 20.0 |
| 176 | + |
| 177 | + ref_name_to_grad = { |
| 178 | + name: param.grad for name, param in model_ref.named_parameters() |
| 179 | + } |
| 180 | + |
| 181 | + grad_sqnr_threshold = 20.0 |
| 182 | + |
| 183 | + for name, param in model_fp8.named_parameters(): |
| 184 | + ref_grad = ref_name_to_grad[name] |
| 185 | + cur_grad = param.grad |
| 186 | + sqnr = compute_error(ref_grad, cur_grad) |
| 187 | + assert sqnr > grad_sqnr_threshold |
| 188 | + |
| 189 | + |
| 190 | +if __name__ == "__main__": |
| 191 | + pytest.main([__file__]) |
0 commit comments