Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit ee5d4f9

Browse files
committed
add test
1 parent 53921b2 commit ee5d4f9

File tree

3 files changed

+61
-2
lines changed

3 files changed

+61
-2
lines changed

float8_experimental/float8_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
123123
b_data = b._data
124124

125125
if a._mm_config.pad_inner_dim:
126+
assert (
127+
b._mm_config.pad_inner_dim
128+
), "Both mm configs must have pad_inner_dim set to True"
129+
assert a._data.size(1) == b._data.size(
130+
0
131+
), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}"
126132
a_data = pad_tensor_for_matmul(a_data, dims=1)
127133
b_data = pad_tensor_for_matmul(b_data, dims=0)
128134

float8_experimental/float8_python_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
to simplify the product code.
1010
"""
1111

12-
1312
from typing import Optional
1413

1514
import float8_experimental.float8_aten_api # noqa

test/test_base.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
import itertools
77
import random
8+
import re
89
import unittest
910
import warnings
1011

@@ -313,7 +314,7 @@ class TestScaledMM:
313314
"base_dtype", [torch.float16, torch.bfloat16, torch.float32]
314315
)
315316
@pytest.mark.parametrize("use_fast_accum", [True, False])
316-
def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
317+
def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum, padded):
317318
torch.manual_seed(42)
318319
input_dtype = e4m3_dtype
319320
output_dtype = base_dtype
@@ -387,6 +388,59 @@ def test_merge_configs(self):
387388
assert c.use_fast_accum is True
388389
assert c.fp8_output is False
389390

391+
@unittest.skipIf(
392+
not is_H100,
393+
"CUDA not available",
394+
)
395+
@pytest.mark.parametrize(
396+
"base_dtype", [torch.float16, torch.bfloat16, torch.float32]
397+
)
398+
@pytest.mark.parametrize("use_fast_accum", [True, False])
399+
def test_pad_inner_dim(self, base_dtype, use_fast_accum):
400+
torch.manual_seed(42)
401+
input_dtype = torch.float8_e4m3fn
402+
compare_type = torch.float32
403+
404+
a = torch.randn(16, 41, device="cuda", dtype=base_dtype)
405+
b = torch.randn(41, 128, device="cuda", dtype=base_dtype)
406+
407+
a_scale = tensor_to_scale(a, input_dtype).float()
408+
b_scale = tensor_to_scale(b, input_dtype).float()
409+
410+
a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype)
411+
b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype)
412+
413+
with pytest.raises(
414+
RuntimeError,
415+
match=re.escape(
416+
"Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41."
417+
),
418+
):
419+
a_fp8 @ b_fp8
420+
421+
pad_config = ScaledMMConfig(False, use_fast_accum, False, True)
422+
423+
a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype, mm_config=pad_config)
424+
b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype, mm_config=pad_config)
425+
out_padded = a_fp8 @ b_fp8
426+
out_padded.to(compare_type)
427+
428+
emulated_conifg = ScaledMMConfig(True, use_fast_accum, False, False)
429+
a_fp8 = Float8Tensor.to_float8(
430+
a, a_scale, input_dtype, mm_config=emulated_conifg
431+
)
432+
b_fp8 = Float8Tensor.to_float8(
433+
b, b_scale, input_dtype, mm_config=emulated_conifg
434+
)
435+
out_emualted = a_fp8 @ b_fp8
436+
out_emualted.to(compare_type)
437+
438+
if base_dtype in {torch.bfloat16, torch.float16}:
439+
atol, rtol = 7e-2, 7e-2
440+
else:
441+
atol, rtol = 2e-3, 2e-3
442+
torch.testing.assert_close(out_padded, out_emualted, atol=atol, rtol=rtol)
443+
390444

391445
class TestNumerics:
392446
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)