Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 6692e05

Browse files
committed
add an option to pad inner_dims
1 parent 29e48ac commit 6692e05

File tree

6 files changed

+39
-12
lines changed

6 files changed

+39
-12
lines changed

float8_experimental/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,10 @@
1919
# implements pre/post-all-gather methods to do fp8 all-gather with FSDP2.
2020
# Only dynamic scaling is supported for now.
2121
enable_fsdp_fp8_all_gather = False
22+
23+
24+
# If True, then prior to performing the fp8 scaled mamtmul we will pad the
25+
# inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls
26+
# _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16.
27+
# This can cause a memory spike however so we keep this off by default.
28+
pad_inner_dim = False

float8_experimental/float8_dynamic_linear.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,12 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
8888
"bias": False,
8989
}
9090
new_mod = cls(**super_kwargs)
91-
new_mod.forward_config = ScaledMMConfig(emulate, not bool(emulate))
92-
new_mod.backward_config = ScaledMMConfig(emulate, False)
91+
new_mod.forward_config = ScaledMMConfig(
92+
emulate, not bool(emulate), False, config.pad_inner_dim
93+
)
94+
new_mod.backward_config = ScaledMMConfig(
95+
emulate, False, False, config.pad_inner_dim
96+
)
9397
if config.enable_fsdp_fp8_all_gather:
9498
new_mod.weight = nn.Parameter(
9599
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)

float8_experimental/float8_linear.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,10 @@ def from_float(cls, mod, emulate: bool = False):
342342
new_mod.create_buffers()
343343
# Defines the behavior of the matmul in the forward and backward
344344
# Forward we use fast_accum, backwards we do not
345-
new_mod.forward_config = ScaledMMConfig(emulate, True if not emulate else False)
346-
new_mod.backward_config = ScaledMMConfig(emulate, False)
345+
new_mod.forward_config = ScaledMMConfig(
346+
emulate, True if not emulate else False, False, config.pad_inner_dim
347+
)
348+
new_mod.backward_config = ScaledMMConfig(
349+
emulate, False, False, config.pad_inner_dim
350+
)
347351
return new_mod

float8_experimental/float8_ops.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
merge_mm_configs,
1414
ScaledMMConfig,
1515
)
16-
from float8_experimental.float8_utils import is_row_major
16+
from float8_experimental.float8_utils import is_row_major, pad_tensor_for_matmul
17+
1718
from torch.utils._pytree import tree_map
1819

1920
aten = torch.ops.aten
@@ -121,6 +122,10 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
121122
a_scale = a._scale
122123
b_data = b._data
123124

125+
if a._mm_config.pad_inner_dim:
126+
a_data = pad_tensor_for_matmul(a_data, dims=1)
127+
b_data = pad_tensor_for_matmul(b_data, dims=0)
128+
124129
if not is_row_major(a_data.stride()):
125130
a_data = a_data.contiguous()
126131
if is_row_major(b_data.stride()):

float8_experimental/float8_tensor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
# emulate: whether to emulate the matmuls in fp32
2020
# use_fast_accum: whether to use the fast-accumulation option for scaled_mm
2121
# fp8_output: whether to output the result of the scaled_mm in fp8
22+
# pad_inner_dim: whether to pad the inner dimension of a and b with 0s. This is needed for matmuls not aligned to 16.
2223
ScaledMMConfig = namedtuple(
2324
"ScaledMMConfig",
24-
["emulate", "use_fast_accum", "fp8_output"],
25-
defaults=[False, False, False],
25+
["emulate", "use_fast_accum", "fp8_output", "pad_inner_dim"],
26+
defaults=[False, False, False, False],
2627
)
2728

2829

@@ -44,6 +45,7 @@ def merge_mm_configs(
4445
emulate=a_mm_config.emulate,
4546
use_fast_accum=a_mm_config.use_fast_accum and b_mm_config.use_fast_accum,
4647
fp8_output=a_mm_config.fp8_output and b_mm_config.fp8_output,
48+
pad_inner_dim=a_mm_config.pad_inner_dim and b_mm_config.pad_inner_dim,
4749
)
4850

4951

float8_experimental/float8_utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Literal, Tuple
7+
from typing import Iterable, Literal, Tuple, Union
88

99
import torch
1010
import torch.distributed as dist
@@ -190,7 +190,9 @@ def get_min_alignment(size: int, alignment_value: int):
190190
return (1 + (size // alignment_value)) * alignment_value
191191

192192

193-
def pad_tensor_for_matmul(tensor: torch.Tensor, both: bool = False) -> torch.Tensor:
193+
def pad_tensor_for_matmul(
194+
tensor: torch.Tensor, dims: Union[int, Iterable[int]]
195+
) -> torch.Tensor:
194196
"""
195197
Pads a 2D tensor with zeros to ensure that its dimensions are multiples of 16, which is required for H100s.
196198
@@ -204,9 +206,12 @@ def pad_tensor_for_matmul(tensor: torch.Tensor, both: bool = False) -> torch.Ten
204206
assert tensor.dim() == 2
205207
dim1, dim2 = tensor.shape
206208

207-
# Calculate aligned dimensions
208-
dim2_aligned = get_min_alignment(dim2, 16)
209-
dim1_aligned = get_min_alignment(dim1, 16) if both else dim1
209+
if isinstance(dims, int):
210+
dims = (dims,)
211+
212+
# Calculate aligned dimensions based on the specified dims
213+
dim1_aligned = get_min_alignment(dim1, 16) if 0 in dims else dim1
214+
dim2_aligned = get_min_alignment(dim2, 16) if 1 in dims else dim2
210215

211216
# Check if padding is needed for either dimension
212217
if dim1 == dim1_aligned and dim2 == dim2_aligned:

0 commit comments

Comments
 (0)