|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 | import itertools
|
7 | 7 | import random
|
| 8 | +import re |
8 | 9 | import unittest
|
9 | 10 | import warnings
|
10 | 11 |
|
@@ -312,7 +313,7 @@ class TestScaledMM:
|
312 | 313 | "base_dtype", [torch.float16, torch.bfloat16, torch.float32]
|
313 | 314 | )
|
314 | 315 | @pytest.mark.parametrize("use_fast_accum", [True, False])
|
315 |
| - def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): |
| 316 | + def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum, padded): |
316 | 317 | torch.manual_seed(42)
|
317 | 318 | input_dtype = torch.float8_e4m3fn
|
318 | 319 | output_dtype = base_dtype
|
@@ -393,6 +394,55 @@ def test_merge_configs(self):
|
393 | 394 | assert c.use_fast_accum is True
|
394 | 395 | assert c.fp8_output is False
|
395 | 396 |
|
| 397 | + @pytest.mark.parametrize( |
| 398 | + "base_dtype", [torch.float16, torch.bfloat16, torch.float32] |
| 399 | + ) |
| 400 | + @pytest.mark.parametrize("use_fast_accum", [True, False]) |
| 401 | + def test_pad_inner_dim(self, base_dtype, use_fast_accum): |
| 402 | + torch.manual_seed(42) |
| 403 | + input_dtype = torch.float8_e4m3fn |
| 404 | + compare_type = torch.float32 |
| 405 | + |
| 406 | + a = torch.randn(16, 41, device="cuda", dtype=base_dtype) |
| 407 | + b = torch.randn(41, 128, device="cuda", dtype=base_dtype) |
| 408 | + |
| 409 | + a_scale = tensor_to_scale(a, input_dtype).float() |
| 410 | + b_scale = tensor_to_scale(b, input_dtype).float() |
| 411 | + |
| 412 | + a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype) |
| 413 | + b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype) |
| 414 | + |
| 415 | + with pytest.raises( |
| 416 | + RuntimeError, |
| 417 | + match=re.escape( |
| 418 | + "Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41." |
| 419 | + ), |
| 420 | + ): |
| 421 | + a_fp8 @ b_fp8 |
| 422 | + |
| 423 | + pad_config = ScaledMMConfig(False, use_fast_accum, False, True) |
| 424 | + |
| 425 | + a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype, mm_config=pad_config) |
| 426 | + b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype, mm_config=pad_config) |
| 427 | + out_padded = a_fp8 @ b_fp8 |
| 428 | + out_padded.to(compare_type) |
| 429 | + |
| 430 | + emulated_conifg = ScaledMMConfig(True, use_fast_accum, False, False) |
| 431 | + a_fp8 = Float8Tensor.to_float8( |
| 432 | + a, a_scale, input_dtype, mm_config=emulated_conifg |
| 433 | + ) |
| 434 | + b_fp8 = Float8Tensor.to_float8( |
| 435 | + b, b_scale, input_dtype, mm_config=emulated_conifg |
| 436 | + ) |
| 437 | + out_emualted = a_fp8 @ b_fp8 |
| 438 | + out_emualted.to(compare_type) |
| 439 | + |
| 440 | + if base_dtype in {torch.bfloat16, torch.float16}: |
| 441 | + atol, rtol = 7e-2, 7e-2 |
| 442 | + else: |
| 443 | + atol, rtol = 2e-3, 2e-3 |
| 444 | + torch.testing.assert_close(out_padded, out_emualted, atol=atol, rtol=rtol) |
| 445 | + |
396 | 446 |
|
397 | 447 | class TestNumerics:
|
398 | 448 | @pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
|
|
0 commit comments