Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit b4473ed

Browse files
committed
add test
1 parent 6692e05 commit b4473ed

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

float8_experimental/float8_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
123123
b_data = b._data
124124

125125
if a._mm_config.pad_inner_dim:
126+
assert (
127+
b._mm_config.pad_inner_dim
128+
), "Both mm configs must have pad_inner_dim set to True"
129+
assert a._data.size(1) == b._data.size(
130+
0
131+
), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}"
126132
a_data = pad_tensor_for_matmul(a_data, dims=1)
127133
b_data = pad_tensor_for_matmul(b_data, dims=0)
128134

test/test_base.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
import itertools
77
import random
8+
import re
89
import unittest
910
import warnings
1011

@@ -312,7 +313,7 @@ class TestScaledMM:
312313
"base_dtype", [torch.float16, torch.bfloat16, torch.float32]
313314
)
314315
@pytest.mark.parametrize("use_fast_accum", [True, False])
315-
def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
316+
def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum, padded):
316317
torch.manual_seed(42)
317318
input_dtype = torch.float8_e4m3fn
318319
output_dtype = base_dtype
@@ -393,6 +394,55 @@ def test_merge_configs(self):
393394
assert c.use_fast_accum is True
394395
assert c.fp8_output is False
395396

397+
@pytest.mark.parametrize(
398+
"base_dtype", [torch.float16, torch.bfloat16, torch.float32]
399+
)
400+
@pytest.mark.parametrize("use_fast_accum", [True, False])
401+
def test_pad_inner_dim(self, base_dtype, use_fast_accum):
402+
torch.manual_seed(42)
403+
input_dtype = torch.float8_e4m3fn
404+
compare_type = torch.float32
405+
406+
a = torch.randn(16, 41, device="cuda", dtype=base_dtype)
407+
b = torch.randn(41, 128, device="cuda", dtype=base_dtype)
408+
409+
a_scale = tensor_to_scale(a, input_dtype).float()
410+
b_scale = tensor_to_scale(b, input_dtype).float()
411+
412+
a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype)
413+
b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype)
414+
415+
with pytest.raises(
416+
RuntimeError,
417+
match=re.escape(
418+
"Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41."
419+
),
420+
):
421+
a_fp8 @ b_fp8
422+
423+
pad_config = ScaledMMConfig(False, use_fast_accum, False, True)
424+
425+
a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype, mm_config=pad_config)
426+
b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype, mm_config=pad_config)
427+
out_padded = a_fp8 @ b_fp8
428+
out_padded.to(compare_type)
429+
430+
emulated_conifg = ScaledMMConfig(True, use_fast_accum, False, False)
431+
a_fp8 = Float8Tensor.to_float8(
432+
a, a_scale, input_dtype, mm_config=emulated_conifg
433+
)
434+
b_fp8 = Float8Tensor.to_float8(
435+
b, b_scale, input_dtype, mm_config=emulated_conifg
436+
)
437+
out_emualted = a_fp8 @ b_fp8
438+
out_emualted.to(compare_type)
439+
440+
if base_dtype in {torch.bfloat16, torch.float16}:
441+
atol, rtol = 7e-2, 7e-2
442+
else:
443+
atol, rtol = 2e-3, 2e-3
444+
torch.testing.assert_close(out_padded, out_emualted, atol=atol, rtol=rtol)
445+
396446

397447
class TestNumerics:
398448
@pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])

0 commit comments

Comments
 (0)