Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 8f50785

Browse files
committed
add test
1 parent 6692e05 commit 8f50785

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

float8_experimental/float8_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
123123
b_data = b._data
124124

125125
if a._mm_config.pad_inner_dim:
126+
assert (
127+
b._mm_config.pad_inner_dim
128+
), "Both mm configs must have pad_inner_dim set to True"
129+
assert a._data.size(1) == b._data.size(
130+
0
131+
), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}"
126132
a_data = pad_tensor_for_matmul(a_data, dims=1)
127133
b_data = pad_tensor_for_matmul(b_data, dims=0)
128134

test/test_base.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
import itertools
77
import random
8+
import re
89
import unittest
910
import warnings
1011

@@ -312,7 +313,7 @@ class TestScaledMM:
312313
"base_dtype", [torch.float16, torch.bfloat16, torch.float32]
313314
)
314315
@pytest.mark.parametrize("use_fast_accum", [True, False])
315-
def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
316+
def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum, padded):
316317
torch.manual_seed(42)
317318
input_dtype = torch.float8_e4m3fn
318319
output_dtype = base_dtype
@@ -393,6 +394,59 @@ def test_merge_configs(self):
393394
assert c.use_fast_accum is True
394395
assert c.fp8_output is False
395396

397+
@unittest.skipIf(
398+
not is_H100,
399+
"CUDA not available",
400+
)
401+
@pytest.mark.parametrize(
402+
"base_dtype", [torch.float16, torch.bfloat16, torch.float32]
403+
)
404+
@pytest.mark.parametrize("use_fast_accum", [True, False])
405+
def test_pad_inner_dim(self, base_dtype, use_fast_accum):
406+
torch.manual_seed(42)
407+
input_dtype = torch.float8_e4m3fn
408+
compare_type = torch.float32
409+
410+
a = torch.randn(16, 41, device="cuda", dtype=base_dtype)
411+
b = torch.randn(41, 128, device="cuda", dtype=base_dtype)
412+
413+
a_scale = tensor_to_scale(a, input_dtype).float()
414+
b_scale = tensor_to_scale(b, input_dtype).float()
415+
416+
a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype)
417+
b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype)
418+
419+
with pytest.raises(
420+
RuntimeError,
421+
match=re.escape(
422+
"Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41."
423+
),
424+
):
425+
a_fp8 @ b_fp8
426+
427+
pad_config = ScaledMMConfig(False, use_fast_accum, False, True)
428+
429+
a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype, mm_config=pad_config)
430+
b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype, mm_config=pad_config)
431+
out_padded = a_fp8 @ b_fp8
432+
out_padded.to(compare_type)
433+
434+
emulated_conifg = ScaledMMConfig(True, use_fast_accum, False, False)
435+
a_fp8 = Float8Tensor.to_float8(
436+
a, a_scale, input_dtype, mm_config=emulated_conifg
437+
)
438+
b_fp8 = Float8Tensor.to_float8(
439+
b, b_scale, input_dtype, mm_config=emulated_conifg
440+
)
441+
out_emualted = a_fp8 @ b_fp8
442+
out_emualted.to(compare_type)
443+
444+
if base_dtype in {torch.bfloat16, torch.float16}:
445+
atol, rtol = 7e-2, 7e-2
446+
else:
447+
atol, rtol = 2e-3, 2e-3
448+
torch.testing.assert_close(out_padded, out_emualted, atol=atol, rtol=rtol)
449+
396450

397451
class TestNumerics:
398452
@pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])

0 commit comments

Comments
 (0)