File tree Expand file tree Collapse file tree 1 file changed +13
-1
lines changed Expand file tree Collapse file tree 1 file changed +13
-1
lines changed Original file line number Diff line number Diff line change @@ -351,14 +351,26 @@ def replace_linear_weight_only_int8_per_channel(
351
351
child , device , node_type , groupsize
352
352
)
353
353
354
+
354
355
def linear_forward_int8 (x , weight , scales ):
355
356
n_groups = scales .numel () // scales .shape [0 ]
356
357
# need a formulation / custom op for good performance
357
358
# on eager, CUDA compiled, CPU compiled and ET exported
358
359
359
360
# for now, we special-case channel-wise, because we know how to make that fast (but does not work for groupwise)
360
361
if n_groups == 1 :
361
- return F .linear (x , weight .to (dtype = x .dtype )) * scales
362
+ if (
363
+ torch .compiler .is_compiling ()
364
+ or x .device .type != "cpu"
365
+ or torch .__version__ < "2.4"
366
+ ):
367
+ return F .linear (x , weight .to (dtype = x .dtype )) * scales
368
+ # Use int8pack_mm for CPU eager
369
+ return torch .ops .aten ._weight_int8pack_mm (
370
+ x .reshape (- 1 , x .shape [- 1 ]),
371
+ weight ,
372
+ scales ,
373
+ ).reshape (x .shape [:- 1 ] + (weight .shape [0 ],))
362
374
363
375
return F .linear (
364
376
x ,
You can’t perform that action at this time.
0 commit comments