24
24
use_et_backend ,
25
25
)
26
26
27
+ from qops import LinearInt8 as WeightOnlyInt8Linear
27
28
28
29
#########################################################################
29
30
### torchchat quantization API ###
@@ -377,7 +378,10 @@ def replace_linear_weight_only_int8_per_channel(
377
378
module ,
378
379
name ,
379
380
WeightOnlyInt8Linear (
380
- device , child .in_features , child .out_features , groupsize
381
+ in_features = child .in_features ,
382
+ out_features = child .out_features ,
383
+ device = device ,
384
+ groupsize = groupsize ,
381
385
),
382
386
)
383
387
else :
@@ -386,35 +390,6 @@ def replace_linear_weight_only_int8_per_channel(
386
390
)
387
391
388
392
389
- def linear_forward_int8 (x , weight , scales ):
390
- n_groups = scales .numel () // scales .shape [0 ]
391
- # need a formulation / custom op for good performance
392
- # on eager, CUDA compiled, CPU compiled and ET exported
393
-
394
- # for now, we special-case channel-wise, because we know how to make that fast (but does not work for groupwise)
395
- if n_groups == 1 :
396
- if (
397
- torch .compiler .is_compiling ()
398
- or x .device .type != "cpu"
399
- or torch .__version__ < "2.4"
400
- ):
401
- return F .linear (x , weight .to (dtype = x .dtype )) * scales
402
- # Use int8pack_mm for CPU eager
403
- return torch .ops .aten ._weight_int8pack_mm (
404
- x .reshape (- 1 , x .shape [- 1 ]),
405
- weight ,
406
- scales ,
407
- ).reshape (x .shape [:- 1 ] + (weight .shape [0 ],))
408
-
409
- return F .linear (
410
- x ,
411
- (
412
- weight .to (dtype = x .dtype ).view (weight .shape [0 ], n_groups , - 1 )
413
- * scales .view (weight .shape [0 ], n_groups , - 1 )
414
- ).view (weight .shape [0 ], - 1 ),
415
- )
416
-
417
-
418
393
class WeightOnlyInt8QuantHandler (QuantHandler ):
419
394
def __init__ (
420
395
self ,
@@ -499,45 +474,6 @@ def quantized_model(self) -> nn.Module:
499
474
return self .model_
500
475
501
476
502
- class WeightOnlyInt8Linear (torch .nn .Module ):
503
- __constants__ = ["in_features" , "out_features" ]
504
- in_features : int
505
- out_features : int
506
- weight : torch .Tensor
507
-
508
- def __init__ (
509
- self ,
510
- device ,
511
- in_features : int ,
512
- out_features : int ,
513
- groupsize : Optional [int ] = None ,
514
- bias : bool = True ,
515
- dtype = None ,
516
- ) -> None :
517
- super ().__init__ ()
518
- # print(f"group size: {groupsize}")
519
-
520
- self .in_features = in_features
521
- self .out_features = out_features
522
- self .register_buffer (
523
- "weight" ,
524
- torch .empty ((out_features , in_features ), dtype = torch .int8 , device = device ),
525
- )
526
- dtype = get_precision ()
527
- if groupsize is None or (groupsize == 0 ):
528
- self .register_buffer (
529
- "scales" , torch .ones (out_features , dtype = dtype , device = device )
530
- )
531
- else :
532
- groups = (in_features + groupsize - 1 ) // groupsize
533
- self .register_buffer (
534
- "scales" , torch .ones (out_features , groups , dtype = dtype , device = device )
535
- )
536
-
537
- def forward (self , input : torch .Tensor ) -> torch .Tensor :
538
- return linear_forward_int8 (input , self .weight , self .scales )
539
-
540
-
541
477
#########################################################################
542
478
##### embedding table quantization ######
543
479
0 commit comments