@@ -217,6 +217,12 @@ def getNodeArgs(node):
217
217
return [tosa_mapping .TosaArg (arg ) for arg in node .args ]
218
218
219
219
220
+ def getQuantNodeArgs (node ):
221
+ quant_args = [tosa_mapping .TosaArg (arg ) for arg in node .args ]
222
+ # Return the scale and zp
223
+ return quant_args [1 ].number , quant_args [2 ].number
224
+
225
+
220
226
@final
221
227
class ArmBackend (BackendDetails ):
222
228
@staticmethod
@@ -253,6 +259,7 @@ def preprocess( # noqa: C901
253
259
outp = tosa_mapping .TosaArg (node )
254
260
255
261
is_quant_node = tosa_quant_utils .isQuantNode (node )
262
+
256
263
if is_quant_node :
257
264
tosa_fb .currRegion .currBasicBlock .addTensor (
258
265
outp .name , outp .shape , ts .DType .INT8
@@ -345,13 +352,17 @@ def preprocess( # noqa: C901
345
352
elif exir_ops .edge .aten .addmm .default == node .target :
346
353
bias , input , weight = inputs
347
354
355
+ output_dtype = ts .DType .INT8 if is_quant_node else outp .dtype
356
+
348
357
# Reshape input, weight, bias tensors
349
358
input_reshape_res = promote_shape (
350
- tosa_fb , input , (1 ,) + input .shape , outp . dtype
359
+ tosa_fb , input , (1 ,) + input .shape , output_dtype
351
360
)
352
361
weight_reshape_res = promote_shape (
353
- tosa_fb , weight , (1 ,) + weight .shape , outp . dtype
362
+ tosa_fb , weight , (1 ,) + weight .shape , output_dtype
354
363
)
364
+
365
+ bias_dtype = ts .DType .INT32 if is_quant_node else outp .dtype
355
366
bias_reshape_res = promote_shape (
356
367
tosa_fb ,
357
368
bias ,
@@ -360,36 +371,87 @@ def preprocess( # noqa: C901
360
371
1 ,
361
372
)
362
373
+ bias .shape ,
363
- outp . dtype ,
374
+ bias_dtype ,
364
375
)
365
376
366
377
# Add dummy batch 1 to mm_shape
367
378
mm_shape = (1 , input .shape [0 ], weight .shape [1 ])
368
379
# Define Intermediate tensor for MatMul res
369
- mm_res = tosa_fb .addIntermediate (mm_shape , outp .dtype )
380
+ mm_res = tosa_fb .addIntermediate (
381
+ mm_shape , ts .DType .INT32 if is_quant_node else output_dtype
382
+ )
370
383
371
384
# Add MatMulOp
385
+ attr_matmul = ts .TosaSerializerAttribute ()
386
+ a_zp , b_zp = (- 128 , 0 ) if is_quant_node else (0 , 0 )
387
+ attr_matmul .MatMulAttribute (a_zp , b_zp )
372
388
tosa_fb .addOperator (
373
389
TosaOp .Op ().MATMUL ,
374
390
[input_reshape_res .name , weight_reshape_res .name ],
375
391
[mm_res .name ],
376
- attr_torch_to_tosa ( TosaOp . Op (). MATMUL , node ) ,
392
+ attr_matmul ,
377
393
)
378
394
379
395
# Add AddOp
380
- add_res = tosa_fb .addIntermediate (mm_shape , outp .dtype )
396
+ add_res = tosa_fb .addIntermediate (
397
+ mm_shape , ts .DType .INT32 if is_quant_node else output_dtype
398
+ )
399
+
381
400
tosa_fb .addOperator (
382
401
TosaOp .Op ().ADD ,
383
402
[bias_reshape_res .name , mm_res .name ],
384
403
[add_res .name ],
385
404
None ,
386
405
)
387
406
407
+ if is_quant_node :
408
+ # Read inputs' parent nodes
409
+ #
410
+ _ , input_node , weight_node = node .all_input_nodes
411
+ input_scale , _ = getQuantNodeArgs (input_node )
412
+ weight_node_q_node = weight_node .all_input_nodes [0 ]
413
+ weight_scale , _ = getQuantNodeArgs (weight_node_q_node )
414
+
415
+ consumer_node = list (node .users )[0 ]
416
+ consumer_node_scale , consumer_node_node_zp = getQuantNodeArgs (
417
+ consumer_node
418
+ )
419
+
420
+ output_rescale_scale = (
421
+ input_scale * weight_scale
422
+ ) / consumer_node_scale
423
+ (
424
+ multiplier_output ,
425
+ shift_output ,
426
+ ) = tosa_quant_utils .computeMultiplierAndShift (
427
+ output_rescale_scale
428
+ )
429
+
430
+ attr_rescale_output = ts .TosaSerializerAttribute ()
431
+ attr_rescale_output .RescaleAttribute (
432
+ input_zp = 0 ,
433
+ output_zp = consumer_node_node_zp ,
434
+ multiplier = [multiplier_output ],
435
+ shift = [shift_output ],
436
+ scale32 = True ,
437
+ double_round = True ,
438
+ per_channel = False ,
439
+ )
440
+ add_res_int8 = tosa_fb .addIntermediate (mm_shape , ts .DType .INT8 )
441
+ tosa_fb .addOperator (
442
+ TosaOp .Op ().RESCALE ,
443
+ [add_res .name ],
444
+ [add_res_int8 .name ],
445
+ attr_rescale_output ,
446
+ )
388
447
# Reshape final result to original shape
389
448
attr_out = ts .TosaSerializerAttribute ()
390
449
attr_out .ReshapeAttribute (outp .shape )
391
450
tosa_fb .addOperator (
392
- TosaOp .Op ().RESHAPE , [add_res .name ], [outp .name ], attr_out
451
+ TosaOp .Op ().RESHAPE ,
452
+ [add_res_int8 .name if is_quant_node else add_res .name ],
453
+ [outp .name ],
454
+ attr_out ,
393
455
)
394
456
elif exir_ops .edge .aten .permute_copy .default == node .target :
395
457
attr = ts .TosaSerializerAttribute ()
@@ -700,20 +762,11 @@ def preprocess( # noqa: C901
700
762
[outp .name ],
701
763
attr_mul ,
702
764
)
703
- elif operator .getitem == node .target :
704
- item_name = inputs [0 ].name
705
- ## Simply add an identityOp
706
- tosa_fb .addOperator (TosaOp .Op ().IDENTITY , [item_name ], [outp .name ])
707
- elif (
708
- exir_ops .edge .quantized_decomposed .quantize_per_tensor .default
709
- == node .target
710
- ):
711
- item_name = inputs [0 ].name
712
- tosa_fb .addOperator (TosaOp .Op ().IDENTITY , [item_name ], [outp .name ])
713
- elif (
714
- exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default
715
- == node .target
716
- ):
765
+ elif node .target in [
766
+ operator .getitem ,
767
+ tosa_quant_utils .q_op ,
768
+ tosa_quant_utils .dq_op ,
769
+ ]:
717
770
item_name = inputs [0 ].name
718
771
## Simply add an identityOp
719
772
tosa_fb .addOperator (TosaOp .Op ().IDENTITY , [item_name ], [outp .name ])
@@ -740,9 +793,54 @@ def preprocess( # noqa: C901
740
793
741
794
assert isinstance (p_data , torch .Tensor ), "Expect Attr to be tensor"
742
795
weight_values = p_data .detach ().numpy ()
743
- tosa_fb .addConst (
744
- inputs [0 ].shape , inputs [0 ].dtype , weight_values , name = out
745
- )
796
+
797
+ # Check if they're for quantized nodes
798
+ consumer_node = list (node .users )[0 ]
799
+ if consumer_node .target in tosa_quant_utils .dq_q_ops :
800
+ _ , weight_node_scale , weight_node_zp , _ , _ , _ = getNodeArgs (
801
+ consumer_node
802
+ )
803
+
804
+ weight_values_quantized = (
805
+ (weight_values / weight_node_scale .number )
806
+ + weight_node_zp .number
807
+ ).astype (np .int8 )
808
+ tosa_fb .addConst (
809
+ inputs [0 ].shape ,
810
+ ts .DType .INT8 ,
811
+ weight_values_quantized ,
812
+ name = out ,
813
+ )
814
+ elif (
815
+ consumer_node .target == exir_ops .edge .aten .addmm .default
816
+ and list (consumer_node .users )[0 ].target == tosa_quant_utils .q_op
817
+ ):
818
+ (
819
+ _ ,
820
+ input_node ,
821
+ weight_node_permuted ,
822
+ ) = consumer_node .all_input_nodes
823
+ weight_node = weight_node_permuted .all_input_nodes [0 ]
824
+
825
+ input_node_scale , _ = getQuantNodeArgs (input_node )
826
+ weight_node_scale , weight_node_zp = getQuantNodeArgs (
827
+ weight_node
828
+ )
829
+
830
+ weight_values_quantized = (
831
+ weight_values / (input_node_scale * weight_node_scale )
832
+ ).astype (np .int32 )
833
+
834
+ tosa_fb .addConst (
835
+ inputs [0 ].shape ,
836
+ ts .DType .INT32 ,
837
+ weight_values_quantized ,
838
+ name = out ,
839
+ )
840
+ else :
841
+ tosa_fb .addConst (
842
+ inputs [0 ].shape , inputs [0 ].dtype , weight_values , name = out
843
+ )
746
844
elif out in edge_program .graph_signature .inputs_to_buffers :
747
845
parameter_name = edge_program .graph_signature .inputs_to_buffers [
748
846
node .name
0 commit comments