Skip to content

Commit 38cfb8d

Browse files
committed
Update on "Add quantized op support to llama runner"
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D56197863](https://our.internmc.facebook.com/intern/diff/D56197863) [ghstack-poisoned]
2 parents 020dc4e + ac60837 commit 38cfb8d

File tree

3 files changed

+17
-11
lines changed

3 files changed

+17
-11
lines changed

examples/models/llama2/ops/quantized.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
- func: llama_quantized::embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)
1+
- func: llama_quantized::DEPRECATED_DO_NOT_USE_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)
22
variants: function
33
kernels:
44
- arg_meta: null
55
kernel_name: torch::executor::quantized_embedding_byte_out
66

7-
- func: llama_quantized::embedding_byte.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)
7+
- func: llama_quantized::DEPRECATED_DO_NOT_USE_embedding_byte.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)
88
variants: function
99
kernels:
1010
- arg_meta: null

examples/models/llama2/ops/quantized_ops.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,22 @@
1515
"llama_quantized", "DEF"
1616
) # to not be confused with torch.ops.quantized.* ops.
1717
quantized_lib.define(
18-
"embedding_byte(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
18+
"DEPRECATED_DO_NOT_USE_embedding_byte(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
1919
"int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor",
2020
)
2121

2222
quantized_lib.define(
23-
"embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
23+
"DEPRECATED_DO_NOT_USE_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
2424
"int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)",
2525
)
2626

2727
quantized_lib.define(
28-
"embedding_byte.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
28+
"DEPRECATED_DO_NOT_USE_embedding_byte.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
2929
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor",
3030
)
3131

3232
quantized_lib.define(
33-
"embedding_byte.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
33+
"DEPRECATED_DO_NOT_USE_embedding_byte.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
3434
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)",
3535
)
3636

@@ -66,7 +66,9 @@ def embedding_byte_weight_checks(weight, weight_scales, weight_zero_points):
6666
), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}"
6767

6868

69-
@impl(quantized_lib, "embedding_byte", "CompositeExplicitAutograd")
69+
@impl(
70+
quantized_lib, "DEPRECATED_DO_NOT_USE_embedding_byte", "CompositeExplicitAutograd"
71+
)
7072
def embedding_byte(
7173
weight: torch.Tensor,
7274
weight_scales: torch.Tensor,
@@ -92,7 +94,7 @@ def embedding_byte(
9294
return torch.ops.aten.embedding.default(weight, indices)
9395

9496

95-
@impl_abstract("llama_quantized::embedding_byte.out")
97+
@impl_abstract("llama_quantized::DEPRECATED_DO_NOT_USE_embedding_byte.out")
9698
def embedding_byte_out_meta(
9799
weight: torch.Tensor,
98100
weight_scales: torch.Tensor,
@@ -112,7 +114,11 @@ def embedding_byte_out_meta(
112114
)
113115

114116

115-
@impl(quantized_lib, "embedding_byte.dtype", "CompositeExplicitAutograd")
117+
@impl(
118+
quantized_lib,
119+
"DEPRECATED_DO_NOT_USE_embedding_byte.dtype",
120+
"CompositeExplicitAutograd",
121+
)
116122
def embedding_byte_dtype(
117123
weight: torch.Tensor,
118124
weight_scales: torch.Tensor,
@@ -140,7 +146,7 @@ def embedding_byte_dtype(
140146
return torch.ops.aten.embedding.default(weight, indices)
141147

142148

143-
@impl_abstract("llama_quantized::embedding_byte.dtype_out")
149+
@impl_abstract("llama_quantized::DEPRECATED_DO_NOT_USE_embedding_byte.dtype_out")
144150
def embedding_byte_dtype_out_meta(
145151
weight: torch.Tensor,
146152
weight_scales: torch.Tensor,

examples/models/llama2/quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def __init__(
377377

378378
@torch.no_grad()
379379
def forward(self, indices: torch.Tensor) -> torch.Tensor:
380-
return torch.ops.llama_quantized.embedding_byte.dtype(
380+
return torch.ops.llama_quantized.DEPRECATED_DO_NOT_USE_embedding_byte.dtype(
381381
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
382382
)
383383

0 commit comments

Comments
 (0)