@@ -351,6 +351,24 @@ def replace_linear_weight_only_int8_per_channel(
351
351
child , device , node_type , groupsize
352
352
)
353
353
354
+ def linear_forward_int8 (x , weight , scales ):
355
+ scales = scales .view (scales .shape [0 ], - 1 )
356
+ n_groups = scales .shape [1 ]
357
+ # need a formulation / custom op for good performance
358
+ # on eager, CUDA compiled, CPU compiled and ET exported
359
+
360
+ # for now, we special-case channel-wise, because we know how to make that fast (but does not work for groupwise)
361
+ if n_groups == 1 :
362
+ return F .linear (x , weight .to (dtype = x .dtype )) * scales
363
+
364
+ return F .linear (
365
+ x ,
366
+ torch .mul (
367
+ weight .to (dtype = x .dtype ).view (weight .shape [0 ], n_groups , - 1 ),
368
+ scales .view (weight .shape [0 ], n_groups , - 1 )
369
+ ).view (weight .shape [0 ], - 1 ),
370
+ )
371
+
354
372
355
373
class WeightOnlyInt8QuantHandler (QuantHandler ):
356
374
def __init__ (
@@ -471,25 +489,7 @@ def __init__(
471
489
)
472
490
473
491
def forward (self , input : torch .Tensor ) -> torch .Tensor :
474
- scales = self .scales
475
- weight = self .weight
476
- scales = scales .view (scales .shape [0 ], - 1 )
477
- no_groups = scales .shape [1 ]
478
-
479
- # need a formulation / custom op for good performance
480
- # on eager, CUDA compiled, CPU compiled and ET exported
481
-
482
- # for now, we special-case channel-wise, because we know how to make that fast (but does not work for groupwise)
483
- if scales .shape [1 ] == 1 :
484
- return F .linear (input , weight .to (dtype = input .dtype )) * self .scales
485
- else :
486
- return F .linear (
487
- input ,
488
- (
489
- weight .to (dtype = input .dtype ).view (weight .shape [0 ], no_groups , - 1 )
490
- * scales .view (weight .shape [0 ], no_groups , - 1 )
491
- ).view (weight .shape [0 ], - 1 ),
492
- )
492
+ return linear_forward_int8 (input , self .weight , self .scales )
493
493
494
494
495
495
#########################################################################
0 commit comments