146
146
"quantized_fully_connected(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
147
147
"Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
148
148
)
149
-
149
+ lib .define (
150
+ "quantized_fully_connected.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
151
+ "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
152
+ )
150
153
151
154
# ------------------------------------ #
152
155
# Migrated from custom_ops.ymal #
192
195
"quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
193
196
"Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
194
197
)
198
+ lib .define (
199
+ "quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
200
+ "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
201
+ )
195
202
lib .define (
196
203
"quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, "
197
204
"Tensor indices, bool pruned_weights=False, *, Tensor(a!) out) -> Tensor(a!)"
@@ -595,6 +602,28 @@ def quantized_fully_connected_meta(
595
602
bias : torch .Tensor ,
596
603
in_zero_point : int ,
597
604
weight_zero_point : torch .Tensor ,
605
+ out_multiplier : torch .Tensor ,
606
+ out_shift : torch .Tensor ,
607
+ out_zero_point : int ,
608
+ offset : Optional [torch .Tensor ],
609
+ ) -> torch .Tensor :
610
+ # src comes in shape [leading_dims, in_dim]
611
+ # weight comes in shape [out_dim, in_dim]
612
+ # output comes in empty with shape [leading_dims, out_dim]
613
+ out_size = list (src .size ())
614
+ weight_size = list (weight .size ())
615
+ assert len (weight_size ) == 2
616
+ out_size [- 1 ] = weight_size [0 ]
617
+ return src .new_empty (out_size , dtype = src .dtype )
618
+
619
+
620
+ @register_fake ("cadence::quantized_fully_connected.per_tensor" )
621
+ def quantized_fully_connected_per_tensor_meta (
622
+ src : torch .Tensor ,
623
+ weight : torch .Tensor ,
624
+ bias : torch .Tensor ,
625
+ in_zero_point : int ,
626
+ weight_zero_point : int ,
598
627
out_multiplier : int ,
599
628
out_shift : int ,
600
629
out_zero_point : int ,
@@ -607,7 +636,7 @@ def quantized_fully_connected_meta(
607
636
weight_size = list (weight .size ())
608
637
assert len (weight_size ) == 2
609
638
out_size [- 1 ] = weight_size [0 ]
610
- return src .new_empty (out_size , dtype = torch . uint8 )
639
+ return src .new_empty (out_size , dtype = src . dtype )
611
640
612
641
613
642
@register_fake ("cadence::convolution" )
0 commit comments