4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ from typing import Tuple
8
+
7
9
import torch
10
+ from .utils import get_conv1d_output_size
8
11
from executorch .exir .scalar_type import ScalarType
9
12
from torch .library import impl , Library
10
13
25
28
)
26
29
27
30
lib .define (
28
- "quantized_linear_pt2(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point) -> (Tensor Z)"
31
+ "quantized_linear(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point) -> (Tensor Z)"
32
+ )
33
+ lib .define (
34
+ "quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
35
+ )
36
+
37
+ lib .define (
38
+ "quantized_conv(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False) -> (Tensor Z)"
29
39
)
30
40
lib .define (
31
- "quantized_linear_pt2 .out(Tensor src , Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale , int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point , *, Tensor(a!) out) -> Tensor(a!)"
41
+ "quantized_conv .out(Tensor input , Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups , int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False , *, Tensor(a!) out) -> Tensor(a!)"
32
42
)
33
43
34
44
m = Library ("xtensa" , "IMPL" , "Meta" )
@@ -58,17 +68,15 @@ def dequantize_per_tensor_meta(
58
68
return input .new_empty (input .size (), dtype = torch .float )
59
69
60
70
61
- @impl (m , "quantized_linear_pt2 " )
62
- def quantized_linear_pt2_meta (
71
+ @impl (m , "quantized_linear " )
72
+ def quantized_linear_meta (
63
73
src : torch .Tensor ,
64
74
weight : torch .Tensor ,
65
75
bias : torch .Tensor ,
66
- in_scale : float ,
67
76
in_zero_point : int ,
68
- weight_scale : float ,
69
- weight_zero_point : int ,
70
- out_multiplier : int ,
71
- out_shift : int ,
77
+ weight_zero_point : torch .Tensor ,
78
+ out_multiplier : torch .Tensor ,
79
+ out_shift : torch .Tensor ,
72
80
out_zero_point : int ,
73
81
):
74
82
# src comes in shape [leading_dims, in_dim]
@@ -79,3 +87,35 @@ def quantized_linear_pt2_meta(
79
87
assert len (weight_size ) == 2
80
88
out_size [- 1 ] = weight_size [0 ]
81
89
return src .new_empty (out_size , dtype = torch .uint8 )
90
+
91
+
92
+ @impl (m , "quantized_conv" )
93
+ def quantized_conv_meta (
94
+ input : torch .Tensor ,
95
+ weight : torch .Tensor ,
96
+ bias : torch .Tensor ,
97
+ stride : Tuple [int ],
98
+ padding : Tuple [int ],
99
+ dilation : Tuple [int ],
100
+ groups : int ,
101
+ in_zero_point : int ,
102
+ weight_zero_point : torch .Tensor ,
103
+ bias_scale : torch .Tensor ,
104
+ output_scale : float ,
105
+ output_zero_point : int ,
106
+ out_multiplier : torch .Tensor ,
107
+ out_shift : torch .Tensor ,
108
+ channel_last : bool = False ,
109
+ ):
110
+ out_channels , _in_channels , * kernel_size = weight .shape
111
+ in_size = input .shape
112
+ # Assert that the input tensor has at least 3 dimensions, and at most 6
113
+ assert len (in_size ) > 2
114
+ assert len (in_size ) < 6
115
+
116
+ # Compute the output tensor size
117
+ output_size = get_conv1d_output_size (
118
+ in_size , out_channels , stride [0 ], padding [0 ], dilation [0 ], kernel_size [0 ]
119
+ )
120
+
121
+ return input .new_empty (output_size , dtype = input .dtype )
0 commit comments