Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit fb3d4ce

Browse files
committed
formatting
1 parent 4a27a27 commit fb3d4ce

File tree

5 files changed

+51
-27
lines changed

5 files changed

+51
-27
lines changed

float8_experimental/float8_dynamic_linear.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
"""
77
A wrapper around a `torch.nn.Linear` module which does fp8 compute.
88
"""
9-
import torch
109
from typing import Optional
1110

11+
import torch
12+
1213
from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd
13-
from float8_experimental.float8_utils import IS_AMD, tensor_to_scale, FP8Dtypes
14+
from float8_experimental.float8_utils import FP8Dtypes, tensor_to_scale
1415

1516

1617
@torch._dynamo.allow_in_graph
@@ -21,20 +22,17 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
2122
"""
2223

2324
@staticmethod
24-
def forward(
25-
ctx,
26-
tensor,
27-
emulate: bool,
28-
fp8_dtype_bw: torch.dtype
29-
):
25+
def forward(ctx, tensor, emulate: bool, fp8_dtype_bw: torch.dtype):
3026
ctx.emulate = emulate
3127
ctx.fp8_dtype_bw = fp8_dtype_bw
3228
return tensor
3329

3430
@staticmethod
3531
def backward(ctx, gradY):
3632
gradY_scale = tensor_to_scale(gradY, ctx.fp8_dtype_bw)
37-
fp8_tensor = to_fp8_no_autograd(gradY, gradY_scale, ctx.fp8_dtype_bw, ctx.emulate)
33+
fp8_tensor = to_fp8_no_autograd(
34+
gradY, gradY_scale, ctx.fp8_dtype_bw, ctx.emulate
35+
)
3836
return fp8_tensor, None, None
3937

4038

@@ -63,7 +61,9 @@ class Float8DynamicLinear(torch.nn.Linear):
6361
conversion to fp8 of the input and weight tensors.
6462
"""
6563

66-
def __init__(self, use_activation_hooks: bool, fp8_dtype: FP8Dtypes, **super_kwargs):
64+
def __init__(
65+
self, use_activation_hooks: bool, fp8_dtype: FP8Dtypes, **super_kwargs
66+
):
6767
"""
6868
Args:
6969
use_activation_hooks (bool): whether to use activation hooks for casting to and from float8
@@ -120,7 +120,11 @@ def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor:
120120

121121
@classmethod
122122
def from_float(
123-
cls, mod, emulate: bool = False, use_activation_hooks: bool = False, fp8_dtypes: Optional[FP8Dtypes] = None
123+
cls,
124+
mod,
125+
emulate: bool = False,
126+
use_activation_hooks: bool = False,
127+
fp8_dtypes: Optional[FP8Dtypes] = None,
124128
) -> "Float8DynamicLinear":
125129
"""
126130
Create an nn.Linear with fp8 compute from a regular nn.Linear

float8_experimental/float8_linear.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import dataclasses
1616

17-
from typing import Optional, Literal
17+
from typing import Literal, Optional
1818

1919
import float8_experimental.config as config
2020

@@ -26,8 +26,8 @@
2626
amax_history_to_scale,
2727
E4M3_MAX_POS,
2828
E5M2_MAX_POS,
29+
FP8Dtypes,
2930
tensor_to_amax,
30-
FP8Dtypes
3131
)
3232

3333

@@ -316,7 +316,13 @@ def forward(self, input):
316316
return y
317317

318318
@classmethod
319-
def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = False, fp8_dtypes: Optional[FP8Dtypes] = None):
319+
def from_float(
320+
cls,
321+
mod,
322+
emulate: bool = False,
323+
use_activation_hooks: bool = False,
324+
fp8_dtypes: Optional[FP8Dtypes] = None,
325+
):
320326
"""
321327
Create an nn.Linear with fp8 compute from a regular nn.Linear
322328

float8_experimental/float8_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def to_float8(
230230
float8_dtype: torch.dtype,
231231
amax_buffer: Optional[torch.Tensor] = None,
232232
emulate: bool = False,
233-
)-> "Float8Tensor":
233+
) -> "Float8Tensor":
234234
"""Converts a higher precision tensor to float8 in a differentiable way.
235235
236236
Args:

float8_experimental/float8_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Literal
87
from dataclasses import dataclass
8+
from typing import Literal
99

1010
import torch
1111
import torch.distributed as dist
@@ -30,10 +30,12 @@
3030

3131
@dataclass(frozen=True)
3232
class FP8Dtypes:
33-
""" Defines the fp8 dtypes to be used in forward and backwrad computations"""
33+
"""Defines the fp8 dtypes to be used in forward and backwrad computations"""
34+
3435
fp8_dtype_fw: torch.dtype = torch.float8_e4m3fn
3536
fp8_dtype_bw: torch.dtype = torch.float8_e5m2
3637

38+
3739
@torch.no_grad()
3840
def amax_to_scale(
3941
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype

test/test_base.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import unittest
99
import warnings
1010
from typing import Optional
11+
1112
import pytest
1213

1314
import torch
@@ -24,17 +25,16 @@
2425
from float8_experimental.float8_python_api import mm_float8
2526
from float8_experimental.float8_tensor import Float8Tensor
2627
from float8_experimental.float8_utils import (
27-
E5M2_FNUZ_MAX_POS,
2828
amax_to_scale,
2929
compute_error,
30-
E4M3_MAX_POS,
3130
E4M3_FNUZ_MAX_POS,
32-
E5M2_MAX_POS,
31+
E4M3_MAX_POS,
3332
E5M2_FNUZ_MAX_POS,
33+
E5M2_MAX_POS,
3434
FP16_MAX_POS,
35-
tensor_to_scale,
36-
IS_AMD,
3735
FP8Dtypes,
36+
IS_AMD,
37+
tensor_to_scale,
3838
)
3939

4040
random.seed(0)
@@ -65,9 +65,10 @@ def _test_linear_impl(
6565
emulate: bool,
6666
use_activation_hooks: bool,
6767
fp8_dtypes: Optional[FP8Dtypes] = None,
68-
6968
):
70-
m_fp8 = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks, fp8_dtypes)
69+
m_fp8 = get_float8_linear(
70+
linear_type, m_ref, emulate, use_activation_hooks, fp8_dtypes
71+
)
7172
for _ in range(2):
7273
if linear_requires_sync(linear_type):
7374
sync_float8_amax_and_scale_history(m_fp8)
@@ -95,7 +96,12 @@ def _test_linear_impl(
9596
]
9697
for buffer_name in amax_buffer_names:
9798
buffer_value = getattr(m_fp8, buffer_name)
98-
for init_val in (E4M3_MAX_POS, E5M2_MAX_POS, E4M3_FNUZ_MAX_POS, E5M2_FNUZ_MAX_POS):
99+
for init_val in (
100+
E4M3_MAX_POS,
101+
E5M2_MAX_POS,
102+
E4M3_FNUZ_MAX_POS,
103+
E5M2_FNUZ_MAX_POS,
104+
):
99105
assert torch.ne(
100106
buffer_value, torch.tensor(init_val)
101107
), f"{buffer_name} not filled, current value {buffer_value}"
@@ -147,10 +153,16 @@ def test_linear_nobias(
147153
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
148154
)
149155
pytest.skip()
150-
fp8_dtypes = FP8Dtypes() if not IS_AMD else FP8Dtypes(torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)
156+
fp8_dtypes = (
157+
FP8Dtypes()
158+
if not IS_AMD
159+
else FP8Dtypes(torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)
160+
)
151161
x = torch.randn(*x_shape, device="cuda")
152162
m_ref = nn.Linear(16, 32, bias=False, device="cuda")
153-
self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks, fp8_dtypes)
163+
self._test_linear_impl(
164+
x, m_ref, linear_type, emulate, use_activation_hooks, fp8_dtypes
165+
)
154166

155167
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
156168
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])

0 commit comments

Comments
 (0)