Skip to content

Commit 6051b2f

Browse files
authored
Add per_tensor overload for quantized_conv
Differential Revision: D65306801 Pull Request resolved: #6648
1 parent 785ebf3 commit 6051b2f

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@
6666
lib.define(
6767
"quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
6868
)
69+
lib.define(
70+
"quantized_conv.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False) -> (Tensor Z)"
71+
)
72+
lib.define(
73+
"quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
74+
)
6975

7076
lib.define(
7177
"quantized_matmul(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)"
@@ -171,6 +177,54 @@ def quantized_conv_meta(
171177
return input.new_empty(output_size, dtype=input.dtype)
172178

173179

180+
@register_fake("cadence::quantized_conv.per_tensor")
181+
def quantized_conv_per_tensor_meta(
182+
input: torch.Tensor,
183+
weight: torch.Tensor,
184+
bias: torch.Tensor,
185+
stride: Tuple[int],
186+
padding: Tuple[int],
187+
dilation: Tuple[int],
188+
groups: int,
189+
in_zero_point: int,
190+
weight_zero_point: int,
191+
bias_scale: float,
192+
output_scale: float,
193+
output_zero_point: int,
194+
out_multiplier: int,
195+
out_shift: int,
196+
channel_last: bool = False,
197+
) -> torch.Tensor:
198+
if channel_last:
199+
out_channels, *kernel_size, _ = weight.shape
200+
else:
201+
out_channels, _, *kernel_size = weight.shape
202+
203+
in_size = input.shape
204+
# Assert that the input tensor has at least 3 dimensions, and at most 6
205+
assert len(in_size) > 2
206+
assert len(in_size) < 6
207+
208+
# Compute the output tensor size
209+
output_size = (
210+
get_conv1d_output_size(
211+
in_size,
212+
out_channels,
213+
stride[1],
214+
padding[1],
215+
dilation[1],
216+
kernel_size[0],
217+
channel_last,
218+
)
219+
if len(in_size) == 3
220+
else get_conv2d_output_size(
221+
in_size, out_channels, stride, padding, dilation, kernel_size, channel_last
222+
)
223+
)
224+
225+
return input.new_empty(output_size, dtype=input.dtype)
226+
227+
174228
@register_fake("cadence::quantized_layer_norm")
175229
def quantized_layer_norm_meta(
176230
input: torch.Tensor,

0 commit comments

Comments
 (0)