Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 57120fa

Browse files
committed
[ROCm] Support for fnuz config
1 parent 5fc07fc commit 57120fa

11 files changed

+65
-38
lines changed

float8_experimental/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,8 @@
1919
# implements pre/post-all-gather methods to do fp8 all-gather with FSDP2.
2020
# Only dynamic scaling is supported for now.
2121
enable_fsdp_fp8_all_gather = False
22+
23+
# If True, use 'fnuz' float8 types for calculations. If the backend
24+
# hardware does not support a particular dtype, the emulated implementation
25+
# of the dtype will be used. Currently, ROCm only supports fnuz variants.
26+
use_fnuz_dtype = False

float8_experimental/float8_dynamic_linear.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
tensor_already_casted_to_fp8,
2323
to_fp8_no_autograd,
2424
)
25-
from float8_experimental.float8_utils import tensor_to_scale
25+
from float8_experimental.float8_utils import tensor_to_scale, e4m3_dtype, e5m2_dtype
2626
from torch._prims_common import suggest_memory_format
2727

2828

@@ -46,9 +46,9 @@ def forward(
4646
def backward(ctx, gradY):
4747
if tensor_already_casted_to_fp8(gradY):
4848
return gradY, None
49-
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
49+
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
5050
fp8_tensor = to_fp8_no_autograd(
51-
gradY, gradY_scale, torch.float8_e5m2, mm_config=ctx.mm_config
51+
gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config
5252
)
5353
return fp8_tensor, None
5454

@@ -105,9 +105,9 @@ def cast_to_float8_e4m3fn(
105105
) -> Float8Tensor:
106106
if tensor_already_casted_to_fp8(inpt_tensor):
107107
return inpt_tensor
108-
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn, reduce_amax)
108+
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
109109
return Float8Tensor.to_float8(
110-
inpt_tensor, scale, torch.float8_e4m3fn, mm_config=mm_config
110+
inpt_tensor, scale, e4m3_dtype, mm_config=mm_config
111111
)
112112

113113

float8_experimental/float8_linear.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
to_fp8_no_autograd,
2727
)
2828

29-
from float8_experimental.float8_utils import amax_history_to_scale, tensor_to_amax
29+
from float8_experimental.float8_utils import amax_history_to_scale, tensor_to_amax, e4m3_dtype, e5m2_dtype
3030

3131

3232
def _maybe_initialize_amaxes_scales_for_float8_cast(
@@ -92,14 +92,14 @@ def backward(ctx, go):
9292
fp8_amax_history_dL_dY,
9393
fp8_scale_dL_dY,
9494
scale_fn_name,
95-
torch.float8_e5m2,
95+
e5m2_dtype,
9696
is_amax_initialized,
9797
)
9898

9999
fp8_amax_dL_dY.fill_(tensor_to_amax(go))
100100

101101
res = to_fp8_no_autograd(
102-
go, fp8_scale_dL_dY, torch.float8_e5m2, mm_config=ctx.mm_config
102+
go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config
103103
)
104104
empty_grads = None, None, None, None, None, None
105105
return res, *empty_grads
@@ -138,9 +138,9 @@ def __init__(self, *args, **kwargs):
138138
history_len = self.recipe.history_len
139139

140140
# Default values for history buffers, see above TODO
141-
default_x = torch.finfo(torch.float8_e4m3fn).max
142-
default_w = torch.finfo(torch.float8_e4m3fn).max
143-
default_dl_dy = torch.finfo(torch.float8_e5m2).max
141+
default_x = torch.finfo(e4m3_dtype).max
142+
default_w = torch.finfo(e4m3_dtype).max
143+
default_dl_dy = torch.finfo(e5m2_dtype).max
144144

145145
self.register_always_float32_buffer("fp8_amax_x", torch.tensor([default_x]))
146146
self.register_always_float32_buffer(
@@ -223,13 +223,13 @@ def cast_x_to_float8(
223223
self.fp8_amax_history_x,
224224
self.fp8_scale_x,
225225
scale_fn_name,
226-
torch.float8_e4m3fn,
226+
e4m3_dtype,
227227
is_amax_initialized,
228228
)
229229
x_fp8 = Float8Tensor.to_float8(
230230
x,
231231
self.fp8_scale_x,
232-
torch.float8_e4m3fn,
232+
e4m3_dtype,
233233
self.fp8_amax_x,
234234
self.forward_config,
235235
)
@@ -245,14 +245,14 @@ def cast_w_to_float8(
245245
self.fp8_amax_history_w,
246246
self.fp8_scale_w,
247247
scale_fn_name,
248-
torch.float8_e4m3fn,
248+
e4m3_dtype,
249249
is_amax_initialized,
250250
)
251251

252252
w_fp8 = Float8Tensor.to_float8(
253253
w,
254254
self.fp8_scale_w,
255-
torch.float8_e4m3fn,
255+
e4m3_dtype,
256256
self.fp8_amax_w,
257257
self.forward_config,
258258
)

float8_experimental/float8_linear_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
1515
from float8_experimental.float8_linear import Float8Linear
1616

17-
from float8_experimental.float8_utils import amax_history_to_scale_stack
17+
from float8_experimental.float8_utils import amax_history_to_scale_stack, e4m3_dtype, e5m2_dtype
1818
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor
1919

2020
log = logging.getLogger(__name__)
@@ -300,13 +300,13 @@ def inner_func():
300300

301301
# Calculate the new scales from the updated history stacks
302302
new_x_scales = amax_history_to_scale_stack(
303-
fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
303+
fp8_x_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
304304
)
305305
new_w_scales = amax_history_to_scale_stack(
306-
fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
306+
fp8_w_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
307307
)
308308
new_dL_dY_scales = amax_history_to_scale_stack(
309-
fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe
309+
fp8_dL_dY_amax_history_stack, e5m2_dtype, x_dtype, scale_fn_recipe
310310
)
311311

312312
# Iterate through the layers and update the scales

float8_experimental/float8_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010

1111
import torch.distributed._functional_collectives as funcol
12-
from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated
12+
from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated, e4m3_dtype
1313
from torch.distributed._tensor import DTensor
1414

1515
aten = torch.ops.aten
@@ -125,7 +125,7 @@ def forward(
125125
ctx,
126126
tensor: torch.Tensor,
127127
scale: torch.Tensor,
128-
float8_dtype=torch.float8_e4m3fn,
128+
float8_dtype=e4m3_dtype,
129129
amax_buffer: Optional[torch.Tensor] = None,
130130
mm_config: Optional[ScaledMMConfig] = None,
131131
):

float8_experimental/float8_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@
99
import torch
1010
import torch.distributed as dist
1111

12+
import float8_experimental.config as config
13+
1214
# Helpful visualizer for debugging (only supports fp32):
1315
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
1416

1517
# avoid division by zero when calculating scale
1618
# TODO: align this value with NVIDIA's assumptions (current value is a guess)
1719
EPS = 1e-12
1820

19-
IS_AMD = torch.cuda.is_available() and torch.version.hip is not None
21+
IS_ROCM = torch.cuda.is_available() and torch.version.hip is not None
2022
FP8_TYPES = {
2123
torch.float8_e4m3fn,
2224
torch.float8_e5m2,
@@ -25,6 +27,11 @@
2527
}
2628

2729

30+
# User defined type for using the individual F8 type based on config
31+
e4m3_dtype = torch.float8_e4m3fn if not config.use_fnuz_dtype else torch.float8_e4m3fnuz
32+
e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz
33+
34+
2835
@torch.no_grad()
2936
def amax_to_scale(
3037
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
@@ -148,7 +155,7 @@ def compute_error(x: torch.Tensor, y: torch.Tensor):
148155

149156

150157
def fp8_tensor_statistics(
151-
tensor: torch.Tensor, float8_dtype=torch.float8_e4m3fn
158+
tensor: torch.Tensor, float8_dtype=e4m3_dtype
152159
) -> Tuple[int, ...]:
153160
"""Calculate FP8 tensor stats
154161

test/test_base.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
fp8_tensor_statistics,
3535
FP8_TYPES,
3636
tensor_to_scale,
37+
e4m3_dtype,
38+
e5m2_dtype,
3739
)
3840

3941
random.seed(0)
@@ -52,7 +54,7 @@ class TestFloat8Tensor(unittest.TestCase):
5254
def test_preserves_dtype(self) -> None:
5355
# hp means high precision, lp means low precision
5456
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
55-
lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2)
57+
lp_dtypes = FP8_TYPES
5658
for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
5759
x1_hp = torch.randn(4, 4, dtype=hp_dtype)
5860
x1_s = tensor_to_scale(x1_hp, lp_dtype)
@@ -61,7 +63,7 @@ def test_preserves_dtype(self) -> None:
6163
self.assertTrue(x3_hp.dtype == hp_dtype)
6264

6365
def test_differentiable_casts(self) -> None:
64-
lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2)
66+
lp_dtypes = (e4m3_dtype, e5m2_dtype)
6567
for f8_dtype in lp_dtypes:
6668
x = torch.randn(1).requires_grad_()
6769
grad = torch.randn(1)
@@ -74,8 +76,8 @@ def test_differentiable_casts(self) -> None:
7476

7577
def test_split_cat(self):
7678
a = torch.rand(16, 16, dtype=torch.bfloat16)
77-
scale = tensor_to_scale(a, torch.float8_e4m3fn)
78-
fp8_a = Float8Tensor.to_float8(a, scale, torch.float8_e4m3fn)
79+
scale = tensor_to_scale(a, e4m3_dtype)
80+
fp8_a = Float8Tensor.to_float8(a, scale, e4m3_dtype)
7981

8082
splits = torch.split(fp8_a, 16)
8183
catted = torch.cat(splits, dim=0)
@@ -314,7 +316,7 @@ class TestScaledMM:
314316
@pytest.mark.parametrize("use_fast_accum", [True, False])
315317
def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
316318
torch.manual_seed(42)
317-
input_dtype = torch.float8_e4m3fn
319+
input_dtype = e4m3_dtype
318320
output_dtype = base_dtype
319321
compare_type = torch.float32
320322

@@ -360,7 +362,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
360362
def test_different_configs_error(self):
361363
x_fp32 = torch.randn(16, 16, device="cuda")
362364
x_scale = torch.tensor(1.0, device="cuda")
363-
fp8_dtype = torch.float8_e4m3fn
365+
fp8_dtype = e4m3_dtype
364366
a = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype)
365367
b = Float8Tensor.to_float8(
366368
x_fp32, x_scale, fp8_dtype, mm_config=ScaledMMConfig(True)
@@ -395,7 +397,10 @@ def test_merge_configs(self):
395397

396398

397399
class TestNumerics:
398-
@pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
400+
@pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn,
401+
torch.float8_e5m2,
402+
torch.float8_e4m3fnuz,
403+
torch.float8_e5m2fnuz])
399404
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
400405
def test_small_amax_float16(self, float8_dtype):
401406
# If we calculate scale naively with FP8_MAX_POS / amax,
@@ -516,7 +521,7 @@ def __init__(self, dim: int):
516521

517522
def test_fp8_tensor_statistics(self):
518523
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
519-
lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2)
524+
lp_dtypes = (e4m3_dtype, e5m2_dtype)
520525
for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
521526
x1_hp = torch.ones(4, 4, dtype=hp_dtype)
522527
tensor_len = x1_hp.numel()

test/test_compile.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
sync_float8_amax_and_scale_history,
2323
)
2424
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
25+
from float8_experimental.float8_utils import e4m3_dtype, IS_ROCM
2526

2627
from torch._dynamo.test_case import TestCase as DynamoTestCase
2728
from torch._dynamo.testing import CompileCounterWithBackend
@@ -116,7 +117,7 @@ def forward(self, x):
116117
x_fp8 = Float8Tensor.to_float8(
117118
x,
118119
self.fp8_scale_x,
119-
torch.float8_e4m3fn,
120+
e4m3_dtype,
120121
self.fp8_amax_x,
121122
ScaledMMConfig(),
122123
)
@@ -127,12 +128,14 @@ def forward(self, x):
127128
return x_fp8
128129

129130
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
131+
@unittest.skipIf(IS_ROCM, "test doesn't currently work on the ROCm stack")
130132
def test_float8_with_graph_break_in_the_middle(self):
131133
"""Test that having Float8Tensor object at the boundary of a subgraph"""
132134
cnts = CompileCounterWithBackend("inductor")
133135
mod = self.MockLinear(graph_break=True).cuda()
134136
compiled_mod = copy.deepcopy(mod)
135137
compiled_mod = torch.compile(compiled_mod, backend=cnts)
138+
torch.manual_seed(0)
136139
x = torch.randn(16, 16, device="cuda")
137140
y_eager = mod(x)
138141
y_compiled = compiled_mod(x)

test/test_dtensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
Float8ColwiseParallel,
2424
Float8RowwiseParallel,
2525
)
26-
from float8_experimental.float8_utils import tensor_to_scale
26+
from float8_experimental.float8_utils import tensor_to_scale, e4m3_dtype
2727
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
2828
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
2929
from torch.distributed.tensor.parallel import parallelize_module
@@ -53,7 +53,7 @@ def forward(self, x):
5353

5454
def test_scaled_mm(mesh: DeviceMesh, size=16):
5555
device = mesh.device_type
56-
fp8_dtype = torch.float8_e4m3fn
56+
fp8_dtype = e4m3_dtype
5757
world_size = mesh.size()
5858

5959
x_fp32 = torch.rand(size, size, device=device)
@@ -92,7 +92,7 @@ def test_scaled_mm(mesh: DeviceMesh, size=16):
9292

9393
def test_fp8_redistribute(mesh: DeviceMesh, size=16):
9494
device = mesh.device_type
95-
fp8_dtype = torch.float8_e4m3fn
95+
fp8_dtype = e4m3_dtype
9696
world_size = mesh.size()
9797

9898
x_fp32 = torch.rand(size, size, device=device)
@@ -119,7 +119,7 @@ def test_fp8_redistribute(mesh: DeviceMesh, size=16):
119119

120120
def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16):
121121
device = mesh.device_type
122-
fp8_dtype = torch.float8_e4m3fn
122+
fp8_dtype = e4m3_dtype
123123

124124
x_fp32 = torch.rand(size, size, device=device)
125125
dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)])
@@ -133,7 +133,7 @@ def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16):
133133

134134
def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
135135
device = mesh.device_type
136-
fp8_dtype = torch.float8_e4m3fn
136+
fp8_dtype = e4m3_dtype
137137

138138
x_fp32 = torch.rand(size, size, device=device, requires_grad=True)
139139
local_weight = torch.rand(2 * size, size, device=device, requires_grad=True)

test/test_everything.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,19 @@
22

33
# terminate script on first error
44
set -e
5+
IS_ROCM=$(rocm-smi --version)
56

67
pytest test/test_base.py
78
pytest test/test_sam.py
89
pytest test/test_compile.py
10+
11+
# These tests do not work on ROCm yet
12+
if [ -z "$IS_ROCM" ]
13+
then
914
./test/test_fsdp.sh
1015
./test/test_fsdp_compile.sh
1116
./test/test_dtensor.sh
1217
pytest test/test_fsdp2/test_fsdp2_eager.py
18+
fi
1319

1420
echo "all tests successful"

test/test_sam.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
swap_linear_with_float8_linear,
1919
sync_float8_amax_and_scale_history,
2020
)
21-
from float8_experimental.float8_utils import compute_error
21+
from float8_experimental.float8_utils import compute_error, IS_ROCM
2222
from transformers import SamModel
2323

2424
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
@@ -31,6 +31,7 @@ class TestFloat8SAMIntegrationTest:
3131
@pytest.mark.parametrize("data_dtype", [torch.float16, torch.bfloat16])
3232
@pytest.mark.parametrize("linear_type", [Float8Linear, Float8DynamicLinear])
3333
@pytest.mark.skipif(not is_H100, reason="requires H100 GPU")
34+
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
3435
def test_encoder_fw_bw(self, data_dtype, linear_type):
3536
model = SamModel.from_pretrained("facebook/sam-vit-base").to(data_dtype).cuda()
3637
# print(model)

0 commit comments

Comments
 (0)