Skip to content

Commit d226c33

Browse files
mikekgfbmalfet
authored andcommitted
move int8 linear class and function into qops.py (#534)
1 parent a653080 commit d226c33

File tree

2 files changed

+91
-69
lines changed

2 files changed

+91
-69
lines changed

qops.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from typing import Optional
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
from torch.nn.parameter import Parameter
7+
8+
9+
def linear_int8(input, weight, scales):
10+
n_groups = scales.numel() // scales.shape[0]
11+
12+
# we special-case channel-wise, because we know how to make that fast
13+
if n_groups == 1:
14+
if (
15+
torch.compiler.is_compiling()
16+
or input.device.type != "cpu"
17+
or torch.__version__ < "2.4"
18+
):
19+
return F.linear(input, weight.to(dtype=input.dtype)) * scales
20+
# Use int8pack_mm for CPU eager
21+
return torch.ops.aten._weight_int8pack_mm(
22+
input.reshape(-1, input.shape[-1]),
23+
weight,
24+
scales,
25+
).reshape(input.shape[:-1] + (weight.shape[0],))
26+
27+
return F.linear(
28+
input,
29+
(
30+
weight.to(dtype=input.dtype).view(weight.shape[0], n_groups, -1)
31+
* scales.view(weight.shape[0], n_groups, -1)
32+
).view(weight.shape[0], -1),
33+
)
34+
35+
36+
class LinearInt8(nn.Module):
37+
__constants__ = ["in_features", "out_features"]
38+
in_features: int
39+
out_features: int
40+
weight: torch.Tensor
41+
scales: torch.Tensor
42+
43+
def __init__(
44+
self,
45+
in_features,
46+
out_features,
47+
bias=None,
48+
device=None,
49+
dtype=None,
50+
*,
51+
weight: Optional[torch.Tensor] = None,
52+
scales: Optional[torch.Tensor] = None,
53+
groupsize: Optional[int] = None,
54+
):
55+
super().__init__()
56+
if dtype is None:
57+
dtype = torch.get_default_dtype()
58+
59+
if device is None:
60+
device = "cpu"
61+
62+
if device == "einputecutorch":
63+
device = "cpu"
64+
65+
assert not bias, "Bias is not supported by LinearInt8"
66+
self.in_features = in_features
67+
self.out_features = out_features
68+
69+
assert bool(weight) == bool(
70+
scales
71+
), "must specify both weights and scales, or neither"
72+
if not weight:
73+
weight = torch.empty(
74+
(out_features, in_features), dtype=torch.int8, device=device
75+
)
76+
if groupsize is None or (groupsize == 0):
77+
scales = torch.empty(out_features, dtype=dtype, device=device)
78+
else:
79+
n_groups = (in_features + groupsize - 1) // groupsize
80+
scales = torch.empty(out_features, n_groups, dtype=dtype, device=device)
81+
82+
self.register_buffer("weight", weight.to(device))
83+
self.register_buffer("scales", scales.to(device))
84+
85+
def forward(self, input: torch.Tensor) -> torch.Tensor:
86+
return linear_int8(input, self.weight, self.scales)

quantize.py

Lines changed: 5 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
use_et_backend,
2525
)
2626

27+
from qops import LinearInt8 as WeightOnlyInt8Linear
2728

2829
#########################################################################
2930
### torchchat quantization API ###
@@ -377,7 +378,10 @@ def replace_linear_weight_only_int8_per_channel(
377378
module,
378379
name,
379380
WeightOnlyInt8Linear(
380-
device, child.in_features, child.out_features, groupsize
381+
in_features=child.in_features,
382+
out_features=child.out_features,
383+
device=device,
384+
groupsize=groupsize,
381385
),
382386
)
383387
else:
@@ -386,35 +390,6 @@ def replace_linear_weight_only_int8_per_channel(
386390
)
387391

388392

389-
def linear_forward_int8(x, weight, scales):
390-
n_groups = scales.numel() // scales.shape[0]
391-
# need a formulation / custom op for good performance
392-
# on eager, CUDA compiled, CPU compiled and ET exported
393-
394-
# for now, we special-case channel-wise, because we know how to make that fast (but does not work for groupwise)
395-
if n_groups == 1:
396-
if (
397-
torch.compiler.is_compiling()
398-
or x.device.type != "cpu"
399-
or torch.__version__ < "2.4"
400-
):
401-
return F.linear(x, weight.to(dtype=x.dtype)) * scales
402-
# Use int8pack_mm for CPU eager
403-
return torch.ops.aten._weight_int8pack_mm(
404-
x.reshape(-1, x.shape[-1]),
405-
weight,
406-
scales,
407-
).reshape(x.shape[:-1] + (weight.shape[0],))
408-
409-
return F.linear(
410-
x,
411-
(
412-
weight.to(dtype=x.dtype).view(weight.shape[0], n_groups, -1)
413-
* scales.view(weight.shape[0], n_groups, -1)
414-
).view(weight.shape[0], -1),
415-
)
416-
417-
418393
class WeightOnlyInt8QuantHandler(QuantHandler):
419394
def __init__(
420395
self,
@@ -499,45 +474,6 @@ def quantized_model(self) -> nn.Module:
499474
return self.model_
500475

501476

502-
class WeightOnlyInt8Linear(torch.nn.Module):
503-
__constants__ = ["in_features", "out_features"]
504-
in_features: int
505-
out_features: int
506-
weight: torch.Tensor
507-
508-
def __init__(
509-
self,
510-
device,
511-
in_features: int,
512-
out_features: int,
513-
groupsize: Optional[int] = None,
514-
bias: bool = True,
515-
dtype=None,
516-
) -> None:
517-
super().__init__()
518-
# print(f"group size: {groupsize}")
519-
520-
self.in_features = in_features
521-
self.out_features = out_features
522-
self.register_buffer(
523-
"weight",
524-
torch.empty((out_features, in_features), dtype=torch.int8, device=device),
525-
)
526-
dtype = get_precision()
527-
if groupsize is None or (groupsize == 0):
528-
self.register_buffer(
529-
"scales", torch.ones(out_features, dtype=dtype, device=device)
530-
)
531-
else:
532-
groups = (in_features + groupsize - 1) // groupsize
533-
self.register_buffer(
534-
"scales", torch.ones(out_features, groups, dtype=dtype, device=device)
535-
)
536-
537-
def forward(self, input: torch.Tensor) -> torch.Tensor:
538-
return linear_forward_int8(input, self.weight, self.scales)
539-
540-
541477
#########################################################################
542478
##### embedding table quantization ######
543479

0 commit comments

Comments
 (0)