4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ from typing import Tuple
8
+
7
9
import torch
8
10
from executorch .exir .scalar_type import ScalarType
9
11
from torch .library import impl , Library
10
12
13
+ from .utils import get_conv1d_output_size
14
+
11
15
lib = Library ("xtensa" , "DEF" )
12
16
13
17
lib .define (
25
29
)
26
30
27
31
lib .define (
28
- "quantized_linear_pt2(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point) -> (Tensor Z)"
32
+ "quantized_linear(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point) -> (Tensor Z)"
33
+ )
34
+ lib .define (
35
+ "quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
36
+ )
37
+
38
+ lib .define (
39
+ "quantized_conv(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False) -> (Tensor Z)"
29
40
)
30
41
lib .define (
31
- "quantized_linear_pt2 .out(Tensor src , Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale , int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point , *, Tensor(a!) out) -> Tensor(a!)"
42
+ "quantized_conv .out(Tensor input , Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups , int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False , *, Tensor(a!) out) -> Tensor(a!)"
32
43
)
33
44
34
45
m = Library ("xtensa" , "IMPL" , "Meta" )
@@ -58,17 +69,15 @@ def dequantize_per_tensor_meta(
58
69
return input .new_empty (input .size (), dtype = torch .float )
59
70
60
71
61
- @impl (m , "quantized_linear_pt2 " )
62
- def quantized_linear_pt2_meta (
72
+ @impl (m , "quantized_linear " )
73
+ def quantized_linear_meta (
63
74
src : torch .Tensor ,
64
75
weight : torch .Tensor ,
65
76
bias : torch .Tensor ,
66
- in_scale : float ,
67
77
in_zero_point : int ,
68
- weight_scale : float ,
69
- weight_zero_point : int ,
70
- out_multiplier : int ,
71
- out_shift : int ,
78
+ weight_zero_point : torch .Tensor ,
79
+ out_multiplier : torch .Tensor ,
80
+ out_shift : torch .Tensor ,
72
81
out_zero_point : int ,
73
82
):
74
83
# src comes in shape [leading_dims, in_dim]
@@ -79,3 +88,35 @@ def quantized_linear_pt2_meta(
79
88
assert len (weight_size ) == 2
80
89
out_size [- 1 ] = weight_size [0 ]
81
90
return src .new_empty (out_size , dtype = torch .uint8 )
91
+
92
+
93
+ @impl (m , "quantized_conv" )
94
+ def quantized_conv_meta (
95
+ input : torch .Tensor ,
96
+ weight : torch .Tensor ,
97
+ bias : torch .Tensor ,
98
+ stride : Tuple [int ],
99
+ padding : Tuple [int ],
100
+ dilation : Tuple [int ],
101
+ groups : int ,
102
+ in_zero_point : int ,
103
+ weight_zero_point : torch .Tensor ,
104
+ bias_scale : torch .Tensor ,
105
+ output_scale : float ,
106
+ output_zero_point : int ,
107
+ out_multiplier : torch .Tensor ,
108
+ out_shift : torch .Tensor ,
109
+ channel_last : bool = False ,
110
+ ):
111
+ out_channels , _in_channels , * kernel_size = weight .shape
112
+ in_size = input .shape
113
+ # Assert that the input tensor has at least 3 dimensions, and at most 6
114
+ assert len (in_size ) > 2
115
+ assert len (in_size ) < 6
116
+
117
+ # Compute the output tensor size
118
+ output_size = get_conv1d_output_size (
119
+ in_size , out_channels , stride [0 ], padding [0 ], dilation [0 ], kernel_size [0 ]
120
+ )
121
+
122
+ return input .new_empty (output_size , dtype = input .dtype )
0 commit comments