Skip to content

Commit 8605a11

Browse files
committed
Use _weight_int8pack_mm for CPU + eager (#472)
* Use _weight_int8pack_mm for CPU + eager * Skip for older PyTorch versions
1 parent 9b710b0 commit 8605a11

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

quantize.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,14 +351,26 @@ def replace_linear_weight_only_int8_per_channel(
351351
child, device, node_type, groupsize
352352
)
353353

354+
354355
def linear_forward_int8(x, weight, scales):
355356
n_groups = scales.numel() // scales.shape[0]
356357
# need a formulation / custom op for good performance
357358
# on eager, CUDA compiled, CPU compiled and ET exported
358359

359360
# for now, we special-case channel-wise, because we know how to make that fast (but does not work for groupwise)
360361
if n_groups == 1:
361-
return F.linear(x, weight.to(dtype=x.dtype)) * scales
362+
if (
363+
torch.compiler.is_compiling()
364+
or x.device.type != "cpu"
365+
or torch.__version__ < "2.4"
366+
):
367+
return F.linear(x, weight.to(dtype=x.dtype)) * scales
368+
# Use int8pack_mm for CPU eager
369+
return torch.ops.aten._weight_int8pack_mm(
370+
x.reshape(-1, x.shape[-1]),
371+
weight,
372+
scales,
373+
).reshape(x.shape[:-1] + (weight.shape[0],))
362374

363375
return F.linear(
364376
x,

0 commit comments

Comments
 (0)