@@ -351,6 +351,23 @@ def replace_linear_weight_only_int8_per_channel(
351
351
child , device , node_type , groupsize
352
352
)
353
353
354
+ def linear_forward_int8 (x , weight , scales ):
355
+ n_groups = scales .numel () // scales .shape [0 ]
356
+ # need a formulation / custom op for good performance
357
+ # on eager, CUDA compiled, CPU compiled and ET exported
358
+
359
+ # for now, we special-case channel-wise, because we know how to make that fast (but does not work for groupwise)
360
+ if n_groups == 1 :
361
+ return F .linear (x , weight .to (dtype = x .dtype )) * scales
362
+
363
+ return F .linear (
364
+ x ,
365
+ (
366
+ weight .to (dtype = x .dtype ).view (weight .shape [0 ], n_groups , - 1 )
367
+ * scales .view (weight .shape [0 ], n_groups , - 1 )
368
+ ).view (weight .shape [0 ], - 1 ),
369
+ )
370
+
354
371
355
372
class WeightOnlyInt8QuantHandler (QuantHandler ):
356
373
def __init__ (
@@ -471,25 +488,7 @@ def __init__(
471
488
)
472
489
473
490
def forward (self , input : torch .Tensor ) -> torch .Tensor :
474
- scales = self .scales
475
- weight = self .weight
476
- scales = scales .view (scales .shape [0 ], - 1 )
477
- no_groups = scales .shape [1 ]
478
-
479
- # need a formulation / custom op for good performance
480
- # on eager, CUDA compiled, CPU compiled and ET exported
481
-
482
- # for now, we special-case channel-wise, because we know how to make that fast (but does not work for groupwise)
483
- if scales .shape [1 ] == 1 :
484
- return F .linear (input , weight .to (dtype = input .dtype )) * self .scales
485
- else :
486
- return F .linear (
487
- input ,
488
- (
489
- weight .to (dtype = input .dtype ).view (weight .shape [0 ], no_groups , - 1 )
490
- * scales .view (weight .shape [0 ], no_groups , - 1 )
491
- ).view (weight .shape [0 ], - 1 ),
492
- )
491
+ return linear_forward_int8 (input , self .weight , self .scales )
493
492
494
493
495
494
#########################################################################
0 commit comments