Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit edae9a3

Browse files
drisspgfacebook-github-bot
authored andcommitted
Updates with new scaled-mm api (#284)
Summary: This updates the calls to _scaled_mm to the new signature from this PR: pytorch/pytorch#128683 This is needed to unblock inductor work on scaled_mm. ```Shell ❯ ./test/test_everything.sh . . . test/test_fsdp2/test_fsdp2_eager.py ....... [100%] ================================ 7 passed in 27.66s ================================= all tests successful ``` Pull Request resolved: #284 Reviewed By: y-sq Differential Revision: D58709092 Pulled By: drisspg fbshipit-source-id: ab330506621e9240f495be965748066d494d7b50
1 parent 1e9add3 commit edae9a3

File tree

6 files changed

+23
-30
lines changed

6 files changed

+23
-30
lines changed

benchmarks/bench_linear_float8.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
import torch
1616
import torch.utils.benchmark as benchmark
17-
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
18-
from float8_experimental.float8_linear import Float8Linear
1917
from float8_experimental.float8_linear_utils import (
2018
get_float8_linear,
2119
LinearType,

benchmarks/bench_matmul.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,11 @@ def run(n_limit: Optional[int] = None):
101101
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
102102

103103
def do_matmul(A, B):
104-
return torch._scaled_mm(A, B, out_dtype=d3, use_fast_accum=False)
104+
scale_a = torch.tensor([1], device=device)
105+
scale_b = torch.tensor([1], device=device)
106+
return torch._scaled_mm(
107+
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False
108+
)
105109

106110
fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
107111
tops, dtype_to_peak_tops[d1], do_matmul, A, B

float8_experimental/float8_aten_api.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import torch
1212

13-
from float8_experimental.float8_utils import tensor_to_amax
1413
from torch.library import Library
1514

1615

@@ -26,7 +25,7 @@ def mm_float8_emulated(
2625
m2_fp32 = m2.float() / s2
2726
m3_fp32 = torch.mm(m1_fp32, m2_fp32)
2827

29-
return m3_fp32.to(dtype3), tensor_to_amax(m3_fp32)
28+
return m3_fp32.to(dtype3)
3029

3130

3231
#
@@ -38,7 +37,7 @@ def mm_float8_emulated(
3837
lib = Library("aten", "FRAGMENT")
3938

4039
lib.define(
41-
"mm_float8_emulated(Tensor m1, Tensor s1, Tensor m2, Tensor s2, ScalarType dtype3) -> (Tensor, Tensor)"
40+
"mm_float8_emulated(Tensor m1, Tensor s1, Tensor m2, Tensor s2, ScalarType dtype3) -> Tensor"
4241
)
4342
lib.impl("mm_float8_emulated", mm_float8_emulated, "CPU")
4443
lib.impl("mm_float8_emulated", mm_float8_emulated, "CUDA")
@@ -47,4 +46,4 @@ def mm_float8_emulated(
4746
@torch.library.impl(lib, "mm_float8_emulated", "Meta")
4847
def _mm_float8_emulated_meta(m1, s1, m2, s2, dtype3):
4948
out = torch.mm(m1.float(), m2.float()).to(dtype3)
50-
return out, torch.empty(1, device="meta")
49+
return out

float8_experimental/float8_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ def float8_mm(aten_op, args, kwargs=None):
147147
if mm_config.emulate:
148148
return torch.ops.aten.mm_float8_emulated(
149149
a._data, a._scale, b._data, b._scale, output_dtype
150-
)[0]
151-
tensor_out, amax = addmm_float8_unwrapped(
150+
)
151+
tensor_out = addmm_float8_unwrapped(
152152
a_data,
153153
a_scale,
154154
b_data,
@@ -180,9 +180,9 @@ def float8_addmm(aten_op, args, kwargs=None):
180180
if mm_config.emulate:
181181
out = torch.ops.aten.mm_float8_emulated(
182182
a._data, a._scale, b._data, b._scale, output_dtype
183-
)[0]
183+
)
184184
return out + bias
185-
tensor_out, amax = addmm_float8_unwrapped(
185+
tensor_out = addmm_float8_unwrapped(
186186
a_data,
187187
a_scale,
188188
b_data,

float8_experimental/float8_python_api.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"""
1111

1212

13-
from typing import Optional, Tuple
13+
from typing import Optional
1414

1515
import float8_experimental.float8_aten_api # noqa
1616

@@ -31,7 +31,7 @@ def addmm_float8_unwrapped(
3131
output_scale: Optional[torch.Tensor] = None,
3232
bias: Optional[torch.Tensor] = None,
3333
use_fast_accum: bool = False,
34-
) -> Tuple[torch.Tensor, torch.Tensor]:
34+
) -> torch.Tensor:
3535
"""
3636
This is the unwrapped version of addmm_float8, which does not take in Float8Tensors
3737
as inputs. This is used to standardize the logic between subclassed and non subclassed
@@ -41,25 +41,25 @@ def addmm_float8_unwrapped(
4141
b_inverse_scale = b_scale.reciprocal()
4242
if output_dtype == torch.float32 and bias is not None:
4343
# Bias is not supported by _scaled_mm when output is fp32
44-
output, output_amax = torch._scaled_mm(
44+
output = torch._scaled_mm(
4545
a_data,
4646
b_data,
47-
out_dtype=output_dtype,
4847
scale_a=a_inverse_scale,
4948
scale_b=b_inverse_scale,
5049
scale_result=output_scale,
50+
out_dtype=output_dtype,
5151
use_fast_accum=use_fast_accum,
5252
)
5353
output += bias
54-
return output, output_amax
55-
output, output_amax = torch._scaled_mm(
54+
return output
55+
output = torch._scaled_mm(
5656
a_data,
5757
b_data,
58-
bias=bias,
59-
out_dtype=output_dtype,
6058
scale_a=a_inverse_scale,
6159
scale_b=b_inverse_scale,
60+
bias=bias,
6261
scale_result=output_scale,
62+
out_dtype=output_dtype,
6363
use_fast_accum=use_fast_accum,
6464
)
65-
return output, output_amax
65+
return output

test/test_base.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
ScaledMMConfig,
3030
)
3131
from float8_experimental.float8_utils import (
32-
amax_to_scale,
3332
compute_error,
3433
fp8_tensor_statistics,
3534
FP8_TYPES,
@@ -327,29 +326,22 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
327326
a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype)
328327
b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype)
329328

330-
out_scaled_mm, output_amax_scaled = addmm_float8_unwrapped(
329+
out_scaled_mm = addmm_float8_unwrapped(
331330
a_fp8._data,
332331
a_fp8._scale,
333332
b_fp8._data,
334333
b_fp8._scale,
335334
output_dtype=output_dtype,
336335
use_fast_accum=use_fast_accum,
337336
)
338-
out_emulated, output_amax_emulated = torch.ops.aten.mm_float8_emulated(
337+
out_emulated = torch.ops.aten.mm_float8_emulated(
339338
a_fp8._data, a_fp8._scale, b_fp8._data, b_fp8._scale, output_dtype
340339
)
341340

342341
if output_dtype != base_dtype:
343342
out_scaled_mm = out_scaled_mm.to(compare_type)
344343
out_emulated = out_emulated.to(compare_type)
345344

346-
out_scaled_mm = out_scaled_mm / amax_to_scale(
347-
output_amax_scaled, input_dtype
348-
)
349-
out_emulated = out_emulated / amax_to_scale(
350-
output_amax_emulated, input_dtype
351-
)
352-
353345
if base_dtype in {torch.bfloat16, torch.float16}:
354346
atol, rtol = 7e-2, 7e-2
355347
else:

0 commit comments

Comments
 (0)