Skip to content

Commit 5770b42

Browse files
committed
[BE] Introduce linear_forward_int8 (#432)
Similar to `linear_forward_int4` to be replaced with special CPU op later
1 parent b3a2672 commit 5770b42

File tree

1 file changed

+18
-19
lines changed

1 file changed

+18
-19
lines changed

quantize.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,23 @@ def replace_linear_weight_only_int8_per_channel(
351351
child, device, node_type, groupsize
352352
)
353353

354+
def linear_forward_int8(x, weight, scales):
355+
n_groups = scales.numel() // scales.shape[0]
356+
# need a formulation / custom op for good performance
357+
# on eager, CUDA compiled, CPU compiled and ET exported
358+
359+
# for now, we special-case channel-wise, because we know how to make that fast (but does not work for groupwise)
360+
if n_groups == 1:
361+
return F.linear(x, weight.to(dtype=x.dtype)) * scales
362+
363+
return F.linear(
364+
x,
365+
(
366+
weight.to(dtype=x.dtype).view(weight.shape[0], n_groups, -1)
367+
* scales.view(weight.shape[0], n_groups, -1)
368+
).view(weight.shape[0], -1),
369+
)
370+
354371

355372
class WeightOnlyInt8QuantHandler(QuantHandler):
356373
def __init__(
@@ -471,25 +488,7 @@ def __init__(
471488
)
472489

473490
def forward(self, input: torch.Tensor) -> torch.Tensor:
474-
scales = self.scales
475-
weight = self.weight
476-
scales = scales.view(scales.shape[0], -1)
477-
no_groups = scales.shape[1]
478-
479-
# need a formulation / custom op for good performance
480-
# on eager, CUDA compiled, CPU compiled and ET exported
481-
482-
# for now, we special-case channel-wise, because we know how to make that fast (but does not work for groupwise)
483-
if scales.shape[1] == 1:
484-
return F.linear(input, weight.to(dtype=input.dtype)) * self.scales
485-
else:
486-
return F.linear(
487-
input,
488-
(
489-
weight.to(dtype=input.dtype).view(weight.shape[0], no_groups, -1)
490-
* scales.view(weight.shape[0], no_groups, -1)
491-
).view(weight.shape[0], -1),
492-
)
491+
return linear_forward_int8(input, self.weight, self.scales)
493492

494493

495494
#########################################################################

0 commit comments

Comments
 (0)