|
3 | 3 | #
|
4 | 4 | # This source code is licensed under the BSD 3-Clause license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 |
| -from typing import Any, Dict, Tuple |
| 6 | +from typing import Any, Dict, Optional, Tuple |
7 | 7 |
|
8 | 8 | import torch
|
9 | 9 | from torch.utils._pytree import tree_map
|
10 | 10 |
|
11 |
| -from torchao.float8.float8_python_api import addmm_float8_unwrapped |
12 | 11 | from torchao.float8.float8_tensor import Float8Tensor, choose_scaled_mm_config
|
13 | 12 | from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul
|
14 | 13 |
|
|
18 | 17 | FLOAT8_OPS_TABLE: Dict[Any, Any] = {}
|
19 | 18 |
|
20 | 19 |
|
| 20 | +# [Note] Usage of scales |
| 21 | +# The meaning of scale in this library can be found in the definition of the Float8Tensor |
| 22 | +# Cublas defines scale to always mean a multiplicative factor for the respective matrices |
| 23 | +# For a,b going from fp8 -> fp32 we multiple by the inverse of the scale |
| 24 | +# For output going from fp32 -> fp8 we multiply by the scale |
| 25 | +def addmm_float8_unwrapped( |
| 26 | + a_data: torch.Tensor, |
| 27 | + a_scale: torch.Tensor, |
| 28 | + b_data: torch.Tensor, |
| 29 | + b_scale: torch.tensor, |
| 30 | + output_dtype: torch.dtype, |
| 31 | + output_scale: Optional[torch.Tensor] = None, |
| 32 | + bias: Optional[torch.Tensor] = None, |
| 33 | + use_fast_accum: bool = False, |
| 34 | +) -> torch.Tensor: |
| 35 | + """ |
| 36 | + This is the unwrapped version of addmm_float8, which does not take in Float8Tensors |
| 37 | + as inputs. This is used to standardize the logic between subclassed and non subclassed |
| 38 | + versions of the linear module. |
| 39 | + """ |
| 40 | + a_inverse_scale = a_scale.reciprocal() |
| 41 | + b_inverse_scale = b_scale.reciprocal() |
| 42 | + |
| 43 | + post_inverse_scale = None |
| 44 | + if ( |
| 45 | + a_scale.shape == (a_data.shape[0], 1) |
| 46 | + and b_scale.shape == (1, b_data.shape[1]) |
| 47 | + and not use_fast_accum |
| 48 | + ): |
| 49 | + # The rowwise CUTLASS-based kernel is so slow without fast-accum that |
| 50 | + # we'd rather use the tensorwise cuBLAS-based kernel and do the scaling |
| 51 | + # manually afterwards (hoping Inductor will be able to fuse it). |
| 52 | + post_inverse_scale = a_inverse_scale * b_inverse_scale |
| 53 | + a_inverse_scale = a_inverse_scale.new_ones(()) |
| 54 | + b_inverse_scale = a_inverse_scale.new_ones(()) |
| 55 | + |
| 56 | + post_bias = None |
| 57 | + if output_dtype == torch.float32: |
| 58 | + # Bias is not supported by _scaled_mm when output is fp32 |
| 59 | + post_bias = bias |
| 60 | + bias = None |
| 61 | + |
| 62 | + output = torch._scaled_mm( |
| 63 | + a_data, |
| 64 | + b_data, |
| 65 | + scale_a=a_inverse_scale, |
| 66 | + scale_b=b_inverse_scale, |
| 67 | + bias=bias, |
| 68 | + scale_result=output_scale, |
| 69 | + out_dtype=output_dtype, |
| 70 | + use_fast_accum=use_fast_accum, |
| 71 | + ) |
| 72 | + |
| 73 | + if post_inverse_scale is not None: |
| 74 | + output *= post_inverse_scale |
| 75 | + if post_bias is not None: |
| 76 | + output += post_bias |
| 77 | + |
| 78 | + return output |
| 79 | + |
| 80 | + |
21 | 81 | def _assert_tensorwise_scale(aten_op, scale):
|
22 | 82 | assert (
|
23 | 83 | # TODO(future PR): figure out why tensorwise scaling can have
|
|
0 commit comments