|
| 1 | +from dataclasses import dataclass |
| 2 | +from typing import Optional |
| 3 | + |
| 4 | +import fire |
| 5 | + |
| 6 | +import torch |
| 7 | +import torch.utils.benchmark as benchmark |
| 8 | +from float8_experimental.float8_utils import pad_tensor_for_matmul |
| 9 | +from tabulate import tabulate |
| 10 | + |
| 11 | +# estimating TOPs for matmuls in fp32, fp16, fp8 |
| 12 | +# assuming A * B = C, with A being M * K, B being K * N, C being M * N |
| 13 | + |
| 14 | +# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/ |
| 15 | +h100_peak_flops_float32 = 67e12 |
| 16 | +h100_peak_flops_fp16_tc = 1979e12 |
| 17 | +h100_peak_tops_float8_tc = 3958e12 |
| 18 | + |
| 19 | +dtype_to_peak_tops = { |
| 20 | + torch.float32: h100_peak_flops_float32, |
| 21 | + torch.float16: h100_peak_flops_fp16_tc, |
| 22 | + torch.bfloat16: h100_peak_flops_fp16_tc, |
| 23 | + torch.float8_e4m3fn: h100_peak_tops_float8_tc, |
| 24 | + torch.float8_e5m2: h100_peak_tops_float8_tc, |
| 25 | +} |
| 26 | + |
| 27 | + |
| 28 | +def benchmark_fn_in_usec(f, *args, **kwargs): |
| 29 | + # Manual warmup |
| 30 | + for _ in range(4): |
| 31 | + f(*args, **kwargs) |
| 32 | + t0 = benchmark.Timer( |
| 33 | + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} |
| 34 | + ) |
| 35 | + measurement = t0.blocked_autorange() |
| 36 | + return measurement.mean * 1e6 |
| 37 | + |
| 38 | + |
| 39 | +def get_tops_info(tops, time, peak_tops): |
| 40 | + time_sec = time / 1e6 |
| 41 | + tops_sec = float(tops) / time_sec |
| 42 | + pct_top_peak = tops_sec / peak_tops |
| 43 | + return tops_sec, pct_top_peak |
| 44 | + |
| 45 | + |
| 46 | +def do_fp8_matmul(A, B, fp8_dtype, out_dtype): |
| 47 | + A_fp8 = A.to(fp8_dtype) |
| 48 | + B_fp8 = B.to(fp8_dtype).t() # view |
| 49 | + |
| 50 | + A_pad = pad_tensor_for_matmul(A_fp8) # mem copy |
| 51 | + B_pad = pad_tensor_for_matmul(B_fp8, both=True).contiguous().t() # mem copy |
| 52 | + |
| 53 | + return torch._scaled_mm(A_pad, B_pad, out_dtype=out_dtype)[0][ |
| 54 | + : A.shape[0], : B.shape[1] |
| 55 | + ] |
| 56 | + |
| 57 | + |
| 58 | +def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype): |
| 59 | + A_pad = pad_tensor_for_matmul(A) # mem copy |
| 60 | + B_pad = pad_tensor_for_matmul(B, both=True) # mem copy |
| 61 | + |
| 62 | + A_pad = A_pad.to(fp8_dtype) # mem copy |
| 63 | + B_pad = B_pad.to(fp8_dtype) # mem copy |
| 64 | + |
| 65 | + B_pad = B_pad.t().contiguous().t() # mem copy |
| 66 | + |
| 67 | + return torch._scaled_mm(A_pad, B_pad, out_dtype=out_dtype)[0][ |
| 68 | + : A.shape[0], : B.shape[1] |
| 69 | + ] |
| 70 | + |
| 71 | + |
| 72 | +def do_hp_matmul(A, B): |
| 73 | + return torch.matmul(A, B) |
| 74 | + |
| 75 | + |
| 76 | +@dataclass |
| 77 | +class Experiment_config: |
| 78 | + M: int |
| 79 | + K: int |
| 80 | + N: int |
| 81 | + output_dtype: torch.dtype |
| 82 | + fp8_dtype: torch.dtype |
| 83 | + |
| 84 | + def __iter__(self): |
| 85 | + return iter((self.M, self.K, self.N, self.output_dtype, self.fp8_dtype)) |
| 86 | + |
| 87 | + |
| 88 | +def gen_configs(): |
| 89 | + shapes = [(8192, 2500, 5000), (4096, 10, 4096)] |
| 90 | + output_dtype = torch.float32 |
| 91 | + fp8_dtype = torch.float8_e4m3fn |
| 92 | + return [Experiment_config(*shape, output_dtype, fp8_dtype) for shape in shapes] |
| 93 | + |
| 94 | + |
| 95 | +@torch.no_grad() |
| 96 | +def run(compile: bool = False, n_limit: Optional[int] = None): |
| 97 | + device = "cuda" |
| 98 | + experiments = gen_configs() |
| 99 | + results = [] |
| 100 | + tops_table = [] |
| 101 | + tops_headers = [ |
| 102 | + "Shape", |
| 103 | + "Ref Dtype", |
| 104 | + "Ref Tops", |
| 105 | + "FP8 Tops", |
| 106 | + "Ref % Peak", |
| 107 | + "FP8 % Peak", |
| 108 | + ] |
| 109 | + for experiment in experiments: |
| 110 | + M, K, N, output_dtype, fp8_dtype = experiment |
| 111 | + tops = 2 * M * N * K |
| 112 | + |
| 113 | + A_base = torch.rand(M, K, device=device, dtype=output_dtype) |
| 114 | + B_base = torch.rand(K, N, device=device, dtype=output_dtype) |
| 115 | + |
| 116 | + hp_func = torch.compile(do_hp_matmul) if compile else do_hp_matmul |
| 117 | + fp8_func = torch.compile(do_fp8_pad_first_matmul) if compile else do_fp8_matmul |
| 118 | + |
| 119 | + ref_time = benchmark_fn_in_usec(hp_func, A_base, B_base) |
| 120 | + fp8_time = benchmark_fn_in_usec( |
| 121 | + fp8_func, A_base, B_base, fp8_dtype, output_dtype |
| 122 | + ) |
| 123 | + |
| 124 | + ref_tops_sec, ref_pct_top_peak = get_tops_info( |
| 125 | + tops, ref_time, dtype_to_peak_tops[output_dtype] |
| 126 | + ) |
| 127 | + fp8_tops_sec, fp8_pct_top_peak = get_tops_info( |
| 128 | + tops, fp8_time, dtype_to_peak_tops[fp8_dtype] |
| 129 | + ) |
| 130 | + tops_table.append( |
| 131 | + [ |
| 132 | + f"({M}x{K}x{N})", |
| 133 | + f"{output_dtype}", |
| 134 | + f"{ref_tops_sec:.2E}", |
| 135 | + f"{fp8_tops_sec:.2E}", |
| 136 | + f"{ref_pct_top_peak:.3f}", |
| 137 | + f"{fp8_pct_top_peak:.3f}", |
| 138 | + ] |
| 139 | + ) |
| 140 | + results.append( |
| 141 | + [(M, K, N), output_dtype, ref_time, fp8_time, ref_time / fp8_time] |
| 142 | + ) |
| 143 | + |
| 144 | + print("TOPs".center(80, "*")) |
| 145 | + print(tabulate(tops_table, headers=tops_headers)) |
| 146 | + print("Speed Results".center(80, "*")) |
| 147 | + headers = ["Shape", "Ref Dtype", "Ref Time", "FP8 Time", "Speedup"] |
| 148 | + print(tabulate(results, headers=headers, tablefmt="grid")) |
| 149 | + |
| 150 | + |
| 151 | +if __name__ == "__main__": |
| 152 | + fire.Fire(run) |
0 commit comments