Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 7c53229

Browse files
y-sqfacebook-github-bot
authored andcommitted
minor changes in float8_experimental
Reviewed By: vkuzo Differential Revision: D56867907 fbshipit-source-id: 27e61d7c14d2d406c19bfe728693ef988befa2e8
1 parent ac065d0 commit 7c53229

File tree

3 files changed

+13
-17
lines changed

3 files changed

+13
-17
lines changed

float8_experimental/float8_linear.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,10 @@ def __init__(self, *args, **kwargs):
183183

184184
# pre_forward and post_forward are currently broken with FSDP
185185
# and torch.compile, this option can disable them
186+
# Note that when using `config.enable_pre_and_post_forward = False`,
187+
# it's recommended to also set `config.enable_amax_init = False`.
188+
# Otherwise, the amax buffer would never be marked as initialized and
189+
# would be initialized in every iteration.
186190
self.enable_pre_and_post_forward = config.enable_pre_and_post_forward
187191

188192
def register_always_float32_buffer(

float8_experimental/float8_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,9 @@ def fp8_tensor_statistics(
119119
else: # e5m2
120120
FP8_MAX = E5M2_MAX_POS
121121
tensor_orig_type = tensor._data.to(dtype=tensor._orig_dtype)
122-
num_overflows = (tensor_orig_type == FP8_MAX).sum().item()
123-
num_underflows = (tensor_orig_type == 0).sum().item()
124-
return (num_underflows, num_overflows)
122+
num_max = (torch.abs(tensor_orig_type) == FP8_MAX).sum().item()
123+
num_zero = (tensor_orig_type == 0).sum().item()
124+
return (num_zero, num_max)
125125

126126

127127
def is_row_major(stride):

test/test_base.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -499,30 +499,22 @@ def test_fp8_tensor_statistics(self):
499499
# Overflow caused by a too large scaling factor
500500
s_overflow = torch.tensor(1e9)
501501
fp8_overflow = Float8Tensor.to_float8(x1_hp, s_overflow, lp_dtype)
502-
(underflow_cnt, fp8_overflow_cnt) = fp8_tensor_statistics(
503-
fp8_overflow, lp_dtype
504-
)
505-
self.assertEqual((underflow_cnt, fp8_overflow_cnt), (0, tensor_len))
502+
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_overflow, lp_dtype)
503+
self.assertEqual((zero_cnt, max_cnt), (0, tensor_len))
506504

507505
# Underflow caused by a too small scaling factor
508506
s_underflow = torch.tensor(1e-9)
509507
fp8_underflow = Float8Tensor.to_float8(x1_hp, s_underflow, lp_dtype)
510-
(underflow_cnt, fp8_overflow_cnt) = fp8_tensor_statistics(
511-
fp8_underflow, lp_dtype
512-
)
513-
self.assertEqual((underflow_cnt, fp8_overflow_cnt), (tensor_len, 0))
508+
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_underflow, lp_dtype)
509+
self.assertEqual((zero_cnt, max_cnt), (tensor_len, 0))
514510

515511
# Both overflow and underflow
516512
x2_hp = torch.cat((x1_hp * 1e9, x1_hp * 1.0, x1_hp * 1e-9), 0)
517513
fp8_over_underflow = Float8Tensor.to_float8(
518514
x2_hp, torch.tensor(1.0), lp_dtype
519515
)
520-
(underflow_cnt, fp8_overflow_cnt) = fp8_tensor_statistics(
521-
fp8_over_underflow, lp_dtype
522-
)
523-
self.assertEqual(
524-
(underflow_cnt, fp8_overflow_cnt), (tensor_len, tensor_len)
525-
)
516+
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_over_underflow, lp_dtype)
517+
self.assertEqual((zero_cnt, max_cnt), (tensor_len, tensor_len))
526518

527519

528520
if __name__ == "__main__":

0 commit comments

Comments
 (0)