Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 0a24b01

Browse files
committed
[ROCm] Support for fnuz config
1 parent 1e9add3 commit 0a24b01

11 files changed

+61
-35
lines changed

float8_experimental/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,7 @@
1919
# implements pre/post-all-gather methods to do fp8 all-gather with FSDP2.
2020
# Only dynamic scaling is supported for now.
2121
enable_fsdp_fp8_all_gather = False
22+
23+
# If True, use 'fnuz' float8 types for calculations.
24+
# Currently, ROCm only supports fnuz variants.
25+
use_fnuz_dtype = False

float8_experimental/float8_dynamic_linear.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
tensor_already_casted_to_fp8,
2323
to_fp8_no_autograd,
2424
)
25-
from float8_experimental.float8_utils import tensor_to_scale
25+
from float8_experimental.float8_utils import tensor_to_scale, e4m3_dtype, e5m2_dtype
2626
from torch._prims_common import suggest_memory_format
2727

2828

@@ -46,9 +46,9 @@ def forward(
4646
def backward(ctx, gradY):
4747
if tensor_already_casted_to_fp8(gradY):
4848
return gradY, None
49-
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
49+
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
5050
fp8_tensor = to_fp8_no_autograd(
51-
gradY, gradY_scale, torch.float8_e5m2, mm_config=ctx.mm_config
51+
gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config
5252
)
5353
return fp8_tensor, None
5454

@@ -105,9 +105,9 @@ def cast_to_float8_e4m3fn(
105105
) -> Float8Tensor:
106106
if tensor_already_casted_to_fp8(inpt_tensor):
107107
return inpt_tensor
108-
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn, reduce_amax)
108+
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
109109
return Float8Tensor.to_float8(
110-
inpt_tensor, scale, torch.float8_e4m3fn, mm_config=mm_config
110+
inpt_tensor, scale, e4m3_dtype, mm_config=mm_config
111111
)
112112

113113

float8_experimental/float8_linear.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
to_fp8_no_autograd,
2222
)
2323

24-
from float8_experimental.float8_utils import amax_history_to_scale, tensor_to_amax
24+
from float8_experimental.float8_utils import amax_history_to_scale, tensor_to_amax, e4m3_dtype, e5m2_dtype
2525

2626

2727
def _maybe_initialize_amaxes_scales_for_float8_cast(
@@ -89,15 +89,15 @@ def backward(ctx, go):
8989
fp8_amax_history_dL_dY,
9090
fp8_scale_dL_dY,
9191
scale_fn_name,
92-
torch.float8_e5m2,
92+
e5m2_dtype,
9393
is_amax_initialized,
9494
reduce_amax=True,
9595
)
9696

9797
fp8_amax_dL_dY.fill_(tensor_to_amax(go))
9898

9999
res = to_fp8_no_autograd(
100-
go, fp8_scale_dL_dY, torch.float8_e5m2, mm_config=ctx.mm_config
100+
go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config
101101
)
102102
empty_grads = None, None, None, None, None, None
103103
return res, *empty_grads
@@ -236,14 +236,14 @@ def cast_x_to_float8(
236236
self.fp8_amax_history_x,
237237
self.fp8_scale_x,
238238
scale_fn_name,
239-
torch.float8_e4m3fn,
239+
e4m3_dtype,
240240
is_amax_initialized,
241241
reduce_amax=True,
242242
)
243243
x_fp8 = Float8Tensor.to_float8(
244244
x,
245245
self.fp8_scale_x,
246-
torch.float8_e4m3fn,
246+
e4m3_dtype,
247247
self.fp8_amax_x,
248248
self.forward_config,
249249
)
@@ -259,15 +259,15 @@ def cast_w_to_float8(
259259
self.fp8_amax_history_w,
260260
self.fp8_scale_w,
261261
scale_fn_name,
262-
torch.float8_e4m3fn,
262+
e4m3_dtype,
263263
is_amax_initialized,
264264
reduce_amax=False,
265265
)
266266

267267
w_fp8 = Float8Tensor.to_float8(
268268
w,
269269
self.fp8_scale_w,
270-
torch.float8_e4m3fn,
270+
e4m3_dtype,
271271
self.fp8_amax_w,
272272
self.forward_config,
273273
)

float8_experimental/float8_linear_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
1515
from float8_experimental.float8_linear import Float8Linear
1616

17-
from float8_experimental.float8_utils import amax_history_to_scale_stack
17+
from float8_experimental.float8_utils import amax_history_to_scale_stack, e4m3_dtype, e5m2_dtype
1818
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor
1919

2020
log = logging.getLogger(__name__)
@@ -298,13 +298,13 @@ def inner_func():
298298

299299
# Calculate the new scales from the updated history stacks
300300
new_x_scales = amax_history_to_scale_stack(
301-
fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
301+
fp8_x_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
302302
)
303303
new_w_scales = amax_history_to_scale_stack(
304-
fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
304+
fp8_w_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
305305
)
306306
new_dL_dY_scales = amax_history_to_scale_stack(
307-
fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe
307+
fp8_dL_dY_amax_history_stack, e5m2_dtype, x_dtype, scale_fn_recipe
308308
)
309309

310310
# Iterate through the layers and update the scales

float8_experimental/float8_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010

1111
import torch.distributed._functional_collectives as funcol
12-
from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated
12+
from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated, e4m3_dtype
1313
from torch.distributed._tensor import DTensor
1414

1515
aten = torch.ops.aten
@@ -125,7 +125,7 @@ def forward(
125125
ctx,
126126
tensor: torch.Tensor,
127127
scale: torch.Tensor,
128-
float8_dtype=torch.float8_e4m3fn,
128+
float8_dtype=e4m3_dtype,
129129
amax_buffer: Optional[torch.Tensor] = None,
130130
mm_config: Optional[ScaledMMConfig] = None,
131131
):

float8_experimental/float8_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@
99
import torch
1010
import torch.distributed as dist
1111

12+
import float8_experimental.config as config
13+
1214
# Helpful visualizer for debugging (only supports fp32):
1315
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
1416

1517
# avoid division by zero when calculating scale
1618
# TODO: align this value with NVIDIA's assumptions (current value is a guess)
1719
EPS = 1e-12
1820

19-
IS_AMD = torch.cuda.is_available() and torch.version.hip is not None
21+
IS_ROCM = torch.cuda.is_available() and torch.version.hip is not None
2022
FP8_TYPES = {
2123
torch.float8_e4m3fn,
2224
torch.float8_e5m2,
@@ -25,6 +27,11 @@
2527
}
2628

2729

30+
# User defined type for using the individual F8 type based on config
31+
e4m3_dtype = torch.float8_e4m3fn if not config.use_fnuz_dtype else torch.float8_e4m3fnuz
32+
e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz
33+
34+
2835
@torch.no_grad()
2936
def amax_to_scale(
3037
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
@@ -148,7 +155,7 @@ def compute_error(x: torch.Tensor, y: torch.Tensor):
148155

149156

150157
def fp8_tensor_statistics(
151-
tensor: torch.Tensor, float8_dtype=torch.float8_e4m3fn
158+
tensor: torch.Tensor, float8_dtype=e4m3_dtype
152159
) -> Tuple[int, ...]:
153160
"""Calculate FP8 tensor stats
154161

test/test_base.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
fp8_tensor_statistics,
3535
FP8_TYPES,
3636
tensor_to_scale,
37+
e4m3_dtype,
38+
e5m2_dtype,
3739
)
3840

3941
random.seed(0)
@@ -52,7 +54,7 @@ class TestFloat8Tensor(unittest.TestCase):
5254
def test_preserves_dtype(self) -> None:
5355
# hp means high precision, lp means low precision
5456
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
55-
lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2)
57+
lp_dtypes = FP8_TYPES
5658
for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
5759
x1_hp = torch.randn(4, 4, dtype=hp_dtype)
5860
x1_s = tensor_to_scale(x1_hp, lp_dtype)
@@ -61,7 +63,7 @@ def test_preserves_dtype(self) -> None:
6163
self.assertTrue(x3_hp.dtype == hp_dtype)
6264

6365
def test_differentiable_casts(self) -> None:
64-
lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2)
66+
lp_dtypes = (e4m3_dtype, e5m2_dtype)
6567
for f8_dtype in lp_dtypes:
6668
x = torch.randn(1).requires_grad_()
6769
grad = torch.randn(1)
@@ -74,8 +76,8 @@ def test_differentiable_casts(self) -> None:
7476

7577
def test_split_cat(self):
7678
a = torch.rand(16, 16, dtype=torch.bfloat16)
77-
scale = tensor_to_scale(a, torch.float8_e4m3fn)
78-
fp8_a = Float8Tensor.to_float8(a, scale, torch.float8_e4m3fn)
79+
scale = tensor_to_scale(a, e4m3_dtype)
80+
fp8_a = Float8Tensor.to_float8(a, scale, e4m3_dtype)
7981

8082
splits = torch.split(fp8_a, 16)
8183
catted = torch.cat(splits, dim=0)
@@ -314,7 +316,7 @@ class TestScaledMM:
314316
@pytest.mark.parametrize("use_fast_accum", [True, False])
315317
def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
316318
torch.manual_seed(42)
317-
input_dtype = torch.float8_e4m3fn
319+
input_dtype = e4m3_dtype
318320
output_dtype = base_dtype
319321
compare_type = torch.float32
320322

@@ -360,7 +362,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
360362
def test_different_configs_error(self):
361363
x_fp32 = torch.randn(16, 16, device="cuda")
362364
x_scale = torch.tensor(1.0, device="cuda")
363-
fp8_dtype = torch.float8_e4m3fn
365+
fp8_dtype = e4m3_dtype
364366
a = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype)
365367
b = Float8Tensor.to_float8(
366368
x_fp32, x_scale, fp8_dtype, mm_config=ScaledMMConfig(True)
@@ -395,7 +397,10 @@ def test_merge_configs(self):
395397

396398

397399
class TestNumerics:
398-
@pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
400+
@pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn,
401+
torch.float8_e5m2,
402+
torch.float8_e4m3fnuz,
403+
torch.float8_e5m2fnuz])
399404
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
400405
def test_small_amax_float16(self, float8_dtype):
401406
# If we calculate scale naively with FP8_MAX_POS / amax,
@@ -516,7 +521,7 @@ def __init__(self, dim: int):
516521

517522
def test_fp8_tensor_statistics(self):
518523
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
519-
lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2)
524+
lp_dtypes = (e4m3_dtype, e5m2_dtype)
520525
for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
521526
x1_hp = torch.ones(4, 4, dtype=hp_dtype)
522527
tensor_len = x1_hp.numel()

test/test_compile.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
sync_float8_amax_and_scale_history,
2323
)
2424
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
25+
from float8_experimental.float8_utils import e4m3_dtype, IS_ROCM
2526

2627
from torch._dynamo.test_case import TestCase as DynamoTestCase
2728
from torch._dynamo.testing import CompileCounterWithBackend
@@ -116,7 +117,7 @@ def forward(self, x):
116117
x_fp8 = Float8Tensor.to_float8(
117118
x,
118119
self.fp8_scale_x,
119-
torch.float8_e4m3fn,
120+
e4m3_dtype,
120121
self.fp8_amax_x,
121122
ScaledMMConfig(),
122123
)
@@ -127,12 +128,14 @@ def forward(self, x):
127128
return x_fp8
128129

129130
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
131+
@unittest.skipIf(IS_ROCM, "test doesn't currently work on the ROCm stack")
130132
def test_float8_with_graph_break_in_the_middle(self):
131133
"""Test that having Float8Tensor object at the boundary of a subgraph"""
132134
cnts = CompileCounterWithBackend("inductor")
133135
mod = self.MockLinear(graph_break=True).cuda()
134136
compiled_mod = copy.deepcopy(mod)
135137
compiled_mod = torch.compile(compiled_mod, backend=cnts)
138+
torch.manual_seed(0)
136139
x = torch.randn(16, 16, device="cuda")
137140
y_eager = mod(x)
138141
y_compiled = compiled_mod(x)

test/test_dtensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
Float8RowwiseParallel,
2626
PrepareFloat8ModuleInput,
2727
)
28-
from float8_experimental.float8_utils import tensor_to_scale
28+
from float8_experimental.float8_utils import tensor_to_scale, e4m3_dtype
2929
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
3030
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
3131
from torch.distributed.tensor.parallel import parallelize_module
@@ -64,7 +64,7 @@ def forward(self, x):
6464

6565
def test_scaled_mm(mesh: DeviceMesh, size=16):
6666
device = mesh.device_type
67-
fp8_dtype = torch.float8_e4m3fn
67+
fp8_dtype = e4m3_dtype
6868
world_size = mesh.size()
6969

7070
x_fp32 = torch.rand(size, size, device=device)
@@ -103,7 +103,7 @@ def test_scaled_mm(mesh: DeviceMesh, size=16):
103103

104104
def test_fp8_redistribute(mesh: DeviceMesh, size=16):
105105
device = mesh.device_type
106-
fp8_dtype = torch.float8_e4m3fn
106+
fp8_dtype = e4m3_dtype
107107
world_size = mesh.size()
108108

109109
x_fp32 = torch.rand(size, size, device=device)
@@ -130,7 +130,7 @@ def test_fp8_redistribute(mesh: DeviceMesh, size=16):
130130

131131
def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16):
132132
device = mesh.device_type
133-
fp8_dtype = torch.float8_e4m3fn
133+
fp8_dtype = e4m3_dtype
134134

135135
x_fp32 = torch.rand(size, size, device=device)
136136
dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)])
@@ -144,7 +144,7 @@ def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16):
144144

145145
def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
146146
device = mesh.device_type
147-
fp8_dtype = torch.float8_e4m3fn
147+
fp8_dtype = e4m3_dtype
148148

149149
x_fp32 = torch.rand(size, size, device=device, requires_grad=True)
150150
local_weight = torch.rand(2 * size, size, device=device, requires_grad=True)

test/test_everything.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,19 @@
22

33
# terminate script on first error
44
set -e
5+
IS_ROCM=$(rocm-smi --version)
56

67
pytest test/test_base.py
78
pytest test/test_sam.py
89
pytest test/test_compile.py
10+
11+
# These tests do not work on ROCm yet
12+
if [ -z "$IS_ROCM" ]
13+
then
914
./test/test_fsdp.sh
1015
./test/test_fsdp_compile.sh
1116
./test/test_dtensor.sh
1217
pytest test/test_fsdp2/test_fsdp2_eager.py
18+
fi
1319

1420
echo "all tests successful"

test/test_sam.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
swap_linear_with_float8_linear,
1919
sync_float8_amax_and_scale_history,
2020
)
21-
from float8_experimental.float8_utils import compute_error
21+
from float8_experimental.float8_utils import compute_error, IS_ROCM
2222
from transformers import SamModel
2323

2424
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
@@ -31,6 +31,7 @@ class TestFloat8SAMIntegrationTest:
3131
@pytest.mark.parametrize("data_dtype", [torch.float16, torch.bfloat16])
3232
@pytest.mark.parametrize("linear_type", [Float8Linear, Float8DynamicLinear])
3333
@pytest.mark.skipif(not is_H100, reason="requires H100 GPU")
34+
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
3435
def test_encoder_fw_bw(self, data_dtype, linear_type):
3536
model = SamModel.from_pretrained("facebook/sam-vit-base").to(data_dtype).cuda()
3637
# print(model)

0 commit comments

Comments
 (0)