@@ -227,6 +227,14 @@ def getQuantNodeArgs(node):
227
227
228
228
@final
229
229
class ArmBackend (BackendDetails ):
230
+ # Class variable initialization
231
+ ssa_num = - 1
232
+
233
+ @staticmethod
234
+ def getSSAnum ():
235
+ ArmBackend .ssa_num += 1
236
+ return ArmBackend .ssa_num
237
+
230
238
@staticmethod
231
239
def preprocess ( # noqa: C901
232
240
edge_program : ExportedProgram ,
@@ -476,10 +484,13 @@ def preprocess( # noqa: C901
476
484
elif exir_ops .edge .aten .convolution .default == node .target :
477
485
input , weight , bias , stride , pad , dilation , _ , _ , group = inputs
478
486
487
+ # Currently only int8 is supported in quantized types.
488
+ actual_out_type = ts .DType .INT8 if is_quant_node else outp .dtype
489
+
479
490
## Transpose input tensor to NHWC_Order for TOSA
480
491
NHWC_Order = [0 , 2 , 3 , 1 ]
481
492
input_transposed = transpose_helper (
482
- tosa_fb , input , NHWC_Order , outp . dtype
493
+ tosa_fb , input , NHWC_Order , actual_out_type
483
494
)
484
495
485
496
## CONV2DOp
@@ -492,6 +503,21 @@ def preprocess( # noqa: C901
492
503
dilation_attr = dilation .special
493
504
attr .ConvAttribute (pad_attr , stride_attr , dilation_attr , 0 , 0 )
494
505
506
+ if len (node .all_input_nodes ) == 3 :
507
+ input_node , weight_node , _ = node .all_input_nodes
508
+ else :
509
+ input_node , weight_node = node .all_input_nodes
510
+
511
+ # Create a zero bias tensor if not presented
512
+ out_channels = weight .shape [0 ]
513
+ bias_name = "const_bias_" + str (ArmBackend .getSSAnum ())
514
+ bias = tosa_fb .addConst (
515
+ [out_channels ],
516
+ ts .DType .INT32 if is_quant_node else outp .dtype ,
517
+ [0 ] * out_channels ,
518
+ name = bias_name ,
519
+ )
520
+
495
521
if group .number > 1 :
496
522
# Transpose weight to [KH, KW, C, M]
497
523
weight_HWCM_Order = [2 , 3 , 0 , 1 ]
@@ -523,14 +549,17 @@ def preprocess( # noqa: C901
523
549
# Transpose weight to [OC, H, W, IC]
524
550
weight_CHWC_Order = [0 , 2 , 3 , 1 ]
525
551
weight_transposed = transpose_helper (
526
- tosa_fb , weight , weight_CHWC_Order , outp . dtype
552
+ tosa_fb , weight , weight_CHWC_Order , actual_out_type
527
553
)
528
554
529
555
## TOSA output shape is [NHWO]
530
556
NHWO_Order = [0 , 2 , 3 , 1 ]
531
557
out_shape_TOSA_CONV2D = [outp .shape [i ] for i in NHWO_Order ]
558
+
559
+ # The output type is int32 when input type is int8.
532
560
conv2d_res = tosa_fb .addIntermediate (
533
- out_shape_TOSA_CONV2D , outp .dtype
561
+ out_shape_TOSA_CONV2D ,
562
+ ts .DType .INT32 if is_quant_node else outp .dtype ,
534
563
)
535
564
tosa_fb .addOperator (
536
565
TosaOp .Op ().CONV2D ,
@@ -547,12 +576,32 @@ def preprocess( # noqa: C901
547
576
NOHW_Order = [0 , 3 , 1 , 2 ]
548
577
attr_output_transpose = ts .TosaSerializerAttribute ()
549
578
attr_output_transpose .TransposeAttribute (NOHW_Order )
579
+
580
+ # For quantized convolution, rescale the output value back to the same
581
+ # integer value domain of the next op. Otherwise return float32 output.
582
+ if is_quant_node :
583
+ # Get scale_factor from input, weight, and output.
584
+ output_node = list (node .users )[0 ]
585
+ _ , input_scale , _ , _ , _ , _ = getNodeArgs (input_node )
586
+ _ , weight_scale , _ , _ , _ , _ = getNodeArgs (weight_node )
587
+ _ , output_scale , _ , _ , _ , _ = getNodeArgs (output_node )
588
+
589
+ conv2d_res = tosa_quant_utils .buildRescaleOpConvOutput (
590
+ tosa_fb ,
591
+ conv2d_res ,
592
+ actual_out_type ,
593
+ input_scale ,
594
+ weight_scale ,
595
+ output_scale ,
596
+ )
597
+
550
598
tosa_fb .addOperator (
551
599
TosaOp .Op ().TRANSPOSE ,
552
600
[conv2d_res .name ],
553
601
[outp .name ],
554
602
attr_output_transpose ,
555
603
)
604
+
556
605
elif exir_ops .edge .aten .div .Tensor == node .target :
557
606
# Div is implemented as x/y = x*1/y
558
607
recip = tosa_fb .addIntermediate (inputs [1 ].shape , inputs [1 ].dtype )
@@ -802,7 +851,7 @@ def preprocess( # noqa: C901
802
851
p_data = edge_program .state_dict [parameter_name ]
803
852
804
853
assert isinstance (p_data , torch .Tensor ), "Expect Attr to be tensor"
805
- weight_values = p_data .detach ().numpy ()
854
+ parameter_values = p_data .detach ().numpy ()
806
855
807
856
# Check if they're for quantized nodes
808
857
consumer_node = list (node .users )[0 ]
@@ -811,14 +860,14 @@ def preprocess( # noqa: C901
811
860
consumer_node
812
861
)
813
862
814
- weight_values_quantized = (
815
- (weight_values / weight_node_scale .number )
863
+ parameter_values_quantized = (
864
+ (parameter_values / weight_node_scale .number )
816
865
+ weight_node_zp .number
817
866
).astype (np .int8 )
818
867
tosa_fb .addConst (
819
868
inputs [0 ].shape ,
820
869
ts .DType .INT8 ,
821
- weight_values_quantized ,
870
+ parameter_values_quantized ,
822
871
name = out ,
823
872
)
824
873
elif (
@@ -837,30 +886,55 @@ def preprocess( # noqa: C901
837
886
weight_node
838
887
)
839
888
840
- weight_values_quantized = (
841
- weight_values / (input_node_scale * weight_node_scale )
889
+ parameter_values_quantized = (
890
+ parameter_values / (input_node_scale * weight_node_scale )
842
891
).astype (np .int32 )
843
892
844
893
tosa_fb .addConst (
845
894
inputs [0 ].shape ,
846
895
ts .DType .INT32 ,
847
- weight_values_quantized ,
896
+ parameter_values_quantized ,
897
+ name = out ,
898
+ )
899
+ elif (
900
+ consumer_node .target == exir_ops .edge .aten .convolution .default
901
+ and list (consumer_node .users )[0 ].target == tosa_quant_utils .q_op
902
+ ):
903
+ (
904
+ input_node ,
905
+ weight_node ,
906
+ bias_node ,
907
+ ) = consumer_node .all_input_nodes
908
+
909
+ input_node_scale , _ = getQuantNodeArgs (input_node )
910
+ weight_node_scale , _ = getQuantNodeArgs (weight_node )
911
+
912
+ bias_scales = input_node_scale * weight_node_scale
913
+ parameter_values_quantized = (
914
+ parameter_values / bias_scales
915
+ ).astype (np .int32 )
916
+
917
+ tosa_fb .addConst (
918
+ inputs [0 ].shape ,
919
+ ts .DType .INT32 ,
920
+ parameter_values_quantized ,
848
921
name = out ,
849
922
)
850
923
else :
851
924
tosa_fb .addConst (
852
- inputs [0 ].shape , inputs [0 ].dtype , weight_values , name = out
925
+ inputs [0 ].shape , inputs [0 ].dtype , parameter_values , name = out
853
926
)
927
+
854
928
elif out in edge_program .graph_signature .inputs_to_buffers :
855
929
parameter_name = edge_program .graph_signature .inputs_to_buffers [
856
930
node .name
857
931
]
858
932
p_data = edge_program .state_dict [parameter_name ]
859
933
860
934
assert isinstance (p_data , torch .Tensor ), "Expect Attr to be tensor"
861
- weight_values = p_data .detach ().numpy ()
935
+ parameter_values = p_data .detach ().numpy ()
862
936
tosa_fb .addConst (
863
- inputs [0 ].shape , inputs [0 ].dtype , weight_values , name = out
937
+ inputs [0 ].shape , inputs [0 ].dtype , parameter_values , name = out
864
938
)
865
939
else :
866
940
tensor = ts .TosaSerializerTensor (
0 commit comments