@@ -476,22 +476,37 @@ def preprocess( # noqa: C901
476
476
elif exir_ops .edge .aten .convolution .default == node .target :
477
477
input , weight , bias , stride , pad , dilation , _ , _ , group = inputs
478
478
479
+ # Currently only int8 is supported in quantized types.
480
+ actual_out_type = ts .DType .INT8 if is_quant_node else outp .dtype
481
+
479
482
## Transpose input tensor to NHWC_Order for TOSA
480
483
NHWC_Order = [0 , 2 , 3 , 1 ]
481
484
input_transposed = transpose_helper (
482
- tosa_fb , input , NHWC_Order , outp . dtype
485
+ tosa_fb , input , NHWC_Order , actual_out_type
483
486
)
484
487
485
- ## CONV2DOp
488
+ # Get the attributes of convolution.
486
489
attr = ts .TosaSerializerAttribute ()
487
- # PAD
488
490
pad_attr = [val for val in pad .special for _ in (0 , 1 )]
489
- # Stride
490
491
stride_attr = stride .special
491
- # Dilation
492
492
dilation_attr = dilation .special
493
493
attr .ConvAttribute (pad_attr , stride_attr , dilation_attr , 0 , 0 )
494
494
495
+ # Non-bias case.
496
+ if len (node .all_input_nodes ) == 2 :
497
+ assert (
498
+ is_quant_node == False
499
+ ), "currently non-bias convolution is not supported yet in BI mode"
500
+ # Create a zero bias tensor if not presented
501
+ out_channels = weight .shape [0 ]
502
+ bias_name = "bias" + node .name .split ("default" , 1 )[1 ]
503
+ bias = tosa_fb .addConst (
504
+ [out_channels ],
505
+ outp .dtype ,
506
+ [0 ] * out_channels ,
507
+ name = bias_name ,
508
+ )
509
+
495
510
if group .number > 1 :
496
511
# Transpose weight to [KH, KW, C, M]
497
512
weight_HWCM_Order = [2 , 3 , 0 , 1 ]
@@ -523,14 +538,17 @@ def preprocess( # noqa: C901
523
538
# Transpose weight to [OC, H, W, IC]
524
539
weight_CHWC_Order = [0 , 2 , 3 , 1 ]
525
540
weight_transposed = transpose_helper (
526
- tosa_fb , weight , weight_CHWC_Order , outp . dtype
541
+ tosa_fb , weight , weight_CHWC_Order , actual_out_type
527
542
)
528
543
529
544
## TOSA output shape is [NHWO]
530
545
NHWO_Order = [0 , 2 , 3 , 1 ]
531
546
out_shape_TOSA_CONV2D = [outp .shape [i ] for i in NHWO_Order ]
547
+
548
+ # The output type is int32 when input type is int8.
532
549
conv2d_res = tosa_fb .addIntermediate (
533
- out_shape_TOSA_CONV2D , outp .dtype
550
+ out_shape_TOSA_CONV2D ,
551
+ ts .DType .INT32 if is_quant_node else outp .dtype ,
534
552
)
535
553
tosa_fb .addOperator (
536
554
TosaOp .Op ().CONV2D ,
@@ -547,6 +565,24 @@ def preprocess( # noqa: C901
547
565
NOHW_Order = [0 , 3 , 1 , 2 ]
548
566
attr_output_transpose = ts .TosaSerializerAttribute ()
549
567
attr_output_transpose .TransposeAttribute (NOHW_Order )
568
+
569
+ # For quantized convolution, rescale the output value back to the same
570
+ # integer value domain of the next op. Otherwise return float32 output.
571
+ if is_quant_node :
572
+ # Get scale_factor from input, weight, and output.
573
+ _ , input_scale , _ , _ , _ , _ = getNodeArgs (node .args [0 ])
574
+ _ , weight_scale , _ , _ , _ , _ = getNodeArgs (node .args [1 ])
575
+ _ , output_scale , _ , _ , _ , _ = getNodeArgs (list (node .users )[0 ])
576
+
577
+ conv2d_res = tosa_quant_utils .buildRescaleOpConvOutput (
578
+ tosa_fb ,
579
+ conv2d_res ,
580
+ actual_out_type ,
581
+ input_scale ,
582
+ weight_scale ,
583
+ output_scale ,
584
+ )
585
+
550
586
tosa_fb .addOperator (
551
587
TosaOp .Op ().TRANSPOSE ,
552
588
[conv2d_res .name ],
@@ -802,7 +838,7 @@ def preprocess( # noqa: C901
802
838
p_data = edge_program .state_dict [parameter_name ]
803
839
804
840
assert isinstance (p_data , torch .Tensor ), "Expect Attr to be tensor"
805
- weight_values = p_data .detach ().numpy ()
841
+ parameter_values = p_data .detach ().numpy ()
806
842
807
843
# Check if they're for quantized nodes
808
844
consumer_node = list (node .users )[0 ]
@@ -811,14 +847,14 @@ def preprocess( # noqa: C901
811
847
consumer_node
812
848
)
813
849
814
- weight_values_quantized = (
815
- (weight_values / weight_node_scale .number )
850
+ parameter_values_quantized = (
851
+ (parameter_values / weight_node_scale .number )
816
852
+ weight_node_zp .number
817
853
).astype (np .int8 )
818
854
tosa_fb .addConst (
819
855
inputs [0 ].shape ,
820
856
ts .DType .INT8 ,
821
- weight_values_quantized ,
857
+ parameter_values_quantized ,
822
858
name = out ,
823
859
)
824
860
elif (
@@ -837,30 +873,55 @@ def preprocess( # noqa: C901
837
873
weight_node
838
874
)
839
875
840
- weight_values_quantized = (
841
- weight_values / (input_node_scale * weight_node_scale )
876
+ parameter_values_quantized = (
877
+ parameter_values / (input_node_scale * weight_node_scale )
842
878
).astype (np .int32 )
843
879
844
880
tosa_fb .addConst (
845
881
inputs [0 ].shape ,
846
882
ts .DType .INT32 ,
847
- weight_values_quantized ,
883
+ parameter_values_quantized ,
884
+ name = out ,
885
+ )
886
+ elif (
887
+ consumer_node .target == exir_ops .edge .aten .convolution .default
888
+ and list (consumer_node .users )[0 ].target == tosa_quant_utils .q_op
889
+ ):
890
+ (
891
+ input_node ,
892
+ weight_node ,
893
+ bias_node ,
894
+ ) = consumer_node .all_input_nodes
895
+
896
+ input_node_scale , _ = getQuantNodeArgs (input_node )
897
+ weight_node_scale , _ = getQuantNodeArgs (weight_node )
898
+
899
+ bias_scales = input_node_scale * weight_node_scale
900
+ parameter_values_quantized = (
901
+ parameter_values / bias_scales
902
+ ).astype (np .int32 )
903
+
904
+ tosa_fb .addConst (
905
+ inputs [0 ].shape ,
906
+ ts .DType .INT32 ,
907
+ parameter_values_quantized ,
848
908
name = out ,
849
909
)
850
910
else :
851
911
tosa_fb .addConst (
852
- inputs [0 ].shape , inputs [0 ].dtype , weight_values , name = out
912
+ inputs [0 ].shape , inputs [0 ].dtype , parameter_values , name = out
853
913
)
914
+
854
915
elif out in edge_program .graph_signature .inputs_to_buffers :
855
916
parameter_name = edge_program .graph_signature .inputs_to_buffers [
856
917
node .name
857
918
]
858
919
p_data = edge_program .state_dict [parameter_name ]
859
920
860
921
assert isinstance (p_data , torch .Tensor ), "Expect Attr to be tensor"
861
- weight_values = p_data .detach ().numpy ()
922
+ parameter_values = p_data .detach ().numpy ()
862
923
tosa_fb .addConst (
863
- inputs [0 ].shape , inputs [0 ].dtype , weight_values , name = out
924
+ inputs [0 ].shape , inputs [0 ].dtype , parameter_values , name = out
864
925
)
865
926
else :
866
927
tensor = ts .TosaSerializerTensor (
0 commit comments