Skip to content

Commit 0262272

Browse files
committed
Add initial lowering of aten.convolution to tosa.conv2d support
1 parent 17fee78 commit 0262272

File tree

4 files changed

+281
-32
lines changed

4 files changed

+281
-32
lines changed

backends/arm/arm_backend.py

Lines changed: 78 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -476,22 +476,37 @@ def preprocess( # noqa: C901
476476
elif exir_ops.edge.aten.convolution.default == node.target:
477477
input, weight, bias, stride, pad, dilation, _, _, group = inputs
478478

479+
# Currently only int8 is supported in quantized types.
480+
actual_out_type = ts.DType.INT8 if is_quant_node else outp.dtype
481+
479482
## Transpose input tensor to NHWC_Order for TOSA
480483
NHWC_Order = [0, 2, 3, 1]
481484
input_transposed = transpose_helper(
482-
tosa_fb, input, NHWC_Order, outp.dtype
485+
tosa_fb, input, NHWC_Order, actual_out_type
483486
)
484487

485-
## CONV2DOp
488+
# Get the attributes of convolution.
486489
attr = ts.TosaSerializerAttribute()
487-
# PAD
488490
pad_attr = [val for val in pad.special for _ in (0, 1)]
489-
# Stride
490491
stride_attr = stride.special
491-
# Dilation
492492
dilation_attr = dilation.special
493493
attr.ConvAttribute(pad_attr, stride_attr, dilation_attr, 0, 0)
494494

495+
# Non-bias case.
496+
if len(node.all_input_nodes) == 2:
497+
assert (
498+
is_quant_node == False
499+
), "currently non-bias convolution is not supported yet in BI mode"
500+
# Create a zero bias tensor if not presented
501+
out_channels = weight.shape[0]
502+
bias_name = "bias" + node.name.split("default", 1)[1]
503+
bias = tosa_fb.addConst(
504+
[out_channels],
505+
outp.dtype,
506+
[0] * out_channels,
507+
name=bias_name,
508+
)
509+
495510
if group.number > 1:
496511
# Transpose weight to [KH, KW, C, M]
497512
weight_HWCM_Order = [2, 3, 0, 1]
@@ -523,14 +538,17 @@ def preprocess( # noqa: C901
523538
# Transpose weight to [OC, H, W, IC]
524539
weight_CHWC_Order = [0, 2, 3, 1]
525540
weight_transposed = transpose_helper(
526-
tosa_fb, weight, weight_CHWC_Order, outp.dtype
541+
tosa_fb, weight, weight_CHWC_Order, actual_out_type
527542
)
528543

529544
## TOSA output shape is [NHWO]
530545
NHWO_Order = [0, 2, 3, 1]
531546
out_shape_TOSA_CONV2D = [outp.shape[i] for i in NHWO_Order]
547+
548+
# The output type is int32 when input type is int8.
532549
conv2d_res = tosa_fb.addIntermediate(
533-
out_shape_TOSA_CONV2D, outp.dtype
550+
out_shape_TOSA_CONV2D,
551+
ts.DType.INT32 if is_quant_node else outp.dtype,
534552
)
535553
tosa_fb.addOperator(
536554
TosaOp.Op().CONV2D,
@@ -547,6 +565,24 @@ def preprocess( # noqa: C901
547565
NOHW_Order = [0, 3, 1, 2]
548566
attr_output_transpose = ts.TosaSerializerAttribute()
549567
attr_output_transpose.TransposeAttribute(NOHW_Order)
568+
569+
# For quantized convolution, rescale the output value back to the same
570+
# integer value domain of the next op. Otherwise return float32 output.
571+
if is_quant_node:
572+
# Get scale_factor from input, weight, and output.
573+
_, input_scale, _, _, _, _ = getNodeArgs(node.args[0])
574+
_, weight_scale, _, _, _, _ = getNodeArgs(node.args[1])
575+
_, output_scale, _, _, _, _ = getNodeArgs(list(node.users)[0])
576+
577+
conv2d_res = tosa_quant_utils.buildRescaleOpConvOutput(
578+
tosa_fb,
579+
conv2d_res,
580+
actual_out_type,
581+
input_scale,
582+
weight_scale,
583+
output_scale,
584+
)
585+
550586
tosa_fb.addOperator(
551587
TosaOp.Op().TRANSPOSE,
552588
[conv2d_res.name],
@@ -802,7 +838,7 @@ def preprocess( # noqa: C901
802838
p_data = edge_program.state_dict[parameter_name]
803839

804840
assert isinstance(p_data, torch.Tensor), "Expect Attr to be tensor"
805-
weight_values = p_data.detach().numpy()
841+
parameter_values = p_data.detach().numpy()
806842

807843
# Check if they're for quantized nodes
808844
consumer_node = list(node.users)[0]
@@ -811,14 +847,14 @@ def preprocess( # noqa: C901
811847
consumer_node
812848
)
813849

814-
weight_values_quantized = (
815-
(weight_values / weight_node_scale.number)
850+
parameter_values_quantized = (
851+
(parameter_values / weight_node_scale.number)
816852
+ weight_node_zp.number
817853
).astype(np.int8)
818854
tosa_fb.addConst(
819855
inputs[0].shape,
820856
ts.DType.INT8,
821-
weight_values_quantized,
857+
parameter_values_quantized,
822858
name=out,
823859
)
824860
elif (
@@ -837,30 +873,55 @@ def preprocess( # noqa: C901
837873
weight_node
838874
)
839875

840-
weight_values_quantized = (
841-
weight_values / (input_node_scale * weight_node_scale)
876+
parameter_values_quantized = (
877+
parameter_values / (input_node_scale * weight_node_scale)
842878
).astype(np.int32)
843879

844880
tosa_fb.addConst(
845881
inputs[0].shape,
846882
ts.DType.INT32,
847-
weight_values_quantized,
883+
parameter_values_quantized,
884+
name=out,
885+
)
886+
elif (
887+
consumer_node.target == exir_ops.edge.aten.convolution.default
888+
and list(consumer_node.users)[0].target == tosa_quant_utils.q_op
889+
):
890+
(
891+
input_node,
892+
weight_node,
893+
bias_node,
894+
) = consumer_node.all_input_nodes
895+
896+
input_node_scale, _ = getQuantNodeArgs(input_node)
897+
weight_node_scale, _ = getQuantNodeArgs(weight_node)
898+
899+
bias_scales = input_node_scale * weight_node_scale
900+
parameter_values_quantized = (
901+
parameter_values / bias_scales
902+
).astype(np.int32)
903+
904+
tosa_fb.addConst(
905+
inputs[0].shape,
906+
ts.DType.INT32,
907+
parameter_values_quantized,
848908
name=out,
849909
)
850910
else:
851911
tosa_fb.addConst(
852-
inputs[0].shape, inputs[0].dtype, weight_values, name=out
912+
inputs[0].shape, inputs[0].dtype, parameter_values, name=out
853913
)
914+
854915
elif out in edge_program.graph_signature.inputs_to_buffers:
855916
parameter_name = edge_program.graph_signature.inputs_to_buffers[
856917
node.name
857918
]
858919
p_data = edge_program.state_dict[parameter_name]
859920

860921
assert isinstance(p_data, torch.Tensor), "Expect Attr to be tensor"
861-
weight_values = p_data.detach().numpy()
922+
parameter_values = p_data.detach().numpy()
862923
tosa_fb.addConst(
863-
inputs[0].shape, inputs[0].dtype, weight_values, name=out
924+
inputs[0].shape, inputs[0].dtype, parameter_values, name=out
864925
)
865926
else:
866927
tensor = ts.TosaSerializerTensor(

0 commit comments

Comments
 (0)