8
8
"""
9
9
10
10
import dataclasses
11
+ import enum
11
12
12
13
from typing import Optional
13
14
14
15
import float8_experimental .config as config
15
16
16
17
import torch
17
18
19
+ from float8_experimental .float8_dynamic_linear import (
20
+ cast_to_float8_e4m3_dynamic ,
21
+ cast_to_float8_e5m2_dynamic_bw ,
22
+ )
23
+
18
24
from float8_experimental .float8_tensor import (
19
25
Float8Tensor ,
20
26
ScaledMMConfig ,
@@ -125,20 +131,54 @@ def __init__(self, history_len: int = 16, scale_fn_name: str = "max"):
125
131
), f"{ self .scale_fn_name } is not implemented yet. Only max is supported for now."
126
132
127
133
134
+ class TensorScalingType (enum .Enum ):
135
+ DELAYED = "delayed"
136
+ DYNAMIC = "dynamic"
137
+
138
+ def short_str (self ):
139
+ if self is TensorScalingType .DELAYED :
140
+ return "del"
141
+ else :
142
+ assert self is TensorScalingType .DYNAMIC
143
+ return "dyn"
144
+
145
+
128
146
class Float8Linear (torch .nn .Linear ):
129
147
"""
130
148
A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks
131
149
scales in way friendly to delayed scaling.
132
150
"""
133
151
134
152
def __init__ (self , * args , ** kwargs ):
153
+ """
154
+ Additional arguments on top of `torch.nn.Linear`'s arguments:
155
+ * `delayed_scaling_recipe`: configuration for delayed scaling
156
+ * `scaling_type_x`: delayed vs dynamic scaling for `x`
157
+ * `scaling_type_w`: delayed vs dynamic scaling for `w`
158
+ * `scaling_type_dL_dY`: delayed vs dynamic scaling for `dL_dY`
159
+ """
160
+
135
161
delayed_scaling_recipe = kwargs .pop (
136
162
"delayed_scaling_recipe" , DelayedScalingRecipe ()
137
163
)
138
164
# Amax scales should always be kept as float32.
139
165
self .always_float32_buffers = set ()
166
+ scaling_type_x = kwargs .pop ("scaling_type_x" , TensorScalingType .DELAYED )
167
+ scaling_type_w = kwargs .pop ("scaling_type_w" , TensorScalingType .DELAYED )
168
+ scaling_type_dL_dY = kwargs .pop ("scaling_type_dL_dY" , TensorScalingType .DELAYED )
140
169
super ().__init__ (* args , ** kwargs )
141
170
171
+ # Defines the scaling behavior of x, w, dL_dY
172
+ self .scaling_type_x = scaling_type_x
173
+ self .scaling_type_w = scaling_type_w
174
+ self .scaling_type_dL_dY = scaling_type_dL_dY
175
+ # Convenience flag to skip code related to delayed scaling
176
+ self .has_any_delayed_scaling = (
177
+ self .scaling_type_x is TensorScalingType .DELAYED
178
+ or self .scaling_type_w is TensorScalingType .DELAYED
179
+ or self .scaling_type_dL_dY is TensorScalingType .DELAYED
180
+ )
181
+
142
182
# TODO(future): have a unique recipe per buffer instead of one per
143
183
# module, saving implementing that until we need it.
144
184
# TODO(future): serialization for recipes
@@ -175,37 +215,44 @@ def create_buffers(self):
175
215
# Default values for history buffers, see above TODO
176
216
history_len = self .recipe .history_len
177
217
device = self .weight .device
218
+ # TODO(future PR): dtype values below don't have the other float8
219
+ # flavors, fix it
178
220
default_x = torch .finfo (torch .float8_e4m3fn ).max
179
221
default_w = torch .finfo (torch .float8_e4m3fn ).max
180
222
default_dl_dy = torch .finfo (torch .float8_e5m2 ).max
181
223
182
- self .register_always_float32_buffer (
183
- "fp8_amax_x" , torch .tensor ([default_x ], device = device )
184
- )
185
- self .register_always_float32_buffer (
186
- "fp8_amax_history_x" , torch .zeros (history_len , device = device )
187
- )
188
- self .register_always_float32_buffer (
189
- "fp8_scale_x" , torch .tensor ([1.0 ], device = device )
190
- )
191
- self .register_always_float32_buffer (
192
- "fp8_amax_w" , torch .tensor ([default_w ], device = device )
193
- )
194
- self .register_always_float32_buffer (
195
- "fp8_amax_history_w" , torch .zeros (history_len , device = device )
196
- )
197
- self .register_always_float32_buffer (
198
- "fp8_scale_w" , torch .tensor ([1.0 ], device = device )
199
- )
200
- self .register_always_float32_buffer (
201
- "fp8_amax_dL_dY" , torch .tensor ([default_dl_dy ], device = device )
202
- )
203
- self .register_always_float32_buffer (
204
- "fp8_amax_history_dL_dY" , torch .zeros (history_len , device = device )
205
- )
206
- self .register_always_float32_buffer (
207
- "fp8_scale_dL_dY" , torch .tensor ([1.0 ], device = device )
208
- )
224
+ # Note: for now, create all the buffers if any are needed, to postpone
225
+ # the work to make the scale and amax syncing and history calculation
226
+ # handle a heterogeneous setup. We can do that work later if benchmarks
227
+ # show it is worth doing.
228
+ if self .has_any_delayed_scaling :
229
+ self .register_always_float32_buffer (
230
+ "fp8_amax_x" , torch .tensor ([default_x ], device = device )
231
+ )
232
+ self .register_always_float32_buffer (
233
+ "fp8_amax_history_x" , torch .zeros (history_len , device = device )
234
+ )
235
+ self .register_always_float32_buffer (
236
+ "fp8_scale_x" , torch .tensor ([1.0 ], device = device )
237
+ )
238
+ self .register_always_float32_buffer (
239
+ "fp8_amax_w" , torch .tensor ([default_w ], device = device )
240
+ )
241
+ self .register_always_float32_buffer (
242
+ "fp8_amax_history_w" , torch .zeros (history_len , device = device )
243
+ )
244
+ self .register_always_float32_buffer (
245
+ "fp8_scale_w" , torch .tensor ([1.0 ], device = device )
246
+ )
247
+ self .register_always_float32_buffer (
248
+ "fp8_amax_dL_dY" , torch .tensor ([default_dl_dy ], device = device )
249
+ )
250
+ self .register_always_float32_buffer (
251
+ "fp8_amax_history_dL_dY" , torch .zeros (history_len , device = device )
252
+ )
253
+ self .register_always_float32_buffer (
254
+ "fp8_scale_dL_dY" , torch .tensor ([1.0 ], device = device )
255
+ )
209
256
210
257
def register_always_float32_buffer (
211
258
self , name : str , tensor : Optional [torch .Tensor ], persistent : bool = True
@@ -234,61 +281,77 @@ def cast_x_to_float8(
234
281
autocast_dtype = torch .get_autocast_gpu_dtype ()
235
282
x = x .to (autocast_dtype )
236
283
237
- scale_fn_name = self .recipe .scale_fn_name
238
- _maybe_initialize_amaxes_scales_for_float8_cast (
239
- x ,
240
- self .fp8_amax_x ,
241
- self .fp8_amax_history_x ,
242
- self .fp8_scale_x ,
243
- scale_fn_name ,
244
- e4m3_dtype ,
245
- is_amax_initialized ,
246
- reduce_amax = True ,
247
- )
248
- x_fp8 = Float8Tensor .to_float8 (
249
- x ,
250
- self .fp8_scale_x ,
251
- e4m3_dtype ,
252
- self .fp8_amax_x ,
253
- self .forward_config ,
254
- )
284
+ if self .scaling_type_x is TensorScalingType .DELAYED :
285
+ scale_fn_name = self .recipe .scale_fn_name
286
+ _maybe_initialize_amaxes_scales_for_float8_cast (
287
+ x ,
288
+ self .fp8_amax_x ,
289
+ self .fp8_amax_history_x ,
290
+ self .fp8_scale_x ,
291
+ scale_fn_name ,
292
+ e4m3_dtype ,
293
+ is_amax_initialized ,
294
+ reduce_amax = True ,
295
+ )
296
+ x_fp8 = Float8Tensor .to_float8 (
297
+ x ,
298
+ self .fp8_scale_x ,
299
+ e4m3_dtype ,
300
+ self .fp8_amax_x ,
301
+ self .forward_config ,
302
+ )
303
+ else :
304
+ assert self .scaling_type_x is TensorScalingType .DYNAMIC
305
+ x_fp8 = cast_to_float8_e4m3_dynamic (x , self .forward_config )
255
306
return x_fp8
256
307
257
308
def cast_w_to_float8 (
258
309
self , w : torch .Tensor , is_amax_initialized : bool
259
310
) -> torch .Tensor :
260
- scale_fn_name = self .recipe .scale_fn_name
261
- _maybe_initialize_amaxes_scales_for_float8_cast (
262
- w ,
263
- self .fp8_amax_w ,
264
- self .fp8_amax_history_w ,
265
- self .fp8_scale_w ,
266
- scale_fn_name ,
267
- e4m3_dtype ,
268
- is_amax_initialized ,
269
- reduce_amax = False ,
270
- )
311
+ if self .scaling_type_w is TensorScalingType .DELAYED :
312
+ scale_fn_name = self .recipe .scale_fn_name
313
+ _maybe_initialize_amaxes_scales_for_float8_cast (
314
+ w ,
315
+ self .fp8_amax_w ,
316
+ self .fp8_amax_history_w ,
317
+ self .fp8_scale_w ,
318
+ scale_fn_name ,
319
+ e4m3_dtype ,
320
+ is_amax_initialized ,
321
+ reduce_amax = False ,
322
+ )
271
323
272
- w_fp8 = Float8Tensor .to_float8 (
273
- w ,
274
- self .fp8_scale_w ,
275
- e4m3_dtype ,
276
- self .fp8_amax_w ,
277
- self .forward_config ,
278
- )
324
+ w_fp8 = Float8Tensor .to_float8 (
325
+ w ,
326
+ self .fp8_scale_w ,
327
+ e4m3_dtype ,
328
+ self .fp8_amax_w ,
329
+ self .forward_config ,
330
+ )
331
+ else :
332
+ assert self .scaling_type_w is TensorScalingType .DYNAMIC
333
+ # TODO(future): also support FSDP integration in delayed scaling path
334
+ if isinstance (self .weight , Float8Tensor ): # cast by FSDP
335
+ w_fp8 = self .weight
336
+ else :
337
+ w_fp8 = cast_to_float8_e4m3_dynamic (self .weight , self .forward_config )
279
338
return w_fp8
280
339
281
340
def cast_y_to_float8_in_bw (self , y : torch .Tensor ) -> torch .Tensor :
282
- scale_fn_name = self .recipe .scale_fn_name
283
- y = NoopFwToFloat8E5M2Bw .apply (
284
- y ,
285
- self .fp8_amax_dL_dY ,
286
- self .fp8_amax_history_dL_dY ,
287
- self .fp8_scale_dL_dY ,
288
- scale_fn_name ,
289
- self .is_amax_initialized ,
290
- self .backward_config ,
291
- )
341
+ if self .scaling_type_dL_dY is TensorScalingType .DELAYED :
342
+ scale_fn_name = self .recipe .scale_fn_name
343
+ y = NoopFwToFloat8E5M2Bw .apply (
344
+ y ,
345
+ self .fp8_amax_dL_dY ,
346
+ self .fp8_amax_history_dL_dY ,
347
+ self .fp8_scale_dL_dY ,
348
+ scale_fn_name ,
349
+ self .is_amax_initialized ,
350
+ self .backward_config ,
351
+ )
352
+ else :
353
+ assert self .scaling_type_dL_dY is TensorScalingType .DYNAMIC
354
+ y = cast_to_float8_e5m2_dynamic_bw (y , self .backward_config )
292
355
return y
293
356
294
357
def float8_pre_forward (self , x ):
@@ -313,7 +376,8 @@ def float8_post_forward(self):
313
376
self .amax_and_scale_synced = False
314
377
315
378
def forward (self , x ):
316
- self .float8_pre_forward (x )
379
+ if self .has_any_delayed_scaling :
380
+ self .float8_pre_forward (x )
317
381
318
382
x_fp8 = self .cast_x_to_float8 (x , self .is_amax_initialized )
319
383
w_fp8 = self .cast_w_to_float8 (self .weight , self .is_amax_initialized )
@@ -326,11 +390,29 @@ def forward(self, x):
326
390
if self .bias is not None :
327
391
y = y + self .bias .to (y .dtype )
328
392
329
- self .float8_post_forward ()
393
+ if self .has_any_delayed_scaling :
394
+ self .float8_post_forward ()
330
395
return y
331
396
397
+ def extra_repr (self ):
398
+ # example: in_features=32, out_features=16, bias=True
399
+ s = super ().extra_repr ()
400
+ # add scaling settings without using too many characters
401
+ scaling = f"x:{ self .scaling_type_x .short_str ()} ,w:{ self .scaling_type_w .short_str ()} ,dldy:{ self .scaling_type_dL_dY .short_str ()} "
402
+
403
+ s = f'{ s } , scaling="{ scaling } "'
404
+ # example: in_features=32, out_features=16, bias=True, scaling="x:del,w:del,dldy:dyn"
405
+ return s
406
+
332
407
@classmethod
333
- def from_float (cls , mod , emulate : bool = False ):
408
+ def from_float (
409
+ cls ,
410
+ mod ,
411
+ emulate : bool = False ,
412
+ scaling_type_x = TensorScalingType .DELAYED ,
413
+ scaling_type_w = TensorScalingType .DELAYED ,
414
+ scaling_type_dL_dY = TensorScalingType .DELAYED ,
415
+ ):
334
416
"""
335
417
Create an nn.Linear with fp8 compute from a regular nn.Linear
336
418
@@ -339,14 +421,22 @@ def from_float(cls, mod, emulate: bool = False):
339
421
emulate (bool): whether to emulate fp8 matmul logic in float32
340
422
"""
341
423
with torch .device ("meta" ):
342
- new_mod = cls (mod .in_features , mod .out_features , bias = False )
424
+ new_mod = cls (
425
+ mod .in_features ,
426
+ mod .out_features ,
427
+ bias = False ,
428
+ scaling_type_x = scaling_type_x ,
429
+ scaling_type_w = scaling_type_w ,
430
+ scaling_type_dL_dY = scaling_type_dL_dY ,
431
+ )
343
432
new_mod .weight = mod .weight
344
433
new_mod .bias = mod .bias
345
434
# need to create buffers again when moving from meta device to
346
435
# real device
347
436
new_mod .create_buffers ()
348
437
# Defines the behavior of the matmul in the forward and backward
349
438
# Forward we use fast_accum, backwards we do not
439
+ # TODO(future PR): move below to the constructor
350
440
new_mod .forward_config = ScaledMMConfig (
351
441
emulate , True if not emulate else False , False , config .pad_inner_dim
352
442
)
0 commit comments