Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 57136bd

Browse files
drisspgfacebook-github-bot
authored andcommitted
Add config to enable padding on inner dims for scaled_mm inputs (#145)
Summary: This adds simple utilities that can be used in order to enable scaled_mm to work with non multiple of 16 matrices. This is done by padding the inputs to the function. This also adds a script that can be used to explore the performance of different shapes. By running `python benchmarks/bench_padding.py` You can produce an example like this: ```Shell **************************************TOPs************************************** Shape Ref Dtype Ref Tops FP8 Tops Ref % Peak FP8 % Peak ------------------ -------------- ---------- ---------- ------------ ------------ (8193x2501x5008) torch.bfloat16 5.1e+14 8.17e+14 0.258 0.206 (65x253x4096) torch.bfloat16 1.07e+13 8.21e+12 0.005 0.002 (1023x1029x2512) torch.bfloat16 7.08e+13 1.98e+14 0.036 0.05 (4095x511x10000) torch.bfloat16 9.4e+13 5.52e+14 0.047 0.139 (2047x3073x8192) torch.bfloat16 1.14e+14 6.16e+14 0.058 0.156 (511x769x7504) torch.bfloat16 8.37e+13 1.68e+14 0.042 0.043 (127x4097x12288) torch.bfloat16 8.61e+13 8.55e+13 0.043 0.022 (32769x15x15024) torch.bfloat16 1.48e+13 3.27e+13 0.007 0.008 (9217x8191x20480) torch.bfloat16 1.2e+14 1.07e+15 0.061 0.271 (16385x1025x25008) torch.bfloat16 1.05e+14 8.11e+14 0.053 0.205 *********************************Speed Results********************************** +----------------------+----------------+------------+------------+-----------+ | Shape | Ref Dtype | Ref Time | FP8 Time | Speedup | +======================+================+============+============+===========+ | (8193, 2501, 5008) | torch.bfloat16 | 402.215 | 251.246 | 1.60088 | +----------------------+----------------+------------+------------+-----------+ | (65, 253, 4096) | torch.bfloat16 | 12.5471 | 16.4149 | 0.764373 | +----------------------+----------------+------------+------------+-----------+ | (1023, 1029, 2512) | torch.bfloat16 | 74.7011 | 26.6719 | 2.80074 | +----------------------+----------------+------------+------------+-----------+ | (4095, 511, 10000) | torch.bfloat16 | 445.42 | 75.8169 | 5.87494 | +----------------------+----------------+------------+------------+-----------+ | (2047, 3073, 8192) | torch.bfloat16 | 901.602 | 167.263 | 5.39033 | +----------------------+----------------+------------+------------+-----------+ | (511, 769, 7504) | torch.bfloat16 | 70.5006 | 35.0095 | 2.01376 | +----------------------+----------------+------------+------------+-----------+ | (127, 4097, 12288) | torch.bfloat16 | 148.589 | 149.542 | 0.993628 | +----------------------+----------------+------------+------------+-----------+ | (32769, 15, 15024) | torch.bfloat16 | 996.979 | 451.53 | 2.208 | +----------------------+----------------+------------+------------+-----------+ | (9217, 8191, 20480) | torch.bfloat16 | 25781.6 | 2886.31 | 8.93238 | +----------------------+----------------+------------+------------+-----------+ | (16385, 1025, 25008) | torch.bfloat16 | 8037.08 | 1036.24 | 7.75598 | +----------------------+----------------+------------+------------+-----------+ ``` ## Example workflows that this really helps ``` Python import torch from torch.utils.data import DataLoader from transformers import AutoModelForCausalLM, AutoTokenizer # Load model and tokenizer tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf").to("cuda") # Convert all torch.nn.Linear modules to Float8DynamicLinear from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear from float8_experimental.float8_dynamic_linear import Float8DynamicLinear import float8_experimental float8_experimental.config.pad_inner_dim = True swap_linear_with_float8_linear(model, Float8DynamicLinear) # Wrap model with Fully Sharded Data Parallel (FSDP) import torch.distributed as dist from torch.distributed.fsdp import FullyShardedDataParallel as FSDP import os os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' os.environ['WORLD_SIZE'] = '1' os.environ['RANK'] = '0' dist.init_process_group(backend='nccl', init_method='env://') # model = FSDP(model, use_orig_params=True) # optionally compile the model # model = torch.compile(model) # Prepare your dataset and dataloader (customize this part as needed) class TextDataset(torch.utils.data.Dataset): def __init__(self, texts, tokenizer): self.encodings = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=512) def __getitem__(self, idx): return {key: val[idx] for key, val in self.encodings.items()} def __len__(self): return len(self.encodings.input_ids) # Example text data texts = ["Example text input 1.", "Example text input 2.", "Example text input 3."] dataset = TextDataset(texts, tokenizer) dataloader = DataLoader(dataset, batch_size=2) # Set up the optimizer # optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) optimizer = torch.optim.SGD(model.parameters(), lr=5e-4) # Training loop model.train() for epoch in range(3): # Loop over the dataset multiple times for i, batch in enumerate(dataloader): inputs = {k: v.to(model.device) for k, v in batch.items()} # Forward pass outputs = model(**inputs, labels=inputs['input_ids']) loss = outputs.loss # Backward and optimize optimizer.zero_grad() loss.backward() optimizer.step() print(f'Epoch {epoch + 1}, Step {i + 1}, Loss: {loss.item()}') # Save the fine-tuned model model.save_pretrained("./fine_tuned_model") print("Training complete!") ``` Pull Request resolved: #145 Reviewed By: vkuzo Differential Revision: D58958442 Pulled By: drisspg fbshipit-source-id: 5a4c8661e974699ce3f83748fca1ce1f0ad65d3b
1 parent d4ade87 commit 57136bd

File tree

9 files changed

+366
-9
lines changed

9 files changed

+366
-9
lines changed

benchmarks/bench_padding.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
4+
import fire
5+
6+
import torch
7+
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
8+
from float8_experimental.float8_utils import pad_tensor_for_matmul
9+
from tabulate import tabulate
10+
from torch._inductor.utils import do_bench_using_profiling
11+
from tqdm import tqdm
12+
13+
# estimating TOPs for matmuls in fp32, fp16, fp8
14+
# assuming A * B = C, with A being M * K, B being K * N, C being M * N
15+
16+
# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
17+
h100_peak_flops_float32 = 67e12
18+
h100_peak_flops_fp16_tc = 1979e12
19+
h100_peak_tops_float8_tc = 3958e12
20+
21+
dtype_to_peak_tops = {
22+
torch.float32: h100_peak_flops_float32,
23+
torch.float16: h100_peak_flops_fp16_tc,
24+
torch.bfloat16: h100_peak_flops_fp16_tc,
25+
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
26+
torch.float8_e5m2: h100_peak_tops_float8_tc,
27+
}
28+
29+
30+
def benchmark_fn_in_usec(f, *args, **kwargs):
31+
no_args = lambda: f(*args, **kwargs)
32+
time = do_bench_using_profiling(no_args)
33+
return time * 1e3
34+
35+
36+
def get_tops_info(tops, time, peak_tops):
37+
time_sec = time / 1e6
38+
tops_sec = float(tops) / time_sec
39+
pct_top_peak = tops_sec / peak_tops
40+
return tops_sec, pct_top_peak
41+
42+
43+
def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
44+
scale_a = torch.tensor([1], device="cuda", dtype=torch.float32)
45+
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)
46+
47+
a_config = ScaledMMConfig(
48+
emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True
49+
)
50+
b_config = ScaledMMConfig(
51+
emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True
52+
)
53+
54+
a_fp8 = Float8Tensor.to_float8(A, scale_a, fp8_dtype, mm_config=a_config)
55+
b_fp8 = Float8Tensor.to_float8(B, scale_b, fp8_dtype, mm_config=b_config)
56+
57+
return a_fp8 @ b_fp8
58+
59+
60+
def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):
61+
# Breaks with compile due to trying to pad on fp8 dtype
62+
# return do_fp8_matmul(A, B, fp8_dtype, out_dtype)
63+
A_pad = pad_tensor_for_matmul(A, dims=1) # mem copy
64+
B_pad = pad_tensor_for_matmul(B, dims=0) # mem copy
65+
66+
scale_a = torch.tensor([1], device="cuda", dtype=torch.float32)
67+
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)
68+
69+
A_pad = A_pad.to(fp8_dtype) # mem copy
70+
B_pad = B_pad.to(fp8_dtype) # mem copy
71+
72+
B_pad = B_pad.t().contiguous().t() # mem copy
73+
74+
return torch._scaled_mm(
75+
A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype, use_fast_accum=True
76+
)
77+
78+
79+
def do_hp_matmul(A, B):
80+
return torch.matmul(A, B)
81+
82+
83+
def do_aligned_bf16_matmul(A, B):
84+
A_pad = pad_tensor_for_matmul(A, dims=1)
85+
B_pad = pad_tensor_for_matmul(B, dims=0)
86+
return torch.matmul(A_pad, B_pad)
87+
88+
89+
@dataclass
90+
class Experiment_config:
91+
M: int
92+
K: int
93+
N: int
94+
output_dtype: torch.dtype
95+
fp8_dtype: torch.dtype
96+
97+
def __iter__(self):
98+
return iter((self.M, self.K, self.N, self.output_dtype, self.fp8_dtype))
99+
100+
101+
def gen_configs():
102+
shapes = shapes = [
103+
(8193, 2501, 5008),
104+
(65, 253, 4096),
105+
(1023, 1029, 2512),
106+
(4095, 511, 10000),
107+
(2047, 3073, 8192),
108+
(511, 769, 7504),
109+
(127, 4097, 12288),
110+
(32769, 15, 15024),
111+
(9217, 8191, 20480),
112+
(16385, 1025, 25008),
113+
]
114+
output_dtype = torch.bfloat16
115+
fp8_dtype = torch.float8_e4m3fn
116+
return [Experiment_config(*shape, output_dtype, fp8_dtype) for shape in shapes]
117+
118+
119+
@torch.no_grad()
120+
def run(compile: bool = False, n_limit: Optional[int] = None):
121+
device = "cuda"
122+
experiments = gen_configs()
123+
results = []
124+
tops_table = []
125+
tops_headers = [
126+
"Shape",
127+
"Ref Dtype",
128+
"Ref Tops",
129+
"Aligned BF16 Tops",
130+
"FP8 Tops",
131+
"Ref % Peak",
132+
"Aligned BF16 % Peak",
133+
"FP8 % Peak",
134+
]
135+
136+
for experiment in tqdm(experiments):
137+
M, K, N, output_dtype, fp8_dtype = experiment
138+
tops = 2 * M * N * K
139+
140+
A_base = torch.rand(M, K, device=device, dtype=output_dtype)
141+
B_base = torch.rand(K, N, device=device, dtype=output_dtype)
142+
143+
hp_func = torch.compile(do_hp_matmul) if compile else do_hp_matmul
144+
aligned_bf16_func = (
145+
torch.compile(do_aligned_bf16_matmul) if compile else do_aligned_bf16_matmul
146+
)
147+
fp8_func = torch.compile(do_fp8_pad_first_matmul) if compile else do_fp8_matmul
148+
149+
ref_time = benchmark_fn_in_usec(hp_func, A_base, B_base)
150+
aligned_bf16_time = benchmark_fn_in_usec(aligned_bf16_func, A_base, B_base)
151+
fp8_time = benchmark_fn_in_usec(
152+
fp8_func, A_base, B_base, fp8_dtype, output_dtype
153+
)
154+
155+
ref_tops_sec, ref_pct_top_peak = get_tops_info(
156+
tops, ref_time, dtype_to_peak_tops[output_dtype]
157+
)
158+
aligned_bf16_tops_sec, aligned_bf16_pct_top_peak = get_tops_info(
159+
tops, aligned_bf16_time, dtype_to_peak_tops[torch.bfloat16]
160+
)
161+
fp8_tops_sec, fp8_pct_top_peak = get_tops_info(
162+
tops, fp8_time, dtype_to_peak_tops[fp8_dtype]
163+
)
164+
tops_table.append(
165+
[
166+
f"({M}x{K}x{N})",
167+
f"{output_dtype}",
168+
f"{ref_tops_sec:.2E}",
169+
f"{aligned_bf16_tops_sec:.2E}",
170+
f"{fp8_tops_sec:.2E}",
171+
f"{ref_pct_top_peak:.3f}",
172+
f"{aligned_bf16_pct_top_peak:.3f}",
173+
f"{fp8_pct_top_peak:.3f}",
174+
]
175+
)
176+
results.append(
177+
[
178+
(M, K, N),
179+
output_dtype,
180+
ref_time,
181+
aligned_bf16_time,
182+
fp8_time,
183+
ref_time / aligned_bf16_time,
184+
ref_time / fp8_time,
185+
]
186+
)
187+
188+
print("TOPs".center(80, "*"))
189+
print(tabulate(tops_table, headers=tops_headers))
190+
print("Speed Results".center(80, "*"))
191+
headers = [
192+
"Shape",
193+
"Ref Dtype",
194+
"Ref Time",
195+
"Aligned BF16 Time",
196+
"FP8 Time",
197+
"Aligned BF16 Speedup",
198+
"FP8 Speedup",
199+
]
200+
print(tabulate(results, headers=headers, tablefmt="grid"))
201+
202+
203+
if __name__ == "__main__":
204+
fire.Fire(run)

float8_experimental/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,9 @@
2323
# If True, use 'fnuz' float8 types for calculations.
2424
# Currently, ROCm only supports fnuz variants.
2525
use_fnuz_dtype = False
26+
27+
# If True, then prior to performing the fp8 scaled mamtmul we will pad the
28+
# inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls
29+
# _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16.
30+
# This can cause a memory spike however so we keep this off by default.
31+
pad_inner_dim = False

float8_experimental/float8_dynamic_linear.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,19 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
8888
"bias": False,
8989
}
9090
new_mod = cls(**super_kwargs)
91-
new_mod.forward_config = ScaledMMConfig(emulate, not bool(emulate))
92-
new_mod.backward_config = ScaledMMConfig(emulate, False)
91+
92+
new_mod.forward_config = ScaledMMConfig(
93+
emulate=emulate,
94+
use_fast_accum=not bool(emulate),
95+
fp8_output=False,
96+
pad_inner_dim=config.pad_inner_dim,
97+
)
98+
new_mod.backward_config = ScaledMMConfig(
99+
emulate=emulate,
100+
use_fast_accum=False,
101+
fp8_output=False,
102+
pad_inner_dim=config.pad_inner_dim,
103+
)
93104
if config.enable_fsdp_fp8_all_gather:
94105
new_mod.weight = nn.Parameter(
95106
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)

float8_experimental/float8_linear.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,10 @@ def from_float(cls, mod, emulate: bool = False):
347347
new_mod.create_buffers()
348348
# Defines the behavior of the matmul in the forward and backward
349349
# Forward we use fast_accum, backwards we do not
350-
new_mod.forward_config = ScaledMMConfig(emulate, True if not emulate else False)
351-
new_mod.backward_config = ScaledMMConfig(emulate, False)
350+
new_mod.forward_config = ScaledMMConfig(
351+
emulate, True if not emulate else False, False, config.pad_inner_dim
352+
)
353+
new_mod.backward_config = ScaledMMConfig(
354+
emulate, False, False, config.pad_inner_dim
355+
)
352356
return new_mod

float8_experimental/float8_ops.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
merge_mm_configs,
1414
ScaledMMConfig,
1515
)
16-
from float8_experimental.float8_utils import is_row_major
16+
from float8_experimental.float8_utils import is_row_major, pad_tensor_for_matmul
17+
1718
from torch.utils._pytree import tree_map
1819

1920
aten = torch.ops.aten
@@ -121,6 +122,16 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
121122
a_scale = a._scale
122123
b_data = b._data
123124

125+
if a._mm_config.pad_inner_dim:
126+
assert (
127+
b._mm_config.pad_inner_dim
128+
), "Both mm configs must have pad_inner_dim set to True"
129+
assert a._data.size(1) == b._data.size(
130+
0
131+
), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}"
132+
a_data = pad_tensor_for_matmul(a_data, dims=1)
133+
b_data = pad_tensor_for_matmul(b_data, dims=0)
134+
124135
if not is_row_major(a_data.stride()):
125136
a_data = a_data.contiguous()
126137
if is_row_major(b_data.stride()):

float8_experimental/float8_python_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
to simplify the product code.
1010
"""
1111

12-
1312
from typing import Optional
1413

1514
import float8_experimental.float8_aten_api # noqa

float8_experimental/float8_tensor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
# emulate: whether to emulate the matmuls in fp32
2424
# use_fast_accum: whether to use the fast-accumulation option for scaled_mm
2525
# fp8_output: whether to output the result of the scaled_mm in fp8
26+
# pad_inner_dim: whether to pad the inner dimension of a and b with 0s. This is needed for matmuls not aligned to 16.
2627
ScaledMMConfig = namedtuple(
2728
"ScaledMMConfig",
28-
["emulate", "use_fast_accum", "fp8_output"],
29-
defaults=[False, False, False],
29+
["emulate", "use_fast_accum", "fp8_output", "pad_inner_dim"],
30+
defaults=[False, False, False, False],
3031
)
3132

3233

@@ -48,6 +49,7 @@ def merge_mm_configs(
4849
emulate=a_mm_config.emulate,
4950
use_fast_accum=a_mm_config.use_fast_accum and b_mm_config.use_fast_accum,
5051
fp8_output=a_mm_config.fp8_output and b_mm_config.fp8_output,
52+
pad_inner_dim=a_mm_config.pad_inner_dim and b_mm_config.pad_inner_dim,
5153
)
5254

5355

float8_experimental/float8_utils.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Literal, Tuple
7+
from typing import Iterable, Literal, Tuple, Union
88

99
import float8_experimental.config as config
1010

@@ -179,3 +179,69 @@ def fp8_tensor_statistics(
179179
def is_row_major(stride):
180180
assert len(stride) == 2, "is_row_major only supports 2D tensors"
181181
return stride[0] > stride[1] and stride[1] == 1
182+
183+
184+
def _get_min_alignment(size: int, alignment_value: int) -> int:
185+
"""
186+
Returns the minimum alignment value that is greater than or equal to the given size.
187+
188+
Args:
189+
size: The size of the data to be aligned.
190+
alignment_value: The alignment value to be used.
191+
192+
Returns:
193+
int: The minimum alignment value that is greater than or equal to the given size.
194+
195+
Usage:
196+
```
197+
>>> _get_min_alignment(10, 8)
198+
16
199+
```
200+
"""
201+
if size % alignment_value == 0:
202+
return size
203+
return (1 + (size // alignment_value)) * alignment_value
204+
205+
206+
def pad_tensor_for_matmul(
207+
tensor: torch.Tensor, dims: Union[int, Iterable[int]]
208+
) -> torch.Tensor:
209+
"""
210+
Pads a 2D tensor with zeros to ensure that its dimensions are multiples of 16, which is required `torch._scaled_mm`
211+
212+
Args:
213+
tensor: The tensor to pad.
214+
both: Whether to pad both dimensions or just the second dimension.
215+
216+
Returns:
217+
torch.Tensor: The padded tensor.
218+
219+
Usage:
220+
```
221+
>>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=0).shape
222+
torch.Size([16, 10])
223+
>>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=1).shape
224+
torch.Size([10, 16])
225+
>>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=(0, 1)).shape
226+
torch.Size([16, 16])
227+
```
228+
"""
229+
assert tensor.dim() == 2
230+
dim1, dim2 = tensor.shape
231+
232+
if isinstance(dims, int):
233+
dims = (dims,)
234+
235+
# Calculate aligned dimensions based on the specified dims
236+
dim1_aligned = _get_min_alignment(dim1, 16) if 0 in dims else dim1
237+
dim2_aligned = _get_min_alignment(dim2, 16) if 1 in dims else dim2
238+
239+
# Check if padding is needed for either dimension
240+
if dim1 == dim1_aligned and dim2 == dim2_aligned:
241+
return tensor
242+
243+
# Calculate padding values for both dimensions
244+
pad_dim1 = dim1_aligned - dim1
245+
pad_dim2 = dim2_aligned - dim2
246+
247+
return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1))

0 commit comments

Comments
 (0)