|
66 | 66 | lib.define(
|
67 | 67 | "quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
|
68 | 68 | )
|
| 69 | +lib.define( |
| 70 | + "quantized_conv.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False) -> (Tensor Z)" |
| 71 | +) |
| 72 | +lib.define( |
| 73 | + "quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" |
| 74 | +) |
69 | 75 |
|
70 | 76 | lib.define(
|
71 | 77 | "quantized_matmul(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)"
|
@@ -171,6 +177,54 @@ def quantized_conv_meta(
|
171 | 177 | return input.new_empty(output_size, dtype=input.dtype)
|
172 | 178 |
|
173 | 179 |
|
| 180 | +@register_fake("cadence::quantized_conv.per_tensor") |
| 181 | +def quantized_conv_per_tensor_meta( |
| 182 | + input: torch.Tensor, |
| 183 | + weight: torch.Tensor, |
| 184 | + bias: torch.Tensor, |
| 185 | + stride: Tuple[int], |
| 186 | + padding: Tuple[int], |
| 187 | + dilation: Tuple[int], |
| 188 | + groups: int, |
| 189 | + in_zero_point: int, |
| 190 | + weight_zero_point: int, |
| 191 | + bias_scale: float, |
| 192 | + output_scale: float, |
| 193 | + output_zero_point: int, |
| 194 | + out_multiplier: int, |
| 195 | + out_shift: int, |
| 196 | + channel_last: bool = False, |
| 197 | +) -> torch.Tensor: |
| 198 | + if channel_last: |
| 199 | + out_channels, *kernel_size, _ = weight.shape |
| 200 | + else: |
| 201 | + out_channels, _, *kernel_size = weight.shape |
| 202 | + |
| 203 | + in_size = input.shape |
| 204 | + # Assert that the input tensor has at least 3 dimensions, and at most 6 |
| 205 | + assert len(in_size) > 2 |
| 206 | + assert len(in_size) < 6 |
| 207 | + |
| 208 | + # Compute the output tensor size |
| 209 | + output_size = ( |
| 210 | + get_conv1d_output_size( |
| 211 | + in_size, |
| 212 | + out_channels, |
| 213 | + stride[1], |
| 214 | + padding[1], |
| 215 | + dilation[1], |
| 216 | + kernel_size[0], |
| 217 | + channel_last, |
| 218 | + ) |
| 219 | + if len(in_size) == 3 |
| 220 | + else get_conv2d_output_size( |
| 221 | + in_size, out_channels, stride, padding, dilation, kernel_size, channel_last |
| 222 | + ) |
| 223 | + ) |
| 224 | + |
| 225 | + return input.new_empty(output_size, dtype=input.dtype) |
| 226 | + |
| 227 | + |
174 | 228 | @register_fake("cadence::quantized_layer_norm")
|
175 | 229 | def quantized_layer_norm_meta(
|
176 | 230 | input: torch.Tensor,
|
|
0 commit comments