15
15
"llama_quantized" , "DEF"
16
16
) # to not be confused with torch.ops.quantized.* ops.
17
17
quantized_lib .define (
18
- "embedding_byte (Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
18
+ "DEPRECATED_DO_NOT_USE_embedding_byte (Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
19
19
"int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor" ,
20
20
)
21
21
22
22
quantized_lib .define (
23
- "embedding_byte .out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
23
+ "DEPRECATED_DO_NOT_USE_embedding_byte .out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
24
24
"int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)" ,
25
25
)
26
26
27
27
quantized_lib .define (
28
- "embedding_byte .dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
28
+ "DEPRECATED_DO_NOT_USE_embedding_byte .dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
29
29
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor" ,
30
30
)
31
31
32
32
quantized_lib .define (
33
- "embedding_byte .dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
33
+ "DEPRECATED_DO_NOT_USE_embedding_byte .dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
34
34
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)" ,
35
35
)
36
36
@@ -66,7 +66,9 @@ def embedding_byte_weight_checks(weight, weight_scales, weight_zero_points):
66
66
), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found { weight .size ()} and { weight_zero_points .size ()} "
67
67
68
68
69
- @impl (quantized_lib , "embedding_byte" , "CompositeExplicitAutograd" )
69
+ @impl (
70
+ quantized_lib , "DEPRECATED_DO_NOT_USE_embedding_byte" , "CompositeExplicitAutograd"
71
+ )
70
72
def embedding_byte (
71
73
weight : torch .Tensor ,
72
74
weight_scales : torch .Tensor ,
@@ -92,7 +94,7 @@ def embedding_byte(
92
94
return torch .ops .aten .embedding .default (weight , indices )
93
95
94
96
95
- @impl_abstract ("llama_quantized::embedding_byte .out" )
97
+ @impl_abstract ("llama_quantized::DEPRECATED_DO_NOT_USE_embedding_byte .out" )
96
98
def embedding_byte_out_meta (
97
99
weight : torch .Tensor ,
98
100
weight_scales : torch .Tensor ,
@@ -112,7 +114,11 @@ def embedding_byte_out_meta(
112
114
)
113
115
114
116
115
- @impl (quantized_lib , "embedding_byte.dtype" , "CompositeExplicitAutograd" )
117
+ @impl (
118
+ quantized_lib ,
119
+ "DEPRECATED_DO_NOT_USE_embedding_byte.dtype" ,
120
+ "CompositeExplicitAutograd" ,
121
+ )
116
122
def embedding_byte_dtype (
117
123
weight : torch .Tensor ,
118
124
weight_scales : torch .Tensor ,
@@ -140,7 +146,7 @@ def embedding_byte_dtype(
140
146
return torch .ops .aten .embedding .default (weight , indices )
141
147
142
148
143
- @impl_abstract ("llama_quantized::embedding_byte .dtype_out" )
149
+ @impl_abstract ("llama_quantized::DEPRECATED_DO_NOT_USE_embedding_byte .dtype_out" )
144
150
def embedding_byte_dtype_out_meta (
145
151
weight : torch .Tensor ,
146
152
weight_scales : torch .Tensor ,
0 commit comments