Skip to content

Commit f9ffc55

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Define embedding_4bit (#3121)
Summary: Pull Request resolved: #3121 Differential Revision: D56246346
1 parent e69a662 commit f9ffc55

File tree

1 file changed

+131
-3
lines changed

1 file changed

+131
-3
lines changed

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 131 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
)
4747

4848

49-
def embedding_byte_weight_checks(weight, weight_scales, weight_zero_points):
49+
def embedding_weight_checks(weight, weight_scales, weight_zero_points):
5050
assert weight.dtype in [
5151
torch.int8,
5252
torch.uint8,
@@ -86,7 +86,7 @@ def embedding_byte(
8686
weight_quant_max: int,
8787
indices: torch.Tensor,
8888
) -> torch.Tensor:
89-
embedding_byte_weight_checks(weight, weight_scales, weight_zero_points)
89+
embedding_weight_checks(weight, weight_scales, weight_zero_points)
9090
group_size = weight.size(1) // (
9191
weight_scales.size(1) if weight_scales.dim() == 2 else 1
9292
)
@@ -133,7 +133,7 @@ def embedding_byte_dtype(
133133
indices: torch.Tensor,
134134
dtype: Optional[torch.dtype],
135135
) -> torch.Tensor:
136-
embedding_byte_weight_checks(weight, weight_scales, weight_zero_points)
136+
embedding_weight_checks(weight, weight_scales, weight_zero_points)
137137
group_size = weight.size(1) // (
138138
weight_scales.size(1) if weight_scales.dim() == 2 else 1
139139
)
@@ -172,6 +172,134 @@ def embedding_byte_dtype_out_meta(
172172
)
173173

174174

175+
quantized_decomposed_lib.define(
176+
"embedding_4bit(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
177+
"int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor",
178+
)
179+
180+
quantized_decomposed_lib.define(
181+
"embedding_4bit.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
182+
"int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None) -> Tensor",
183+
)
184+
185+
quantized_decomposed_lib.define(
186+
"embedding_4bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
187+
"int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)",
188+
)
189+
190+
quantized_decomposed_lib.define(
191+
"embedding_4bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
192+
"int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)",
193+
)
194+
195+
196+
@impl(quantized_decomposed_lib, "embedding_4bit", "CompositeExplicitAutograd")
197+
def embedding_4bit(
198+
weight: torch.Tensor,
199+
weight_scales: torch.Tensor,
200+
weight_zero_points: Optional[torch.Tensor],
201+
weight_quant_min: int,
202+
weight_quant_max: int,
203+
indices: torch.Tensor,
204+
) -> torch.Tensor:
205+
embedding_weight_checks(weight, weight_scales, weight_zero_points)
206+
group_size = (2 * weight.size(1)) // (
207+
weight_scales.size(1) if weight_scales.dim() == 2 else 1
208+
)
209+
weight_even = weight.div(16, rounding_mode="trunc")
210+
weight_odd = weight.remainder(16)
211+
weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1)
212+
weight = weight_unpacked.view(weight.shape[0], -1)
213+
weight = weight.view(torch.int8).add(-8)
214+
215+
weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
216+
weight,
217+
weight_scales,
218+
weight_zero_points,
219+
weight_quant_min,
220+
weight_quant_max,
221+
weight.dtype,
222+
group_size,
223+
weight_scales.dtype,
224+
)
225+
return torch.ops.aten.embedding.default(weight, indices)
226+
227+
228+
@impl_abstract("quantized_decomposed::embedding_4bit.out")
229+
def embedding_4bit_out_meta(
230+
weight: torch.Tensor,
231+
weight_scales: torch.Tensor,
232+
weight_zero_points: Optional[torch.Tensor],
233+
weight_quant_min: int,
234+
weight_quant_max: int,
235+
indices: torch.Tensor,
236+
out: torch.Tensor,
237+
) -> torch.Tensor:
238+
return embedding_4bit(
239+
weight,
240+
weight_scales,
241+
weight_zero_points,
242+
weight_quant_min,
243+
weight_quant_max,
244+
indices,
245+
)
246+
247+
248+
@impl(quantized_decomposed_lib, "embedding_4bit.dtype", "CompositeExplicitAutograd")
249+
def embedding_4bit_dtype(
250+
weight: torch.Tensor,
251+
weight_scales: torch.Tensor,
252+
weight_zero_points: Optional[torch.Tensor],
253+
weight_quant_min: int,
254+
weight_quant_max: int,
255+
indices: torch.Tensor,
256+
dtype: Optional[torch.dtype],
257+
) -> torch.Tensor:
258+
embedding_weight_checks(weight, weight_scales, weight_zero_points)
259+
group_size = (2 * weight.size(1)) // (
260+
weight_scales.size(1) if weight_scales.dim() == 2 else 1
261+
)
262+
weight_even = weight.div(16, rounding_mode="trunc")
263+
weight_odd = weight.remainder(16)
264+
weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1)
265+
weight = weight_unpacked.view(weight.shape[0], -1)
266+
weight = weight.view(torch.int8).add(-8)
267+
268+
weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
269+
weight,
270+
weight_scales,
271+
weight_zero_points,
272+
weight_quant_min,
273+
weight_quant_max,
274+
weight.dtype,
275+
group_size,
276+
dtype,
277+
)
278+
return torch.ops.aten.embedding.default(weight, indices)
279+
280+
281+
@impl_abstract("quantized_decomposed::embedding_4bit.dtype_out")
282+
def embedding_4bit_dtype_out_meta(
283+
weight: torch.Tensor,
284+
weight_scales: torch.Tensor,
285+
weight_zero_points: Optional[torch.Tensor],
286+
weight_quant_min: int,
287+
weight_quant_max: int,
288+
indices: torch.Tensor,
289+
dtype: Optional[torch.dtype],
290+
out: torch.Tensor,
291+
) -> torch.Tensor:
292+
return embedding_4bit_dtype(
293+
weight,
294+
weight_scales,
295+
weight_zero_points,
296+
weight_quant_min,
297+
weight_quant_max,
298+
indices,
299+
dtype,
300+
)
301+
302+
175303
quantized_decomposed_lib.define(
176304
"mixed_mm(Tensor input, Tensor weight, Tensor weight_scales, Tensor? weight_zero_points) -> Tensor",
177305
)

0 commit comments

Comments
 (0)