14
14
15
15
import torch
16
16
import torch .utils .benchmark as benchmark
17
+ from float8_experimental .float8_dynamic_linear import Float8DynamicLinear
17
18
from float8_experimental .float8_linear import Float8Linear
18
- from float8_experimental .float8_linear_utils import sync_float8_amax_and_scale_history
19
+ from float8_experimental .float8_linear_utils import (
20
+ get_float8_linear ,
21
+ LinearType ,
22
+ sync_float8_amax_and_scale_history ,
23
+ )
19
24
from float8_experimental .float8_tensor import ScaledMMConfig
20
25
from tqdm import tqdm
21
26
35
40
torch .float8_e5m2 : h100_peak_tops_float8_tc ,
36
41
}
37
42
43
+ # prevent splitting columns when printing a data frame
44
+ pd .set_option ("display.expand_frame_repr" , False )
45
+ # print the entire data frame
46
+ pd_print_full_ctx = pd .option_context (
47
+ "display.max_rows" , None , "display.max_columns" , None
48
+ )
49
+
38
50
39
51
def benchmark_torch_function_in_microseconds (
40
52
func : Callable ,
@@ -57,6 +69,7 @@ class Experiment:
57
69
dtype : torch .dtype
58
70
compiled : bool
59
71
use_fast_accum : bool
72
+ linear_type : str
60
73
61
74
# 3 Times since we are calculating forward backward
62
75
@property
@@ -79,9 +92,12 @@ def float8_pct_top_peak(self):
79
92
80
93
81
94
def main (
82
- sweep_path : Path ,
83
- compile : bool ,
95
+ sweep_path : Optional [ Path ] = None ,
96
+ compile : bool = False ,
84
97
n_limit : Optional [int ] = None ,
98
+ fast_accum_filter : Optional [bool ] = None ,
99
+ shape_name_filter : Optional [str ] = None ,
100
+ linear_type_filter : Optional [str ] = None ,
85
101
):
86
102
device = "cuda"
87
103
print (f"Compile is set to | { compile } " )
@@ -95,20 +111,33 @@ def main(
95
111
"ffn.w2" : (3584 , 8192 ),
96
112
}
97
113
input_bias = False
98
- ref_dtypes = [torch .bfloat16 , torch .float16 ]
99
- use_fast_accum = [True , False ]
114
+ if fast_accum_filter is not None :
115
+ use_fast_accum = [fast_accum_filter ]
116
+ else :
117
+ use_fast_accum = [True , False ]
118
+ if linear_type_filter is not None :
119
+ linear_types = [linear_type_filter ]
120
+ else :
121
+ linear_types = ["delayed" , "dynamic" ]
122
+ if shape_name_filter is not None :
123
+ k = shape_name_filter
124
+ name_to_shapes_70b = {k : name_to_shapes_70b [k ]}
100
125
experiment_list : List [Experiment ] = []
101
- for idx , (dtype , fast_accum , (name , (K , N ))) in enumerate (
102
- tqdm (list (product (ref_dtypes , use_fast_accum , name_to_shapes_70b .items ())))
126
+ dtype = torch .bfloat16
127
+ for idx , (fast_accum , (name , (K , N )), linear_type ) in enumerate (
128
+ tqdm (list (product (use_fast_accum , name_to_shapes_70b .items (), linear_types )))
103
129
):
104
130
if n_limit is not None and idx >= n_limit :
105
131
break
106
132
linear_ref = torch .nn .Linear (K , N , bias = input_bias ).to (
107
133
device = device , dtype = dtype
108
134
)
135
+ linear_type_enum = (
136
+ LinearType .DELAYED if linear_type == "delayed" else LinearType .DYNAMIC
137
+ )
109
138
110
- linear_float8 = Float8Linear . from_float (
111
- copy .deepcopy (linear_ref ), emulate = False
139
+ linear_float8 = get_float8_linear (
140
+ linear_type_enum , copy .deepcopy (linear_ref ), emulate = False
112
141
)
113
142
if fast_accum :
114
143
linear_float8 .forward_config = ScaledMMConfig (False , True , False )
@@ -120,9 +149,16 @@ def main(
120
149
input_tensor = torch .randn (M , K , device = device , dtype = dtype , requires_grad = True )
121
150
ref_forw_backward = lambda : linear_ref (input_tensor ).sum ().backward ()
122
151
123
- def float8_forw_backward ():
124
- sync_float8_amax_and_scale_history (linear_float8 )
125
- linear_float8 (input_tensor ).sum ().backward ()
152
+ if linear_type_enum == LinearType .DELAYED :
153
+
154
+ def float8_forw_backward ():
155
+ sync_float8_amax_and_scale_history (linear_float8 )
156
+ linear_float8 (input_tensor ).sum ().backward ()
157
+
158
+ else :
159
+
160
+ def float8_forw_backward ():
161
+ linear_float8 (input_tensor ).sum ().backward ()
126
162
127
163
def n_times (n , fn , * args , ** kwargs ):
128
164
def wrapper (* args , ** kwargs ):
@@ -162,6 +198,7 @@ def wrapper(*args, **kwargs):
162
198
dtype ,
163
199
compile ,
164
200
use_fast_accum = fast_accum ,
201
+ linear_type = linear_type ,
165
202
)
166
203
print (experiment )
167
204
print ("float8 speedup" , experiment .ref_time_sec / experiment .float8_time_sec )
@@ -173,6 +210,7 @@ def wrapper(*args, **kwargs):
173
210
"M" ,
174
211
"K" ,
175
212
"N" ,
213
+ "linear_type" ,
176
214
"ref_dtype" ,
177
215
"compiled" ,
178
216
"use_fast_accum" ,
@@ -191,6 +229,7 @@ def wrapper(*args, **kwargs):
191
229
experiment .shape [0 ],
192
230
experiment .shape [1 ],
193
231
experiment .shape [2 ],
232
+ experiment .linear_type ,
194
233
experiment .dtype ,
195
234
experiment .compiled ,
196
235
experiment .use_fast_accum ,
@@ -219,28 +258,40 @@ def wrapper(*args, **kwargs):
219
258
[
220
259
"name" ,
221
260
"shape" ,
222
- "ref_dtype " ,
261
+ "linear_type " ,
223
262
"compiled" ,
224
263
"use_fast_accum" ,
225
264
"ref_time_sec" ,
226
265
"pt_fp8_time_sec" ,
227
266
"pt_fp8_speedup" ,
228
267
]
229
268
]
230
- print (data_pd_simple )
269
+ with pd_print_full_ctx :
270
+ print (data_pd_simple )
231
271
232
- sweep_path = sweep_path .with_suffix (".csv" )
233
- data_pd .to_csv (sweep_path )
272
+ if sweep_path is not None :
273
+ sweep_path = sweep_path .with_suffix (".csv" )
274
+ data_pd .to_csv (sweep_path )
234
275
235
276
236
277
def invoke_main () -> None :
237
278
parser = argparse .ArgumentParser ()
238
- parser .add_argument ("-o" , "--output_path" , type = str , required = True )
279
+ parser .add_argument ("-o" , "--output_path" , type = str , required = False )
239
280
parser .add_argument ("--compile" , action = "store_true" )
240
281
parser .add_argument ("-n" , "--n_limit" , type = int , required = False )
282
+ parser .add_argument ("--fast_accum_filter" , type = bool , required = False )
283
+ parser .add_argument ("--shape_name_filter" , type = str , required = False )
284
+ parser .add_argument ("--linear_type_filter" , type = str , required = False )
241
285
args = parser .parse_args ()
242
- output_path = Path (args .output_path )
243
- main (output_path , args .compile , args .n_limit )
286
+ output_path = Path (args .output_path ) if args .output_path is not None else None
287
+ main (
288
+ output_path ,
289
+ args .compile ,
290
+ args .n_limit ,
291
+ args .fast_accum_filter ,
292
+ args .shape_name_filter ,
293
+ args .linear_type_filter ,
294
+ )
244
295
245
296
246
297
if __name__ == "__main__" :
0 commit comments