Skip to content

Commit 2ca47b6

Browse files
committed
[BE] Introduce linear_forward_int8
Similar to `linear_forward_int4` to be replaced with special CPU op later
1 parent 06ac1da commit 2ca47b6

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

quantize.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,24 @@ def replace_linear_weight_only_int8_per_channel(
351351
child, device, node_type, groupsize
352352
)
353353

354+
def linear_forward_int8(x, weight, scales):
355+
scales = scales.view(scales.shape[0], -1)
356+
n_groups = scales.shape[1]
357+
# need a formulation / custom op for good performance
358+
# on eager, CUDA compiled, CPU compiled and ET exported
359+
360+
# for now, we special-case channel-wise, because we know how to make that fast (but does not work for groupwise)
361+
if n_groups == 1:
362+
return F.linear(x, weight.to(dtype=x.dtype)) * scales
363+
364+
return F.linear(
365+
x,
366+
torch.mul(
367+
weight.to(dtype=x.dtype).view(weight.shape[0], n_groups, -1),
368+
scales.view(weight.shape[0], n_groups, -1)
369+
).view(weight.shape[0], -1),
370+
)
371+
354372

355373
class WeightOnlyInt8QuantHandler(QuantHandler):
356374
def __init__(
@@ -471,25 +489,7 @@ def __init__(
471489
)
472490

473491
def forward(self, input: torch.Tensor) -> torch.Tensor:
474-
scales = self.scales
475-
weight = self.weight
476-
scales = scales.view(scales.shape[0], -1)
477-
no_groups = scales.shape[1]
478-
479-
# need a formulation / custom op for good performance
480-
# on eager, CUDA compiled, CPU compiled and ET exported
481-
482-
# for now, we special-case channel-wise, because we know how to make that fast (but does not work for groupwise)
483-
if scales.shape[1] == 1:
484-
return F.linear(input, weight.to(dtype=input.dtype)) * self.scales
485-
else:
486-
return F.linear(
487-
input,
488-
(
489-
weight.to(dtype=input.dtype).view(weight.shape[0], no_groups, -1)
490-
* scales.view(weight.shape[0], no_groups, -1)
491-
).view(weight.shape[0], -1),
492-
)
492+
return linear_forward_int8(input, self.weight, self.scales)
493493

494494

495495
#########################################################################

0 commit comments

Comments
 (0)