|
50 | 50 | "quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
|
51 | 51 | )
|
52 | 52 | lib.define(
|
53 |
| - "cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" |
| 53 | + "quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" |
| 54 | +) |
| 55 | +lib.define( |
| 56 | + "quantized_linear.per_tensor(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, " |
| 57 | + "SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset) -> Tensor" |
54 | 58 | )
|
55 | 59 |
|
56 | 60 | lib.define(
|
@@ -129,6 +133,28 @@ def quantized_linear_meta(
|
129 | 133 | return src.new_empty(out_size, dtype=src.dtype)
|
130 | 134 |
|
131 | 135 |
|
| 136 | +@register_fake("cadence::quantized_linear.per_tensor") |
| 137 | +def quantized_linear_per_tensor_meta( |
| 138 | + src: torch.Tensor, |
| 139 | + weight: torch.Tensor, |
| 140 | + bias: torch.Tensor, |
| 141 | + in_zero_point: torch.SymInt, |
| 142 | + weight_zero_point: torch.SymInt, |
| 143 | + out_multiplier: torch.SymInt, |
| 144 | + out_shift: torch.SymInt, |
| 145 | + out_zero_point: torch.SymInt, |
| 146 | + offset: Optional[torch.Tensor], |
| 147 | +) -> torch.Tensor: |
| 148 | + # src comes in shape [leading_dims, in_dim] |
| 149 | + # weight comes in shape [out_dim, in_dim] |
| 150 | + # output comes in empty with shape [leading_dims, out_dim] |
| 151 | + out_size = list(src.size()) |
| 152 | + weight_size = list(weight.size()) |
| 153 | + assert len(weight_size) == 2 |
| 154 | + out_size[-1] = weight_size[0] |
| 155 | + return src.new_empty(out_size, dtype=src.dtype) |
| 156 | + |
| 157 | + |
132 | 158 | @register_fake("cadence::quantized_conv")
|
133 | 159 | def quantized_conv_meta(
|
134 | 160 | input: torch.Tensor,
|
|
0 commit comments