Skip to content

move int8 linear class and function into qops.py #534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions qops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter


def linear_int8(input, weight, scales):
n_groups = scales.numel() // scales.shape[0]

# we special-case channel-wise, because we know how to make that fast
if n_groups == 1:
if (
torch.compiler.is_compiling()
or input.device.type != "cpu"
or torch.__version__ < "2.4"
):
return F.linear(input, weight.to(dtype=input.dtype)) * scales
# Use int8pack_mm for CPU eager
return torch.ops.aten._weight_int8pack_mm(
input.reshape(-1, input.shape[-1]),
weight,
scales,
).reshape(input.shape[:-1] + (weight.shape[0],))

return F.linear(
input,
(
weight.to(dtype=input.dtype).view(weight.shape[0], n_groups, -1)
* scales.view(weight.shape[0], n_groups, -1)
).view(weight.shape[0], -1),
)


class LinearInt8(nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
scales: torch.Tensor

def __init__(
self,
in_features,
out_features,
bias=None,
device=None,
dtype=None,
*,
weight: Optional[torch.Tensor] = None,
scales: Optional[torch.Tensor] = None,
groupsize: Optional[int] = None,
):
super().__init__()
if dtype is None:
dtype = torch.get_default_dtype()

if device is None:
device = "cpu"

if device == "einputecutorch":
device = "cpu"

assert not bias, "Bias is not supported by LinearInt8"
self.in_features = in_features
self.out_features = out_features

assert bool(weight) == bool(
scales
), "must specify both weights and scales, or neither"
if not weight:
weight = torch.empty(
(out_features, in_features), dtype=torch.int8, device=device
)
if groupsize is None or (groupsize == 0):
scales = torch.empty(out_features, dtype=dtype, device=device)
else:
n_groups = (in_features + groupsize - 1) // groupsize
scales = torch.empty(out_features, n_groups, dtype=dtype, device=device)

self.register_buffer("weight", weight.to(device))
self.register_buffer("scales", scales.to(device))

def forward(self, input: torch.Tensor) -> torch.Tensor:
return linear_int8(input, self.weight, self.scales)
74 changes: 5 additions & 69 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
use_et_backend,
)

from qops import LinearInt8 as WeightOnlyInt8Linear

#########################################################################
### torchchat quantization API ###
Expand Down Expand Up @@ -377,7 +378,10 @@ def replace_linear_weight_only_int8_per_channel(
module,
name,
WeightOnlyInt8Linear(
device, child.in_features, child.out_features, groupsize
in_features=child.in_features,
out_features=child.out_features,
device=device,
groupsize=groupsize,
),
)
else:
Expand All @@ -386,35 +390,6 @@ def replace_linear_weight_only_int8_per_channel(
)


def linear_forward_int8(x, weight, scales):
n_groups = scales.numel() // scales.shape[0]
# need a formulation / custom op for good performance
# on eager, CUDA compiled, CPU compiled and ET exported

# for now, we special-case channel-wise, because we know how to make that fast (but does not work for groupwise)
if n_groups == 1:
if (
torch.compiler.is_compiling()
or x.device.type != "cpu"
or torch.__version__ < "2.4"
):
return F.linear(x, weight.to(dtype=x.dtype)) * scales
# Use int8pack_mm for CPU eager
return torch.ops.aten._weight_int8pack_mm(
x.reshape(-1, x.shape[-1]),
weight,
scales,
).reshape(x.shape[:-1] + (weight.shape[0],))

return F.linear(
x,
(
weight.to(dtype=x.dtype).view(weight.shape[0], n_groups, -1)
* scales.view(weight.shape[0], n_groups, -1)
).view(weight.shape[0], -1),
)


class WeightOnlyInt8QuantHandler(QuantHandler):
def __init__(
self,
Expand Down Expand Up @@ -499,45 +474,6 @@ def quantized_model(self) -> nn.Module:
return self.model_


class WeightOnlyInt8Linear(torch.nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor

def __init__(
self,
device,
in_features: int,
out_features: int,
groupsize: Optional[int] = None,
bias: bool = True,
dtype=None,
) -> None:
super().__init__()
# print(f"group size: {groupsize}")

self.in_features = in_features
self.out_features = out_features
self.register_buffer(
"weight",
torch.empty((out_features, in_features), dtype=torch.int8, device=device),
)
dtype = get_precision()
if groupsize is None or (groupsize == 0):
self.register_buffer(
"scales", torch.ones(out_features, dtype=dtype, device=device)
)
else:
groups = (in_features + groupsize - 1) // groupsize
self.register_buffer(
"scales", torch.ones(out_features, groups, dtype=dtype, device=device)
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
return linear_forward_int8(input, self.weight, self.scales)


#########################################################################
##### embedding table quantization ######

Expand Down