Skip to content

Commit 44d4bac

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Fix quantized_linear cpp op schema
Summary: The cpp op schema does not match the registered one. Fix that. Reviewed By: tarun292, cccclai Differential Revision: D56594373 fbshipit-source-id: cb4853030715245e7a0177c0f193c4558f19584d
1 parent 7b3f5c6 commit 44d4bac

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

examples/cadence/ops/functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
- arg_meta: null
6161
kernel_name: impl::HiFi::quantized_layer_norm_out
6262

63-
- func: cadence::quantized_linear.out(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
63+
- func: cadence::quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
6464
kernels:
6565
- arg_meta: null
6666
kernel_name: impl::HiFi::quantized_linear_out

examples/cadence/ops/quantized_linear_out.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,12 @@ void quantized_linear_out(
2424
const Tensor& src,
2525
const Tensor& weight,
2626
const Tensor& bias,
27-
double src_scale,
2827
int64_t src_zero_point,
29-
double weight_scale,
30-
int64_t weight_zero_point,
28+
const Tensor& weight_zero_point,
3129
const Tensor& out_multiplier,
3230
const Tensor& out_shift,
3331
int64_t out_zero_point,
32+
const exec_aten::optional<Tensor>& offset,
3433
Tensor& out) {
3534
// input comes in shape [leading_dims, in_dim]
3635
// weight comes in shape [out_dim, in_dim]
@@ -58,7 +57,7 @@ void quantized_linear_out(
5857
in_dim, // vec_offset of p_mat2.
5958
out_dim, // out_offset, i.e., offset of next output element written
6059
1, // out_stride, i.e., stride to go to next output row
61-
-weight_zero_point, // mat1_zero_bias
60+
-weight_zero_point.const_data_ptr<int32_t>()[0], // mat1_zero_bias
6261
-src_zero_point, // mat2_zero_bias
6362
out_multiplier.const_data_ptr<int32_t>(), // out_multiplier
6463
out_shift.const_data_ptr<int32_t>(), // out_shift

0 commit comments

Comments
 (0)