Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit f32a4a4

Browse files
drisspgalugorey
authored andcommitted
format
1 parent d5dc16e commit f32a4a4

File tree

5 files changed

+29
-13
lines changed

5 files changed

+29
-13
lines changed

float8_experimental/float8_dynamic_linear.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
tensor_already_casted_to_fp8,
2323
to_fp8_no_autograd,
2424
)
25-
from float8_experimental.float8_utils import tensor_to_scale, e4m3_dtype, e5m2_dtype
25+
from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_scale
2626
from torch._prims_common import suggest_memory_format
2727

2828

@@ -106,9 +106,7 @@ def cast_to_float8_e4m3fn(
106106
if tensor_already_casted_to_fp8(inpt_tensor):
107107
return inpt_tensor
108108
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
109-
return Float8Tensor.to_float8(
110-
inpt_tensor, scale, e4m3_dtype, mm_config=mm_config
111-
)
109+
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)
112110

113111

114112
def cast_to_float8_e5m2_bw(

float8_experimental/float8_linear.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
to_fp8_no_autograd,
2222
)
2323

24-
from float8_experimental.float8_utils import amax_history_to_scale, tensor_to_amax, e4m3_dtype, e5m2_dtype
24+
from float8_experimental.float8_utils import (
25+
amax_history_to_scale,
26+
e4m3_dtype,
27+
e5m2_dtype,
28+
tensor_to_amax,
29+
)
2530

2631

2732
def _maybe_initialize_amaxes_scales_for_float8_cast(

float8_experimental/float8_linear_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
1515
from float8_experimental.float8_linear import Float8Linear
1616

17-
from float8_experimental.float8_utils import amax_history_to_scale_stack, e4m3_dtype, e5m2_dtype
17+
from float8_experimental.float8_utils import (
18+
amax_history_to_scale_stack,
19+
e4m3_dtype,
20+
e5m2_dtype,
21+
)
1822
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor
1923

2024
log = logging.getLogger(__name__)

float8_experimental/float8_tensor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
import torch
1010

1111
import torch.distributed._functional_collectives as funcol
12-
from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated, e4m3_dtype
12+
from float8_experimental.float8_utils import (
13+
e4m3_dtype,
14+
tensor_to_amax,
15+
to_fp8_saturated,
16+
)
1317
from torch.distributed._tensor import DTensor
1418

1519
aten = torch.ops.aten

test/test_base.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@
3030
)
3131
from float8_experimental.float8_utils import (
3232
compute_error,
33+
e4m3_dtype,
34+
e5m2_dtype,
3335
fp8_tensor_statistics,
3436
FP8_TYPES,
3537
tensor_to_scale,
36-
e4m3_dtype,
37-
e5m2_dtype,
3838
)
3939

4040
random.seed(0)
@@ -389,10 +389,15 @@ def test_merge_configs(self):
389389

390390

391391
class TestNumerics:
392-
@pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn,
393-
torch.float8_e5m2,
394-
torch.float8_e4m3fnuz,
395-
torch.float8_e5m2fnuz])
392+
@pytest.mark.parametrize(
393+
"float8_dtype",
394+
[
395+
torch.float8_e4m3fn,
396+
torch.float8_e5m2,
397+
torch.float8_e4m3fnuz,
398+
torch.float8_e5m2fnuz,
399+
],
400+
)
396401
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
397402
def test_small_amax_float16(self, float8_dtype):
398403
# If we calculate scale naively with FP8_MAX_POS / amax,

0 commit comments

Comments
 (0)