Skip to content

Commit 06e44f5

Browse files
committed
Add initial lowering of aten.convolution to tosa.conv2d support
1 parent 4ff9736 commit 06e44f5

File tree

4 files changed

+288
-50
lines changed

4 files changed

+288
-50
lines changed

backends/arm/arm_backend.py

Lines changed: 79 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -553,23 +553,39 @@ def preprocess( # noqa: C901
553553
elif exir_ops.edge.aten.convolution.default == node.target:
554554
input, weight, bias, stride, pad, dilation, _, _, group = inputs
555555

556+
# Currently only int8 is supported in quantized types.
557+
actual_out_type = ts.DType.INT8 if is_quant_node else outp.dtype
558+
556559
## Transpose input tensor to NHWC_Order for TOSA
557560
NHWC_Order = [0, 2, 3, 1]
558561
input_transposed = transpose_helper(
559-
tosa_fb, input, NHWC_Order, outp.dtype
562+
tosa_fb, input, NHWC_Order, actual_out_type
560563
)
561564

562-
## CONV2DOp
565+
# Get the attributes of convolution.
563566
attr = ts.TosaSerializerAttribute()
564-
# PAD
565567
pad_attr = [val for val in pad.special for _ in (0, 1)]
566-
# Stride
567568
stride_attr = stride.special
568-
# Dilation
569569
dilation_attr = dilation.special
570570
attr.ConvAttribute(pad_attr, stride_attr, dilation_attr, 0, 0)
571571

572+
# Non-bias case.
573+
if len(node.all_input_nodes) == 2:
574+
# Create a zero bias tensor if not presented
575+
out_channels = weight.shape[0]
576+
bias_name = "bias" + node.name.split("default", 1)[1]
577+
bias = tosa_fb.addConst(
578+
[out_channels],
579+
ts.DType.INT32 if is_quant_node else outp.dtype,
580+
[0] * out_channels,
581+
name=bias_name,
582+
)
583+
572584
if group.number > 1:
585+
assert (
586+
is_quant_node is False
587+
), "quantized depthwise convolution is not supported yet in BI mode"
588+
573589
# Transpose weight to [KH, KW, C, M]
574590
weight_HWCM_Order = [2, 3, 0, 1]
575591
weight_transposed = transpose_helper(
@@ -600,14 +616,17 @@ def preprocess( # noqa: C901
600616
# Transpose weight to [OC, H, W, IC]
601617
weight_CHWC_Order = [0, 2, 3, 1]
602618
weight_transposed = transpose_helper(
603-
tosa_fb, weight, weight_CHWC_Order, outp.dtype
619+
tosa_fb, weight, weight_CHWC_Order, actual_out_type
604620
)
605621

606622
## TOSA output shape is [NHWO]
607623
NHWO_Order = [0, 2, 3, 1]
608624
out_shape_TOSA_CONV2D = [outp.shape[i] for i in NHWO_Order]
625+
626+
# The output type is int32 when input type is int8.
609627
conv2d_res = tosa_fb.addIntermediate(
610-
out_shape_TOSA_CONV2D, outp.dtype
628+
out_shape_TOSA_CONV2D,
629+
ts.DType.INT32 if is_quant_node else outp.dtype,
611630
)
612631
tosa_fb.addOperator(
613632
TosaOp.Op().CONV2D,
@@ -624,6 +643,24 @@ def preprocess( # noqa: C901
624643
NOHW_Order = [0, 3, 1, 2]
625644
attr_output_transpose = ts.TosaSerializerAttribute()
626645
attr_output_transpose.TransposeAttribute(NOHW_Order)
646+
647+
# For quantized convolution, rescale the output value back to the same
648+
# integer value domain of the next op. Otherwise return float32 output.
649+
if is_quant_node:
650+
# Get scale_factor from input, weight, and output.
651+
_, input_scale, _, _, _, _ = getNodeArgs(node.args[0])
652+
_, weight_scale, _, _, _, _ = getNodeArgs(node.args[1])
653+
_, output_scale, _, _, _, _ = getNodeArgs(list(node.users)[0])
654+
655+
conv2d_res = tosa_quant_utils.buildRescaleOpConvOutput(
656+
tosa_fb,
657+
conv2d_res,
658+
actual_out_type,
659+
input_scale,
660+
weight_scale,
661+
output_scale,
662+
)
663+
627664
tosa_fb.addOperator(
628665
TosaOp.Op().TRANSPOSE,
629666
[conv2d_res.name],
@@ -879,7 +916,7 @@ def preprocess( # noqa: C901
879916
p_data = edge_program.state_dict[parameter_name]
880917

881918
assert isinstance(p_data, torch.Tensor), "Expect Attr to be tensor"
882-
weight_values = p_data.detach().numpy()
919+
parameter_values = p_data.detach().numpy()
883920

884921
# Check if they're for quantized nodes
885922
consumer_node = list(node.users)[0]
@@ -888,14 +925,14 @@ def preprocess( # noqa: C901
888925
consumer_node
889926
)
890927

891-
weight_values_quantized = (
892-
(weight_values / weight_node_scale.number)
928+
parameter_values_quantized = (
929+
(parameter_values / weight_node_scale.number)
893930
+ weight_node_zp.number
894931
).astype(np.int8)
895932
tosa_fb.addConst(
896933
inputs[0].shape,
897934
ts.DType.INT8,
898-
weight_values_quantized,
935+
parameter_values_quantized,
899936
name=out,
900937
)
901938
elif (
@@ -914,30 +951,55 @@ def preprocess( # noqa: C901
914951
weight_node
915952
)
916953

917-
weight_values_quantized = (
918-
weight_values / (input_node_scale * weight_node_scale)
954+
parameter_values_quantized = (
955+
parameter_values / (input_node_scale * weight_node_scale)
956+
).astype(np.int32)
957+
958+
tosa_fb.addConst(
959+
inputs[0].shape,
960+
ts.DType.INT32,
961+
parameter_values_quantized,
962+
name=out,
963+
)
964+
elif (
965+
consumer_node.target == exir_ops.edge.aten.convolution.default
966+
and list(consumer_node.users)[0].target == tosa_quant_utils.q_op
967+
):
968+
(
969+
input_node,
970+
weight_node,
971+
bias_node,
972+
) = consumer_node.all_input_nodes
973+
974+
input_node_scale, _ = getQuantNodeArgs(input_node)
975+
weight_node_scale, _ = getQuantNodeArgs(weight_node)
976+
977+
bias_scales = input_node_scale * weight_node_scale
978+
parameter_values_quantized = (
979+
parameter_values / bias_scales
919980
).astype(np.int32)
920981

921982
tosa_fb.addConst(
922983
inputs[0].shape,
923984
ts.DType.INT32,
924-
weight_values_quantized,
985+
parameter_values_quantized,
925986
name=out,
926987
)
927988
else:
928989
tosa_fb.addConst(
929-
inputs[0].shape, inputs[0].dtype, weight_values, name=out
990+
inputs[0].shape, inputs[0].dtype, parameter_values, name=out
930991
)
992+
931993
elif out in edge_program.graph_signature.inputs_to_buffers:
932994
parameter_name = edge_program.graph_signature.inputs_to_buffers[
933995
node.name
934996
]
935997
p_data = edge_program.state_dict[parameter_name]
936998

937999
assert isinstance(p_data, torch.Tensor), "Expect Attr to be tensor"
938-
weight_values = p_data.detach().numpy()
1000+
buffer_values = p_data.detach().numpy()
9391001
tosa_fb.addConst(
940-
inputs[0].shape, inputs[0].dtype, weight_values, name=out
1002+
inputs[0].shape, inputs[0].dtype, buffer_values, name=out
9411003
)
9421004
else:
9431005
tensor = ts.TosaSerializerTensor(

0 commit comments

Comments
 (0)