@@ -246,6 +246,10 @@ def preprocess( # noqa: C901
246
246
if path is None :
247
247
path = tempfile .mkdtemp (prefix = "arm_tosa_" )
248
248
249
+ # Verify if this is a quantized model ahead so that the tensor data type of
250
+ # tosa operations during lowering can be easier determined.
251
+ is_quantized_model = tosa_quant_utils .isQuantizedModel (edge_program .graph )
252
+
249
253
# Converted output for this subgraph, serializer needs path early as it emits
250
254
# const data directly. Path created and data written only in debug builds.
251
255
tosa_fb = ts .TosaSerializer (path )
@@ -476,10 +480,15 @@ def preprocess( # noqa: C901
476
480
elif exir_ops .edge .aten .convolution .default == node .target :
477
481
input , weight , bias , stride , pad , dilation , _ , _ , group = inputs
478
482
483
+ # Currently only int8 is supported in quantized types.
484
+ actual_out_type = (
485
+ ts .DType .INT8 if is_quantized_model else outp .dtype
486
+ )
487
+
479
488
## Transpose input tensor to NHWC_Order for TOSA
480
489
NHWC_Order = [0 , 2 , 3 , 1 ]
481
490
input_transposed = transpose_helper (
482
- tosa_fb , input , NHWC_Order , outp . dtype
491
+ tosa_fb , input , NHWC_Order , actual_out_type
483
492
)
484
493
485
494
## CONV2DOp
@@ -493,6 +502,11 @@ def preprocess( # noqa: C901
493
502
attr .ConvAttribute (pad_attr , stride_attr , dilation_attr , 0 , 0 )
494
503
495
504
if group .number > 1 :
505
+ if is_quant_node :
506
+ raise AssertionError (
507
+ "quantized depthwise conv2d is not supported for now"
508
+ )
509
+
496
510
# Transpose weight to [KH, KW, C, M]
497
511
weight_HWCM_Order = [2 , 3 , 0 , 1 ]
498
512
weight_transposed = transpose_helper (
@@ -523,14 +537,17 @@ def preprocess( # noqa: C901
523
537
# Transpose weight to [OC, H, W, IC]
524
538
weight_CHWC_Order = [0 , 2 , 3 , 1 ]
525
539
weight_transposed = transpose_helper (
526
- tosa_fb , weight , weight_CHWC_Order , outp . dtype
540
+ tosa_fb , weight , weight_CHWC_Order , actual_out_type
527
541
)
528
542
529
543
## TOSA output shape is [NHWO]
530
544
NHWO_Order = [0 , 2 , 3 , 1 ]
531
545
out_shape_TOSA_CONV2D = [outp .shape [i ] for i in NHWO_Order ]
546
+
547
+ # The output type is int32 when input type is int8.
532
548
conv2d_res = tosa_fb .addIntermediate (
533
- out_shape_TOSA_CONV2D , outp .dtype
549
+ out_shape_TOSA_CONV2D ,
550
+ ts .DType .INT32 if is_quant_node else outp .dtype ,
534
551
)
535
552
tosa_fb .addOperator (
536
553
TosaOp .Op ().CONV2D ,
@@ -547,12 +564,45 @@ def preprocess( # noqa: C901
547
564
NOHW_Order = [0 , 3 , 1 , 2 ]
548
565
attr_output_transpose = ts .TosaSerializerAttribute ()
549
566
attr_output_transpose .TransposeAttribute (NOHW_Order )
550
- tosa_fb .addOperator (
551
- TosaOp .Op ().TRANSPOSE ,
552
- [conv2d_res .name ],
553
- [outp .name ],
554
- attr_output_transpose ,
555
- )
567
+
568
+ if len (node .all_input_nodes ) == 3 :
569
+ input_node , weight_node , bias_node = node .all_input_nodes
570
+ else :
571
+ raise AssertionError (
572
+ "non-biased conv2d is not supported for now"
573
+ )
574
+
575
+ output_node = list (node .users )[0 ]
576
+
577
+ # For quantized convolution, rescale the output value back to the same
578
+ # integer value domain of the next op. Otherwise return float32 output.
579
+ if is_quant_node :
580
+ # Get scale_factor from input, weight, and output.
581
+ _ , input_scale , _ , _ , _ , _ = getNodeArgs (input_node )
582
+ _ , weight_scale , _ , _ , _ , _ = getNodeArgs (weight_node )
583
+ _ , output_scale , _ , _ , _ , _ = getNodeArgs (output_node )
584
+ rescaled_conv2d_res = tosa_quant_utils .buildRescaleOpConvOutput (
585
+ tosa_fb ,
586
+ conv2d_res ,
587
+ actual_out_type ,
588
+ input_scale ,
589
+ weight_scale ,
590
+ output_scale ,
591
+ )
592
+ tosa_fb .addOperator (
593
+ TosaOp .Op ().TRANSPOSE ,
594
+ [rescaled_conv2d_res .name ],
595
+ [outp .name ],
596
+ attr_output_transpose ,
597
+ )
598
+ else :
599
+ tosa_fb .addOperator (
600
+ TosaOp .Op ().TRANSPOSE ,
601
+ [conv2d_res .name ],
602
+ [outp .name ],
603
+ attr_output_transpose ,
604
+ )
605
+
556
606
elif exir_ops .edge .aten .div .Tensor == node .target :
557
607
# Div is implemented as x/y = x*1/y
558
608
recip = tosa_fb .addIntermediate (inputs [1 ].shape , inputs [1 ].dtype )
@@ -802,7 +852,7 @@ def preprocess( # noqa: C901
802
852
p_data = edge_program .state_dict [parameter_name ]
803
853
804
854
assert isinstance (p_data , torch .Tensor ), "Expect Attr to be tensor"
805
- weight_values = p_data .detach ().numpy ()
855
+ ph_values = p_data .detach ().numpy ()
806
856
807
857
# Check if they're for quantized nodes
808
858
consumer_node = list (node .users )[0 ]
@@ -811,14 +861,14 @@ def preprocess( # noqa: C901
811
861
consumer_node
812
862
)
813
863
814
- weight_values_quantized = (
815
- (weight_values / weight_node_scale .number )
864
+ ph_values_quantized = (
865
+ (ph_values / weight_node_scale .number )
816
866
+ weight_node_zp .number
817
867
).astype (np .int8 )
818
868
tosa_fb .addConst (
819
869
inputs [0 ].shape ,
820
870
ts .DType .INT8 ,
821
- weight_values_quantized ,
871
+ ph_values_quantized ,
822
872
name = out ,
823
873
)
824
874
elif (
@@ -837,30 +887,53 @@ def preprocess( # noqa: C901
837
887
weight_node
838
888
)
839
889
840
- weight_values_quantized = (
841
- weight_values / (input_node_scale * weight_node_scale )
890
+ ph_values_quantized = (
891
+ ph_values / (input_node_scale * weight_node_scale )
842
892
).astype (np .int32 )
843
893
844
894
tosa_fb .addConst (
845
895
inputs [0 ].shape ,
846
896
ts .DType .INT32 ,
847
- weight_values_quantized ,
897
+ ph_values_quantized ,
898
+ name = out ,
899
+ )
900
+ elif (
901
+ consumer_node .target == exir_ops .edge .aten .convolution .default
902
+ and list (consumer_node .users )[0 ].target == tosa_quant_utils .q_op
903
+ ):
904
+ (
905
+ input_node ,
906
+ weight_node ,
907
+ bias_node ,
908
+ ) = consumer_node .all_input_nodes
909
+
910
+ input_node_scale , _ = getQuantNodeArgs (input_node )
911
+ weight_node_scale , _ = getQuantNodeArgs (weight_node )
912
+
913
+ bias_scales = input_node_scale * weight_node_scale
914
+ ph_values_quantized = (ph_values / bias_scales ).astype (np .int32 )
915
+
916
+ tosa_fb .addConst (
917
+ inputs [0 ].shape ,
918
+ ts .DType .INT32 ,
919
+ ph_values_quantized ,
848
920
name = out ,
849
921
)
850
922
else :
851
923
tosa_fb .addConst (
852
- inputs [0 ].shape , inputs [0 ].dtype , weight_values , name = out
924
+ inputs [0 ].shape , inputs [0 ].dtype , ph_values , name = out
853
925
)
926
+
854
927
elif out in edge_program .graph_signature .inputs_to_buffers :
855
928
parameter_name = edge_program .graph_signature .inputs_to_buffers [
856
929
node .name
857
930
]
858
931
p_data = edge_program .state_dict [parameter_name ]
859
932
860
933
assert isinstance (p_data , torch .Tensor ), "Expect Attr to be tensor"
861
- weight_values = p_data .detach ().numpy ()
934
+ ph_values = p_data .detach ().numpy ()
862
935
tosa_fb .addConst (
863
- inputs [0 ].shape , inputs [0 ].dtype , weight_values , name = out
936
+ inputs [0 ].shape , inputs [0 ].dtype , ph_values , name = out
864
937
)
865
938
else :
866
939
tensor = ts .TosaSerializerTensor (
0 commit comments