Skip to content

Commit 6d6630e

Browse files
authored
register quantized_linear.per_tensor in lib
Differential Revision: D65104400 Pull Request resolved: #6563
1 parent 1cd8a06 commit 6d6630e

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@
5050
"quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
5151
)
5252
lib.define(
53-
"cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
53+
"quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
54+
)
55+
lib.define(
56+
"quantized_linear.per_tensor(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, "
57+
"SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset) -> Tensor"
5458
)
5559

5660
lib.define(
@@ -129,6 +133,28 @@ def quantized_linear_meta(
129133
return src.new_empty(out_size, dtype=src.dtype)
130134

131135

136+
@register_fake("cadence::quantized_linear.per_tensor")
137+
def quantized_linear_per_tensor_meta(
138+
src: torch.Tensor,
139+
weight: torch.Tensor,
140+
bias: torch.Tensor,
141+
in_zero_point: torch.SymInt,
142+
weight_zero_point: torch.SymInt,
143+
out_multiplier: torch.SymInt,
144+
out_shift: torch.SymInt,
145+
out_zero_point: torch.SymInt,
146+
offset: Optional[torch.Tensor],
147+
) -> torch.Tensor:
148+
# src comes in shape [leading_dims, in_dim]
149+
# weight comes in shape [out_dim, in_dim]
150+
# output comes in empty with shape [leading_dims, out_dim]
151+
out_size = list(src.size())
152+
weight_size = list(weight.size())
153+
assert len(weight_size) == 2
154+
out_size[-1] = weight_size[0]
155+
return src.new_empty(out_size, dtype=src.dtype)
156+
157+
132158
@register_fake("cadence::quantized_conv")
133159
def quantized_conv_meta(
134160
input: torch.Tensor,

0 commit comments

Comments
 (0)