4
4
import fire
5
5
6
6
import torch
7
- import torch .utils .benchmark as benchmark
8
7
from float8_experimental .float8_utils import pad_tensor_for_matmul
9
8
from tabulate import tabulate
9
+ from torch ._inductor .utils import do_bench_using_profiling
10
+ from tqdm import tqdm
10
11
11
12
# estimating TOPs for matmuls in fp32, fp16, fp8
12
13
# assuming A * B = C, with A being M * K, B being K * N, C being M * N
26
27
27
28
28
29
def benchmark_fn_in_usec (f , * args , ** kwargs ):
29
- # Manual warmup
30
- for _ in range (4 ):
31
- f (* args , ** kwargs )
32
- t0 = benchmark .Timer (
33
- stmt = "f(*args, **kwargs)" , globals = {"args" : args , "kwargs" : kwargs , "f" : f }
34
- )
35
- measurement = t0 .blocked_autorange ()
36
- return measurement .mean * 1e6
30
+ no_args = lambda : f (* args , ** kwargs )
31
+ time = do_bench_using_profiling (no_args )
32
+ return time * 1e3
37
33
38
34
39
35
def get_tops_info (tops , time , peak_tops ):
@@ -51,16 +47,17 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
51
47
scale_b = torch .tensor ([1 ], device = "cuda" , dtype = torch .float32 )
52
48
53
49
A_pad = pad_tensor_for_matmul (A_fp8 , dims = 1 ) # mem copy
54
- B_pad = pad_tensor_for_matmul (B_fp8 , dims = [ 0 , 1 ] ).contiguous ().t () # mem copy
50
+ B_pad = pad_tensor_for_matmul (B_fp8 , dims = 0 ).contiguous ().t () # mem copy
55
51
56
- return torch ._scaled_mm (A_pad , B_pad , scale_a , scale_b , out_dtype = out_dtype )[
57
- : A . shape [ 0 ], : B . shape [ 1 ]
58
- ]
52
+ return torch ._scaled_mm (
53
+ A_pad , B_pad , scale_a , scale_b , out_dtype = out_dtype , use_fast_accum = True
54
+ )
59
55
60
56
61
57
def do_fp8_pad_first_matmul (A , B , fp8_dtype , out_dtype ):
58
+ # We are only going to test the shape preserving
62
59
A_pad = pad_tensor_for_matmul (A , dims = 1 ) # mem copy
63
- B_pad = pad_tensor_for_matmul (B , dims = [ 0 , 1 ] ) # mem copy
60
+ B_pad = pad_tensor_for_matmul (B , dims = 0 ) # mem copy
64
61
65
62
scale_a = torch .tensor ([1 ], device = "cuda" , dtype = torch .float32 )
66
63
scale_b = torch .tensor ([1 ], device = "cuda" , dtype = torch .float32 )
@@ -70,9 +67,9 @@ def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):
70
67
71
68
B_pad = B_pad .t ().contiguous ().t () # mem copy
72
69
73
- return torch ._scaled_mm (A_pad , B_pad , scale_a , scale_b , out_dtype = out_dtype )[
74
- : A . shape [ 0 ], : B . shape [ 1 ]
75
- ]
70
+ return torch ._scaled_mm (
71
+ A_pad , B_pad , scale_a , scale_b , out_dtype = out_dtype , use_fast_accum = True
72
+ )
76
73
77
74
78
75
def do_hp_matmul (A , B ):
@@ -92,7 +89,18 @@ def __iter__(self):
92
89
93
90
94
91
def gen_configs ():
95
- shapes = [(8192 , 2500 , 5000 ), (64 , 255 , 4096 )]
92
+ shapes = shapes = [
93
+ (8193 , 2501 , 5008 ),
94
+ (65 , 253 , 4096 ),
95
+ (1023 , 1029 , 2512 ),
96
+ (4095 , 511 , 10000 ),
97
+ (2047 , 3073 , 8192 ),
98
+ (511 , 769 , 7504 ),
99
+ (127 , 4097 , 12288 ),
100
+ (32769 , 15 , 15024 ),
101
+ (9217 , 8191 , 20480 ),
102
+ (16385 , 1025 , 25008 ),
103
+ ]
96
104
output_dtype = torch .bfloat16
97
105
fp8_dtype = torch .float8_e4m3fn
98
106
return [Experiment_config (* shape , output_dtype , fp8_dtype ) for shape in shapes ]
@@ -112,7 +120,7 @@ def run(compile: bool = False, n_limit: Optional[int] = None):
112
120
"Ref % Peak" ,
113
121
"FP8 % Peak" ,
114
122
]
115
- for experiment in experiments :
123
+ for experiment in tqdm ( experiments ) :
116
124
M , K , N , output_dtype , fp8_dtype = experiment
117
125
tops = 2 * M * N * K
118
126
0 commit comments