@@ -553,23 +553,39 @@ def preprocess( # noqa: C901
553
553
elif exir_ops .edge .aten .convolution .default == node .target :
554
554
input , weight , bias , stride , pad , dilation , _ , _ , group = inputs
555
555
556
+ # Currently only int8 is supported in quantized types.
557
+ actual_out_type = ts .DType .INT8 if is_quant_node else outp .dtype
558
+
556
559
## Transpose input tensor to NHWC_Order for TOSA
557
560
NHWC_Order = [0 , 2 , 3 , 1 ]
558
561
input_transposed = transpose_helper (
559
- tosa_fb , input , NHWC_Order , outp . dtype
562
+ tosa_fb , input , NHWC_Order , actual_out_type
560
563
)
561
564
562
- ## CONV2DOp
565
+ # Get the attributes of convolution.
563
566
attr = ts .TosaSerializerAttribute ()
564
- # PAD
565
567
pad_attr = [val for val in pad .special for _ in (0 , 1 )]
566
- # Stride
567
568
stride_attr = stride .special
568
- # Dilation
569
569
dilation_attr = dilation .special
570
570
attr .ConvAttribute (pad_attr , stride_attr , dilation_attr , 0 , 0 )
571
571
572
+ # Non-bias case.
573
+ if len (node .all_input_nodes ) == 2 :
574
+ # Create a zero bias tensor if not presented
575
+ out_channels = weight .shape [0 ]
576
+ bias_name = "bias" + node .name .split ("default" , 1 )[1 ]
577
+ bias = tosa_fb .addConst (
578
+ [out_channels ],
579
+ ts .DType .INT32 if is_quant_node else outp .dtype ,
580
+ [0 ] * out_channels ,
581
+ name = bias_name ,
582
+ )
583
+
572
584
if group .number > 1 :
585
+ assert (
586
+ is_quant_node is False
587
+ ), "quantized depthwise convolution is not supported yet in BI mode"
588
+
573
589
# Transpose weight to [KH, KW, C, M]
574
590
weight_HWCM_Order = [2 , 3 , 0 , 1 ]
575
591
weight_transposed = transpose_helper (
@@ -600,14 +616,17 @@ def preprocess( # noqa: C901
600
616
# Transpose weight to [OC, H, W, IC]
601
617
weight_CHWC_Order = [0 , 2 , 3 , 1 ]
602
618
weight_transposed = transpose_helper (
603
- tosa_fb , weight , weight_CHWC_Order , outp . dtype
619
+ tosa_fb , weight , weight_CHWC_Order , actual_out_type
604
620
)
605
621
606
622
## TOSA output shape is [NHWO]
607
623
NHWO_Order = [0 , 2 , 3 , 1 ]
608
624
out_shape_TOSA_CONV2D = [outp .shape [i ] for i in NHWO_Order ]
625
+
626
+ # The output type is int32 when input type is int8.
609
627
conv2d_res = tosa_fb .addIntermediate (
610
- out_shape_TOSA_CONV2D , outp .dtype
628
+ out_shape_TOSA_CONV2D ,
629
+ ts .DType .INT32 if is_quant_node else outp .dtype ,
611
630
)
612
631
tosa_fb .addOperator (
613
632
TosaOp .Op ().CONV2D ,
@@ -624,6 +643,24 @@ def preprocess( # noqa: C901
624
643
NOHW_Order = [0 , 3 , 1 , 2 ]
625
644
attr_output_transpose = ts .TosaSerializerAttribute ()
626
645
attr_output_transpose .TransposeAttribute (NOHW_Order )
646
+
647
+ # For quantized convolution, rescale the output value back to the same
648
+ # integer value domain of the next op. Otherwise return float32 output.
649
+ if is_quant_node :
650
+ # Get scale_factor from input, weight, and output.
651
+ _ , input_scale , _ , _ , _ , _ = getNodeArgs (node .args [0 ])
652
+ _ , weight_scale , _ , _ , _ , _ = getNodeArgs (node .args [1 ])
653
+ _ , output_scale , _ , _ , _ , _ = getNodeArgs (list (node .users )[0 ])
654
+
655
+ conv2d_res = tosa_quant_utils .buildRescaleOpConvOutput (
656
+ tosa_fb ,
657
+ conv2d_res ,
658
+ actual_out_type ,
659
+ input_scale ,
660
+ weight_scale ,
661
+ output_scale ,
662
+ )
663
+
627
664
tosa_fb .addOperator (
628
665
TosaOp .Op ().TRANSPOSE ,
629
666
[conv2d_res .name ],
@@ -879,7 +916,7 @@ def preprocess( # noqa: C901
879
916
p_data = edge_program .state_dict [parameter_name ]
880
917
881
918
assert isinstance (p_data , torch .Tensor ), "Expect Attr to be tensor"
882
- weight_values = p_data .detach ().numpy ()
919
+ parameter_values = p_data .detach ().numpy ()
883
920
884
921
# Check if they're for quantized nodes
885
922
consumer_node = list (node .users )[0 ]
@@ -888,14 +925,14 @@ def preprocess( # noqa: C901
888
925
consumer_node
889
926
)
890
927
891
- weight_values_quantized = (
892
- (weight_values / weight_node_scale .number )
928
+ parameter_values_quantized = (
929
+ (parameter_values / weight_node_scale .number )
893
930
+ weight_node_zp .number
894
931
).astype (np .int8 )
895
932
tosa_fb .addConst (
896
933
inputs [0 ].shape ,
897
934
ts .DType .INT8 ,
898
- weight_values_quantized ,
935
+ parameter_values_quantized ,
899
936
name = out ,
900
937
)
901
938
elif (
@@ -914,30 +951,55 @@ def preprocess( # noqa: C901
914
951
weight_node
915
952
)
916
953
917
- weight_values_quantized = (
918
- weight_values / (input_node_scale * weight_node_scale )
954
+ parameter_values_quantized = (
955
+ parameter_values / (input_node_scale * weight_node_scale )
956
+ ).astype (np .int32 )
957
+
958
+ tosa_fb .addConst (
959
+ inputs [0 ].shape ,
960
+ ts .DType .INT32 ,
961
+ parameter_values_quantized ,
962
+ name = out ,
963
+ )
964
+ elif (
965
+ consumer_node .target == exir_ops .edge .aten .convolution .default
966
+ and list (consumer_node .users )[0 ].target == tosa_quant_utils .q_op
967
+ ):
968
+ (
969
+ input_node ,
970
+ weight_node ,
971
+ bias_node ,
972
+ ) = consumer_node .all_input_nodes
973
+
974
+ input_node_scale , _ = getQuantNodeArgs (input_node )
975
+ weight_node_scale , _ = getQuantNodeArgs (weight_node )
976
+
977
+ bias_scales = input_node_scale * weight_node_scale
978
+ parameter_values_quantized = (
979
+ parameter_values / bias_scales
919
980
).astype (np .int32 )
920
981
921
982
tosa_fb .addConst (
922
983
inputs [0 ].shape ,
923
984
ts .DType .INT32 ,
924
- weight_values_quantized ,
985
+ parameter_values_quantized ,
925
986
name = out ,
926
987
)
927
988
else :
928
989
tosa_fb .addConst (
929
- inputs [0 ].shape , inputs [0 ].dtype , weight_values , name = out
990
+ inputs [0 ].shape , inputs [0 ].dtype , parameter_values , name = out
930
991
)
992
+
931
993
elif out in edge_program .graph_signature .inputs_to_buffers :
932
994
parameter_name = edge_program .graph_signature .inputs_to_buffers [
933
995
node .name
934
996
]
935
997
p_data = edge_program .state_dict [parameter_name ]
936
998
937
999
assert isinstance (p_data , torch .Tensor ), "Expect Attr to be tensor"
938
- weight_values = p_data .detach ().numpy ()
1000
+ buffer_values = p_data .detach ().numpy ()
939
1001
tosa_fb .addConst (
940
- inputs [0 ].shape , inputs [0 ].dtype , weight_values , name = out
1002
+ inputs [0 ].shape , inputs [0 ].dtype , buffer_values , name = out
941
1003
)
942
1004
else :
943
1005
tensor = ts .TosaSerializerTensor (
0 commit comments