4
4
import fire
5
5
6
6
import torch
7
- import torch .utils .benchmark as benchmark
8
7
from float8_experimental .float8_utils import pad_tensor_for_matmul
9
8
from tabulate import tabulate
9
+ from torch ._inductor .utils import do_bench_using_profiling
10
+ from tqdm import tqdm
10
11
11
12
# estimating TOPs for matmuls in fp32, fp16, fp8
12
13
# assuming A * B = C, with A being M * K, B being K * N, C being M * N
26
27
27
28
28
29
def benchmark_fn_in_usec (f , * args , ** kwargs ):
29
- # Manual warmup
30
- for _ in range (4 ):
31
- f (* args , ** kwargs )
32
- t0 = benchmark .Timer (
33
- stmt = "f(*args, **kwargs)" , globals = {"args" : args , "kwargs" : kwargs , "f" : f }
34
- )
35
- measurement = t0 .blocked_autorange ()
36
- return measurement .mean * 1e6
30
+ no_args = lambda : f (* args , ** kwargs )
31
+ time = do_bench_using_profiling (no_args )
32
+ return time * 1e3
37
33
38
34
39
35
def get_tops_info (tops , time , peak_tops ):
@@ -51,16 +47,15 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
51
47
scale_b = torch .tensor ([1 ], device = "cuda" , dtype = torch .float32 )
52
48
53
49
A_pad = pad_tensor_for_matmul (A_fp8 , dims = 1 ) # mem copy
54
- B_pad = pad_tensor_for_matmul (B_fp8 , dims = [ 0 , 1 ] ).contiguous ().t () # mem copy
50
+ B_pad = pad_tensor_for_matmul (B_fp8 , dims = 0 ).contiguous ().t () # mem copy
55
51
56
- return torch ._scaled_mm (A_pad , B_pad , scale_a , scale_b , out_dtype = out_dtype )[
57
- : A .shape [0 ], : B .shape [1 ]
58
- ]
52
+ return torch ._scaled_mm (A_pad , B_pad , scale_a , scale_b , out_dtype = out_dtype )
59
53
60
54
61
55
def do_fp8_pad_first_matmul (A , B , fp8_dtype , out_dtype ):
56
+ # We are only going to test the shape preserving
62
57
A_pad = pad_tensor_for_matmul (A , dims = 1 ) # mem copy
63
- B_pad = pad_tensor_for_matmul (B , dims = [ 0 , 1 ] ) # mem copy
58
+ B_pad = pad_tensor_for_matmul (B , dims = 0 ) # mem copy
64
59
65
60
scale_a = torch .tensor ([1 ], device = "cuda" , dtype = torch .float32 )
66
61
scale_b = torch .tensor ([1 ], device = "cuda" , dtype = torch .float32 )
@@ -70,9 +65,7 @@ def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):
70
65
71
66
B_pad = B_pad .t ().contiguous ().t () # mem copy
72
67
73
- return torch ._scaled_mm (A_pad , B_pad , scale_a , scale_b , out_dtype = out_dtype )[
74
- : A .shape [0 ], : B .shape [1 ]
75
- ]
68
+ return torch ._scaled_mm (A_pad , B_pad , scale_a , scale_b , out_dtype = out_dtype )
76
69
77
70
78
71
def do_hp_matmul (A , B ):
@@ -92,7 +85,18 @@ def __iter__(self):
92
85
93
86
94
87
def gen_configs ():
95
- shapes = [(8192 , 2500 , 5000 ), (64 , 255 , 4096 )]
88
+ shapes = shapes = [
89
+ (8193 , 2501 , 5008 ),
90
+ (65 , 256 , 4096 ),
91
+ (1023 , 1024 , 2512 ),
92
+ (4095 , 512 , 10000 ),
93
+ (2047 , 3073 , 8192 ),
94
+ (511 , 769 , 7504 ),
95
+ (127 , 4097 , 12288 ),
96
+ (32769 , 16 , 15024 ),
97
+ (9217 , 8191 , 20480 ),
98
+ (16385 , 1025 , 25008 ),
99
+ ]
96
100
output_dtype = torch .bfloat16
97
101
fp8_dtype = torch .float8_e4m3fn
98
102
return [Experiment_config (* shape , output_dtype , fp8_dtype ) for shape in shapes ]
@@ -112,7 +116,7 @@ def run(compile: bool = False, n_limit: Optional[int] = None):
112
116
"Ref % Peak" ,
113
117
"FP8 % Peak" ,
114
118
]
115
- for experiment in experiments :
119
+ for experiment in tqdm ( experiments ) :
116
120
M , K , N , output_dtype , fp8_dtype = experiment
117
121
tops = 2 * M * N * K
118
122
0 commit comments