Skip to content

Commit ff864d6

Browse files
mikekgfbmalfet
authored andcommitted
add quantized ops (#119)
* add quantized ops * updates
1 parent 17a9c86 commit ff864d6

File tree

1 file changed

+131
-0
lines changed

1 file changed

+131
-0
lines changed

quantized_ops.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
import torch
10+
from torch.library import impl, impl_abstract
11+
12+
torchat_lib = torch.library.Library(
13+
"torchat", "DEF"
14+
)
15+
16+
torchat_lib.define(
17+
"embedding_int8(Tensor input, Tensor weight, "
18+
"Tensor scales) -> Tensor",
19+
)
20+
21+
@impl(torchat_lib, "embedding_int8", "CompositeExplicitAutograd")
22+
def embedding_int8(
23+
input: torch.Tensor,
24+
weight: torch.Tensor,
25+
scales: torch.Tensor,
26+
) -> torch.Tensor:
27+
indices = input
28+
# embedding_byte_weight_checks(weight, weight_scales, weight_zero_points)
29+
group_size = weight.size(1) // (
30+
weight_scales.size(1) if weight_scales.dim() == 2 else 1
31+
)
32+
# ET definition
33+
if False:
34+
weight_zero_points = None
35+
weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
36+
weight,
37+
weight_scales,
38+
weight_zero_points,
39+
weight_quant_min,
40+
weight_quant_max,
41+
weight.dtype,
42+
group_size,
43+
weight_scales.dtype,
44+
)
45+
return torch.ops.aten.embedding.default(weight, indices)
46+
47+
scales = scales.view(weight.shape[0], -1)
48+
result_weights = F.embedding(indices, weight)
49+
result_scales = F.embedding(indices, scales)
50+
51+
rw_view = result_weights.to(dtype=result_scales.dtype).view(tuple(result_weights.shape[:-1] + (scales.shape[1], -1, )))
52+
rs_view = result_scales.view(tuple(result_scales.shape[:-1]) + (scales.shape[1], 1, ))
53+
# print(f"rw_view {rw_view.shape}")
54+
# print(f"rs_view {rs_view.shape}")
55+
56+
r = rw_view * rs_view
57+
return r.view(indices.size() + (-1,))
58+
59+
60+
torchat_lib.define(
61+
"linear_int8(Tensor input, Tensor weight, Tensor scales, "
62+
"Tensor bias = None) -> Tensor",
63+
)
64+
65+
@impl(torchat_lib, "linear_int8", "CompositeExplicitAutograd")
66+
def linear_int8(
67+
input: torch.Tensor,
68+
weight: torch.Tensor,
69+
scales: torch.Tensor,
70+
bias: Optional[torch.Tensor] = None,
71+
) -> Tensor:
72+
assert bias is None, "bias != None not implemented"
73+
74+
scales = scales.view(scales.shape[0], -1)
75+
no_groups = scales.shape[1]
76+
77+
# for now, we special-case channel-wise, because we know how to
78+
# make that fast with Triton
79+
if scales.shape[1] == 1:
80+
return F.linear(input, weight.to(dtype=input.dtype)) * self.scales
81+
else:
82+
return F.linear(
83+
input,
84+
(weight.to(dtype=input.dtype).view(weight.shape[0],no_groups, -1)
85+
* scales.view(weight.shape[0], no_groups, -1)
86+
).view(weight.shape[0], -1)
87+
)
88+
89+
90+
91+
torchat_lib.define(
92+
"linear_int4(Tensor input, Tensor weight, Tensor scales_and_zeros, "
93+
"Tensor bias = None, int groupsize, int origin_in_features, "
94+
"int int_features, int out_features, bool padding = True) -> Tensor",
95+
)
96+
97+
@impl(torchat_lib, "linear_int4", "CompositeExplicitAutograd")
98+
def linear_int4(
99+
input: torch.Tensor,
100+
weight: torch.Tensor,
101+
scales_and_zeros: torch.Tensor,
102+
bias: torch.Tensor = None,
103+
*,
104+
groupsize: int,
105+
origin_in_features: int,
106+
in_features: int,
107+
out_features: int,
108+
padding: bool = True,
109+
) -> Tensor:
110+
assert bias is None, "bias != None not implemented"
111+
112+
if padding:
113+
import torch.nn.functional as F
114+
input = F.pad(input, pad=(0, in_features - origin_in_features))
115+
116+
# the weight is in int4pack format
117+
# rename to remind ourselves of that
118+
weight_int4pack = weight
119+
120+
origin_input_size = input.size()
121+
input = input.reshape(-1, origin_input_size[-1])
122+
c = torch.ops.aten._weight_int4pack_mm(
123+
input.to(dtype=torch.bfloat16),
124+
weight_int4pack,
125+
groupsize,
126+
scales_and_zeros.to(dtype=torch.bfloat16)
127+
).to(dtype=input.dtype)
128+
new_shape = origin_input_size[:-1] + (out_features,)
129+
c = c.reshape(new_shape)
130+
return c
131+

0 commit comments

Comments
 (0)