Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 0bd374d

Browse files
alugoreyfacebook-github-bot
authored andcommitted
Changes on top of upstream to get rid of type errors (#248)
Summary: Fixes the class of failed unit tests on rocm in test_base.py that fail the internal assertion `Cannot convert ScalarType Float8_e4m3fn to hipDataType.` Note: We are aware of the outstanding numerical issues and are looking into it internally. Pull Request resolved: #248 Reviewed By: vkuzo Differential Revision: D58652172 Pulled By: drisspg fbshipit-source-id: b62845a8eb3773bd4de5396e8c47aef94cd7e600
1 parent edae9a3 commit 0bd374d

11 files changed

+79
-37
lines changed

float8_experimental/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,7 @@
1919
# implements pre/post-all-gather methods to do fp8 all-gather with FSDP2.
2020
# Only dynamic scaling is supported for now.
2121
enable_fsdp_fp8_all_gather = False
22+
23+
# If True, use 'fnuz' float8 types for calculations.
24+
# Currently, ROCm only supports fnuz variants.
25+
use_fnuz_dtype = False

float8_experimental/float8_dynamic_linear.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
tensor_already_casted_to_fp8,
2323
to_fp8_no_autograd,
2424
)
25-
from float8_experimental.float8_utils import tensor_to_scale
25+
from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_scale
2626
from torch._prims_common import suggest_memory_format
2727

2828

@@ -46,9 +46,9 @@ def forward(
4646
def backward(ctx, gradY):
4747
if tensor_already_casted_to_fp8(gradY):
4848
return gradY, None
49-
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
49+
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
5050
fp8_tensor = to_fp8_no_autograd(
51-
gradY, gradY_scale, torch.float8_e5m2, mm_config=ctx.mm_config
51+
gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config
5252
)
5353
return fp8_tensor, None
5454

@@ -105,10 +105,8 @@ def cast_to_float8_e4m3fn(
105105
) -> Float8Tensor:
106106
if tensor_already_casted_to_fp8(inpt_tensor):
107107
return inpt_tensor
108-
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn, reduce_amax)
109-
return Float8Tensor.to_float8(
110-
inpt_tensor, scale, torch.float8_e4m3fn, mm_config=mm_config
111-
)
108+
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
109+
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)
112110

113111

114112
def cast_to_float8_e5m2_bw(

float8_experimental/float8_linear.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
to_fp8_no_autograd,
2222
)
2323

24-
from float8_experimental.float8_utils import amax_history_to_scale, tensor_to_amax
24+
from float8_experimental.float8_utils import (
25+
amax_history_to_scale,
26+
e4m3_dtype,
27+
e5m2_dtype,
28+
tensor_to_amax,
29+
)
2530

2631

2732
def _maybe_initialize_amaxes_scales_for_float8_cast(
@@ -89,15 +94,15 @@ def backward(ctx, go):
8994
fp8_amax_history_dL_dY,
9095
fp8_scale_dL_dY,
9196
scale_fn_name,
92-
torch.float8_e5m2,
97+
e5m2_dtype,
9398
is_amax_initialized,
9499
reduce_amax=True,
95100
)
96101

97102
fp8_amax_dL_dY.fill_(tensor_to_amax(go))
98103

99104
res = to_fp8_no_autograd(
100-
go, fp8_scale_dL_dY, torch.float8_e5m2, mm_config=ctx.mm_config
105+
go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config
101106
)
102107
empty_grads = None, None, None, None, None, None
103108
return res, *empty_grads
@@ -236,14 +241,14 @@ def cast_x_to_float8(
236241
self.fp8_amax_history_x,
237242
self.fp8_scale_x,
238243
scale_fn_name,
239-
torch.float8_e4m3fn,
244+
e4m3_dtype,
240245
is_amax_initialized,
241246
reduce_amax=True,
242247
)
243248
x_fp8 = Float8Tensor.to_float8(
244249
x,
245250
self.fp8_scale_x,
246-
torch.float8_e4m3fn,
251+
e4m3_dtype,
247252
self.fp8_amax_x,
248253
self.forward_config,
249254
)
@@ -259,15 +264,15 @@ def cast_w_to_float8(
259264
self.fp8_amax_history_w,
260265
self.fp8_scale_w,
261266
scale_fn_name,
262-
torch.float8_e4m3fn,
267+
e4m3_dtype,
263268
is_amax_initialized,
264269
reduce_amax=False,
265270
)
266271

267272
w_fp8 = Float8Tensor.to_float8(
268273
w,
269274
self.fp8_scale_w,
270-
torch.float8_e4m3fn,
275+
e4m3_dtype,
271276
self.fp8_amax_w,
272277
self.forward_config,
273278
)

float8_experimental/float8_linear_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
1515
from float8_experimental.float8_linear import Float8Linear
1616

17-
from float8_experimental.float8_utils import amax_history_to_scale_stack
17+
from float8_experimental.float8_utils import (
18+
amax_history_to_scale_stack,
19+
e4m3_dtype,
20+
e5m2_dtype,
21+
)
1822
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor
1923

2024
log = logging.getLogger(__name__)
@@ -298,13 +302,13 @@ def inner_func():
298302

299303
# Calculate the new scales from the updated history stacks
300304
new_x_scales = amax_history_to_scale_stack(
301-
fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
305+
fp8_x_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
302306
)
303307
new_w_scales = amax_history_to_scale_stack(
304-
fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
308+
fp8_w_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
305309
)
306310
new_dL_dY_scales = amax_history_to_scale_stack(
307-
fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe
311+
fp8_dL_dY_amax_history_stack, e5m2_dtype, x_dtype, scale_fn_recipe
308312
)
309313

310314
# Iterate through the layers and update the scales

float8_experimental/float8_tensor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
import torch
1010

1111
import torch.distributed._functional_collectives as funcol
12-
from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated
12+
from float8_experimental.float8_utils import (
13+
e4m3_dtype,
14+
tensor_to_amax,
15+
to_fp8_saturated,
16+
)
1317
from torch.distributed._tensor import DTensor
1418

1519
aten = torch.ops.aten
@@ -125,7 +129,7 @@ def forward(
125129
ctx,
126130
tensor: torch.Tensor,
127131
scale: torch.Tensor,
128-
float8_dtype=torch.float8_e4m3fn,
132+
float8_dtype=e4m3_dtype,
129133
amax_buffer: Optional[torch.Tensor] = None,
130134
mm_config: Optional[ScaledMMConfig] = None,
131135
):

float8_experimental/float8_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from typing import Literal, Tuple
88

9+
import float8_experimental.config as config
10+
911
import torch
1012
import torch.distributed as dist
1113

@@ -16,7 +18,7 @@
1618
# TODO: align this value with NVIDIA's assumptions (current value is a guess)
1719
EPS = 1e-12
1820

19-
IS_AMD = torch.cuda.is_available() and torch.version.hip is not None
21+
IS_ROCM = torch.cuda.is_available() and torch.version.hip is not None
2022
FP8_TYPES = {
2123
torch.float8_e4m3fn,
2224
torch.float8_e5m2,
@@ -25,6 +27,11 @@
2527
}
2628

2729

30+
# User defined type for using the individual F8 type based on config
31+
e4m3_dtype = torch.float8_e4m3fn if not config.use_fnuz_dtype else torch.float8_e4m3fnuz
32+
e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz
33+
34+
2835
@torch.no_grad()
2936
def amax_to_scale(
3037
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
@@ -148,7 +155,7 @@ def compute_error(x: torch.Tensor, y: torch.Tensor):
148155

149156

150157
def fp8_tensor_statistics(
151-
tensor: torch.Tensor, float8_dtype=torch.float8_e4m3fn
158+
tensor: torch.Tensor, float8_dtype=e4m3_dtype
152159
) -> Tuple[int, ...]:
153160
"""Calculate FP8 tensor stats
154161

test/test_base.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
)
3131
from float8_experimental.float8_utils import (
3232
compute_error,
33+
e4m3_dtype,
34+
e5m2_dtype,
3335
fp8_tensor_statistics,
3436
FP8_TYPES,
3537
tensor_to_scale,
@@ -51,7 +53,7 @@ class TestFloat8Tensor(unittest.TestCase):
5153
def test_preserves_dtype(self) -> None:
5254
# hp means high precision, lp means low precision
5355
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
54-
lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2)
56+
lp_dtypes = FP8_TYPES
5557
for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
5658
x1_hp = torch.randn(4, 4, dtype=hp_dtype)
5759
x1_s = tensor_to_scale(x1_hp, lp_dtype)
@@ -60,7 +62,7 @@ def test_preserves_dtype(self) -> None:
6062
self.assertTrue(x3_hp.dtype == hp_dtype)
6163

6264
def test_differentiable_casts(self) -> None:
63-
lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2)
65+
lp_dtypes = (e4m3_dtype, e5m2_dtype)
6466
for f8_dtype in lp_dtypes:
6567
x = torch.randn(1).requires_grad_()
6668
grad = torch.randn(1)
@@ -73,8 +75,8 @@ def test_differentiable_casts(self) -> None:
7375

7476
def test_split_cat(self):
7577
a = torch.rand(16, 16, dtype=torch.bfloat16)
76-
scale = tensor_to_scale(a, torch.float8_e4m3fn)
77-
fp8_a = Float8Tensor.to_float8(a, scale, torch.float8_e4m3fn)
78+
scale = tensor_to_scale(a, e4m3_dtype)
79+
fp8_a = Float8Tensor.to_float8(a, scale, e4m3_dtype)
7880

7981
splits = torch.split(fp8_a, 16)
8082
catted = torch.cat(splits, dim=0)
@@ -313,7 +315,7 @@ class TestScaledMM:
313315
@pytest.mark.parametrize("use_fast_accum", [True, False])
314316
def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
315317
torch.manual_seed(42)
316-
input_dtype = torch.float8_e4m3fn
318+
input_dtype = e4m3_dtype
317319
output_dtype = base_dtype
318320
compare_type = torch.float32
319321

@@ -352,7 +354,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
352354
def test_different_configs_error(self):
353355
x_fp32 = torch.randn(16, 16, device="cuda")
354356
x_scale = torch.tensor(1.0, device="cuda")
355-
fp8_dtype = torch.float8_e4m3fn
357+
fp8_dtype = e4m3_dtype
356358
a = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype)
357359
b = Float8Tensor.to_float8(
358360
x_fp32, x_scale, fp8_dtype, mm_config=ScaledMMConfig(True)
@@ -387,7 +389,15 @@ def test_merge_configs(self):
387389

388390

389391
class TestNumerics:
390-
@pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
392+
@pytest.mark.parametrize(
393+
"float8_dtype",
394+
[
395+
torch.float8_e4m3fn,
396+
torch.float8_e5m2,
397+
torch.float8_e4m3fnuz,
398+
torch.float8_e5m2fnuz,
399+
],
400+
)
391401
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
392402
def test_small_amax_float16(self, float8_dtype):
393403
# If we calculate scale naively with FP8_MAX_POS / amax,
@@ -508,7 +518,7 @@ def __init__(self, dim: int):
508518

509519
def test_fp8_tensor_statistics(self):
510520
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
511-
lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2)
521+
lp_dtypes = (e4m3_dtype, e5m2_dtype)
512522
for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
513523
x1_hp = torch.ones(4, 4, dtype=hp_dtype)
514524
tensor_len = x1_hp.numel()

test/test_compile.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
sync_float8_amax_and_scale_history,
2323
)
2424
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
25+
from float8_experimental.float8_utils import e4m3_dtype, IS_ROCM
2526

2627
from torch._dynamo.test_case import TestCase as DynamoTestCase
2728
from torch._dynamo.testing import CompileCounterWithBackend
@@ -116,7 +117,7 @@ def forward(self, x):
116117
x_fp8 = Float8Tensor.to_float8(
117118
x,
118119
self.fp8_scale_x,
119-
torch.float8_e4m3fn,
120+
e4m3_dtype,
120121
self.fp8_amax_x,
121122
ScaledMMConfig(),
122123
)
@@ -127,12 +128,14 @@ def forward(self, x):
127128
return x_fp8
128129

129130
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
131+
@unittest.skipIf(IS_ROCM, "test doesn't currently work on the ROCm stack")
130132
def test_float8_with_graph_break_in_the_middle(self):
131133
"""Test that having Float8Tensor object at the boundary of a subgraph"""
132134
cnts = CompileCounterWithBackend("inductor")
133135
mod = self.MockLinear(graph_break=True).cuda()
134136
compiled_mod = copy.deepcopy(mod)
135137
compiled_mod = torch.compile(compiled_mod, backend=cnts)
138+
torch.manual_seed(0)
136139
x = torch.randn(16, 16, device="cuda")
137140
y_eager = mod(x)
138141
y_compiled = compiled_mod(x)

test/test_dtensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
Float8RowwiseParallel,
2626
PrepareFloat8ModuleInput,
2727
)
28-
from float8_experimental.float8_utils import tensor_to_scale
28+
from float8_experimental.float8_utils import e4m3_dtype, tensor_to_scale
2929
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
3030
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
3131
from torch.distributed.tensor.parallel import parallelize_module
@@ -64,7 +64,7 @@ def forward(self, x):
6464

6565
def test_scaled_mm(mesh: DeviceMesh, size=16):
6666
device = mesh.device_type
67-
fp8_dtype = torch.float8_e4m3fn
67+
fp8_dtype = e4m3_dtype
6868
world_size = mesh.size()
6969

7070
x_fp32 = torch.rand(size, size, device=device)
@@ -103,7 +103,7 @@ def test_scaled_mm(mesh: DeviceMesh, size=16):
103103

104104
def test_fp8_redistribute(mesh: DeviceMesh, size=16):
105105
device = mesh.device_type
106-
fp8_dtype = torch.float8_e4m3fn
106+
fp8_dtype = e4m3_dtype
107107
world_size = mesh.size()
108108

109109
x_fp32 = torch.rand(size, size, device=device)
@@ -130,7 +130,7 @@ def test_fp8_redistribute(mesh: DeviceMesh, size=16):
130130

131131
def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16):
132132
device = mesh.device_type
133-
fp8_dtype = torch.float8_e4m3fn
133+
fp8_dtype = e4m3_dtype
134134

135135
x_fp32 = torch.rand(size, size, device=device)
136136
dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)])
@@ -144,7 +144,7 @@ def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16):
144144

145145
def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
146146
device = mesh.device_type
147-
fp8_dtype = torch.float8_e4m3fn
147+
fp8_dtype = e4m3_dtype
148148

149149
x_fp32 = torch.rand(size, size, device=device, requires_grad=True)
150150
local_weight = torch.rand(2 * size, size, device=device, requires_grad=True)

test/test_everything.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,19 @@
22

33
# terminate script on first error
44
set -e
5+
IS_ROCM=$(rocm-smi --version || true)
56

67
pytest test/test_base.py
78
pytest test/test_sam.py
89
pytest test/test_compile.py
10+
11+
# These tests do not work on ROCm yet
12+
if [ -z "$IS_ROCM" ]
13+
then
914
./test/test_fsdp.sh
1015
./test/test_fsdp_compile.sh
1116
./test/test_dtensor.sh
1217
pytest test/test_fsdp2/test_fsdp2_eager.py
18+
fi
1319

1420
echo "all tests successful"

test/test_sam.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
swap_linear_with_float8_linear,
1919
sync_float8_amax_and_scale_history,
2020
)
21-
from float8_experimental.float8_utils import compute_error
21+
from float8_experimental.float8_utils import compute_error, IS_ROCM
2222
from transformers import SamModel
2323

2424
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
@@ -31,6 +31,7 @@ class TestFloat8SAMIntegrationTest:
3131
@pytest.mark.parametrize("data_dtype", [torch.float16, torch.bfloat16])
3232
@pytest.mark.parametrize("linear_type", [Float8Linear, Float8DynamicLinear])
3333
@pytest.mark.skipif(not is_H100, reason="requires H100 GPU")
34+
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
3435
def test_encoder_fw_bw(self, data_dtype, linear_type):
3536
model = SamModel.from_pretrained("facebook/sam-vit-base").to(data_dtype).cuda()
3637
# print(model)

0 commit comments

Comments
 (0)