Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 53921b2

Browse files
committed
add an option to pad inner_dims
1 parent ef603c5 commit 53921b2

File tree

6 files changed

+38
-12
lines changed

6 files changed

+38
-12
lines changed

float8_experimental/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,9 @@
2323
# If True, use 'fnuz' float8 types for calculations.
2424
# Currently, ROCm only supports fnuz variants.
2525
use_fnuz_dtype = False
26+
27+
# If True, then prior to performing the fp8 scaled mamtmul we will pad the
28+
# inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls
29+
# _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16.
30+
# This can cause a memory spike however so we keep this off by default.
31+
pad_inner_dim = False

float8_experimental/float8_dynamic_linear.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,12 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
8888
"bias": False,
8989
}
9090
new_mod = cls(**super_kwargs)
91-
new_mod.forward_config = ScaledMMConfig(emulate, not bool(emulate))
92-
new_mod.backward_config = ScaledMMConfig(emulate, False)
91+
new_mod.forward_config = ScaledMMConfig(
92+
emulate, not bool(emulate), False, config.pad_inner_dim
93+
)
94+
new_mod.backward_config = ScaledMMConfig(
95+
emulate, False, False, config.pad_inner_dim
96+
)
9397
if config.enable_fsdp_fp8_all_gather:
9498
new_mod.weight = nn.Parameter(
9599
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)

float8_experimental/float8_linear.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,10 @@ def from_float(cls, mod, emulate: bool = False):
347347
new_mod.create_buffers()
348348
# Defines the behavior of the matmul in the forward and backward
349349
# Forward we use fast_accum, backwards we do not
350-
new_mod.forward_config = ScaledMMConfig(emulate, True if not emulate else False)
351-
new_mod.backward_config = ScaledMMConfig(emulate, False)
350+
new_mod.forward_config = ScaledMMConfig(
351+
emulate, True if not emulate else False, False, config.pad_inner_dim
352+
)
353+
new_mod.backward_config = ScaledMMConfig(
354+
emulate, False, False, config.pad_inner_dim
355+
)
352356
return new_mod

float8_experimental/float8_ops.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
merge_mm_configs,
1414
ScaledMMConfig,
1515
)
16-
from float8_experimental.float8_utils import is_row_major
16+
from float8_experimental.float8_utils import is_row_major, pad_tensor_for_matmul
17+
1718
from torch.utils._pytree import tree_map
1819

1920
aten = torch.ops.aten
@@ -121,6 +122,10 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
121122
a_scale = a._scale
122123
b_data = b._data
123124

125+
if a._mm_config.pad_inner_dim:
126+
a_data = pad_tensor_for_matmul(a_data, dims=1)
127+
b_data = pad_tensor_for_matmul(b_data, dims=0)
128+
124129
if not is_row_major(a_data.stride()):
125130
a_data = a_data.contiguous()
126131
if is_row_major(b_data.stride()):

float8_experimental/float8_tensor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
# emulate: whether to emulate the matmuls in fp32
2424
# use_fast_accum: whether to use the fast-accumulation option for scaled_mm
2525
# fp8_output: whether to output the result of the scaled_mm in fp8
26+
# pad_inner_dim: whether to pad the inner dimension of a and b with 0s. This is needed for matmuls not aligned to 16.
2627
ScaledMMConfig = namedtuple(
2728
"ScaledMMConfig",
28-
["emulate", "use_fast_accum", "fp8_output"],
29-
defaults=[False, False, False],
29+
["emulate", "use_fast_accum", "fp8_output", "pad_inner_dim"],
30+
defaults=[False, False, False, False],
3031
)
3132

3233

@@ -48,6 +49,7 @@ def merge_mm_configs(
4849
emulate=a_mm_config.emulate,
4950
use_fast_accum=a_mm_config.use_fast_accum and b_mm_config.use_fast_accum,
5051
fp8_output=a_mm_config.fp8_output and b_mm_config.fp8_output,
52+
pad_inner_dim=a_mm_config.pad_inner_dim and b_mm_config.pad_inner_dim,
5153
)
5254

5355

float8_experimental/float8_utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Literal, Tuple
7+
from typing import Iterable, Literal, Tuple, Union
88

99
import float8_experimental.config as config
1010

@@ -197,7 +197,9 @@ def get_min_alignment(size: int, alignment_value: int):
197197
return (1 + (size // alignment_value)) * alignment_value
198198

199199

200-
def pad_tensor_for_matmul(tensor: torch.Tensor, both: bool = False) -> torch.Tensor:
200+
def pad_tensor_for_matmul(
201+
tensor: torch.Tensor, dims: Union[int, Iterable[int]]
202+
) -> torch.Tensor:
201203
"""
202204
Pads a 2D tensor with zeros to ensure that its dimensions are multiples of 16, which is required for H100s.
203205
@@ -211,9 +213,12 @@ def pad_tensor_for_matmul(tensor: torch.Tensor, both: bool = False) -> torch.Ten
211213
assert tensor.dim() == 2
212214
dim1, dim2 = tensor.shape
213215

214-
# Calculate aligned dimensions
215-
dim2_aligned = get_min_alignment(dim2, 16)
216-
dim1_aligned = get_min_alignment(dim1, 16) if both else dim1
216+
if isinstance(dims, int):
217+
dims = (dims,)
218+
219+
# Calculate aligned dimensions based on the specified dims
220+
dim1_aligned = get_min_alignment(dim1, 16) if 0 in dims else dim1
221+
dim2_aligned = get_min_alignment(dim2, 16) if 1 in dims else dim2
217222

218223
# Check if padding is needed for either dimension
219224
if dim1 == dim1_aligned and dim2 == dim2_aligned:

0 commit comments

Comments
 (0)