Skip to content

Commit 7d87946

Browse files
authored
[1/x] float8 cleanup: remove float8_python_api (#1779)
Update [ghstack-poisoned]
1 parent 8706d3f commit 7d87946

File tree

3 files changed

+63
-78
lines changed

3 files changed

+63
-78
lines changed

test/float8/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from torchao.float8.float8_linear_utils import (
3838
convert_to_float8_training,
3939
)
40-
from torchao.float8.float8_python_api import addmm_float8_unwrapped
40+
from torchao.float8.float8_ops import addmm_float8_unwrapped
4141
from torchao.float8.float8_scaling_utils import (
4242
get_maybe_axiswise_dim,
4343
hp_tensor_to_float8_dynamic,

torchao/float8/float8_ops.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
from typing import Any, Dict, Tuple
6+
from typing import Any, Dict, Optional, Tuple
77

88
import torch
99
from torch.utils._pytree import tree_map
1010

11-
from torchao.float8.float8_python_api import addmm_float8_unwrapped
1211
from torchao.float8.float8_tensor import Float8Tensor, choose_scaled_mm_config
1312
from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul
1413

@@ -18,6 +17,67 @@
1817
FLOAT8_OPS_TABLE: Dict[Any, Any] = {}
1918

2019

20+
# [Note] Usage of scales
21+
# The meaning of scale in this library can be found in the definition of the Float8Tensor
22+
# Cublas defines scale to always mean a multiplicative factor for the respective matrices
23+
# For a,b going from fp8 -> fp32 we multiple by the inverse of the scale
24+
# For output going from fp32 -> fp8 we multiply by the scale
25+
def addmm_float8_unwrapped(
26+
a_data: torch.Tensor,
27+
a_scale: torch.Tensor,
28+
b_data: torch.Tensor,
29+
b_scale: torch.tensor,
30+
output_dtype: torch.dtype,
31+
output_scale: Optional[torch.Tensor] = None,
32+
bias: Optional[torch.Tensor] = None,
33+
use_fast_accum: bool = False,
34+
) -> torch.Tensor:
35+
"""
36+
This is the unwrapped version of addmm_float8, which does not take in Float8Tensors
37+
as inputs. This is used to standardize the logic between subclassed and non subclassed
38+
versions of the linear module.
39+
"""
40+
a_inverse_scale = a_scale.reciprocal()
41+
b_inverse_scale = b_scale.reciprocal()
42+
43+
post_inverse_scale = None
44+
if (
45+
a_scale.shape == (a_data.shape[0], 1)
46+
and b_scale.shape == (1, b_data.shape[1])
47+
and not use_fast_accum
48+
):
49+
# The rowwise CUTLASS-based kernel is so slow without fast-accum that
50+
# we'd rather use the tensorwise cuBLAS-based kernel and do the scaling
51+
# manually afterwards (hoping Inductor will be able to fuse it).
52+
post_inverse_scale = a_inverse_scale * b_inverse_scale
53+
a_inverse_scale = a_inverse_scale.new_ones(())
54+
b_inverse_scale = a_inverse_scale.new_ones(())
55+
56+
post_bias = None
57+
if output_dtype == torch.float32:
58+
# Bias is not supported by _scaled_mm when output is fp32
59+
post_bias = bias
60+
bias = None
61+
62+
output = torch._scaled_mm(
63+
a_data,
64+
b_data,
65+
scale_a=a_inverse_scale,
66+
scale_b=b_inverse_scale,
67+
bias=bias,
68+
scale_result=output_scale,
69+
out_dtype=output_dtype,
70+
use_fast_accum=use_fast_accum,
71+
)
72+
73+
if post_inverse_scale is not None:
74+
output *= post_inverse_scale
75+
if post_bias is not None:
76+
output += post_bias
77+
78+
return output
79+
80+
2181
def _assert_tensorwise_scale(aten_op, scale):
2282
assert (
2383
# TODO(future PR): figure out why tensorwise scaling can have

torchao/float8/float8_python_api.py

Lines changed: 0 additions & 75 deletions
This file was deleted.

0 commit comments

Comments
 (0)