Skip to content

Commit 2b46b72

Browse files
mikekgfbmalfet
authored andcommitted
user linear_int8 (#135)
1 parent b20db51 commit 2b46b72

File tree

1 file changed

+7
-13
lines changed

1 file changed

+7
-13
lines changed

quantize.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
import torch.nn as nn
1414
import torch.nn.functional as F
15+
import quantized_ops
1516

1617

1718
try:
@@ -471,19 +472,12 @@ def __init__(
471472
self.register_buffer("scales", torch.ones(out_features, groups, dtype=torch.bfloat16))
472473

473474
def forward(self, input: torch.Tensor) -> torch.Tensor:
474-
scales = self.scales
475-
weight = self.weight
476-
scales = scales.view(scales.shape[0], -1)
477-
no_groups = scales.shape[1]
478-
479-
# need a formulation / custom op for good performance on both eager, CUDA compiled, CPU compiled and ET exported
480-
# maybe use IR-based rewriting?
481-
482-
# for now, we special-case channel-wise, because we know how to make that fast (but does not work for groupwise)
483-
if scales.shape[1] == 1:
484-
return F.linear(input, weight.to(dtype=input.dtype)) * self.scales
485-
else:
486-
return F.linear(input, (weight.to(dtype=input.dtype).view(weight.shape[0], no_groups, -1) * scales.view(weight.shape[0], no_groups, -1)).view(weight.shape[0], -1))
475+
return torch.ops.torchat.linear_int8(
476+
input,
477+
self.weight,
478+
self.scales,
479+
None
480+
)
487481

488482

489483
#########################################################################

0 commit comments

Comments
 (0)