Skip to content

Add initial lowering of aten.convolution to tosa.conv2d support #615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 79 additions & 17 deletions backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,23 +553,39 @@ def preprocess( # noqa: C901
elif exir_ops.edge.aten.convolution.default == node.target:
input, weight, bias, stride, pad, dilation, _, _, group = inputs

# Currently only int8 is supported in quantized types.
actual_out_type = ts.DType.INT8 if is_quant_node else outp.dtype

## Transpose input tensor to NHWC_Order for TOSA
NHWC_Order = [0, 2, 3, 1]
input_transposed = transpose_helper(
tosa_fb, input, NHWC_Order, outp.dtype
tosa_fb, input, NHWC_Order, actual_out_type
)

## CONV2DOp
# Get the attributes of convolution.
attr = ts.TosaSerializerAttribute()
# PAD
pad_attr = [val for val in pad.special for _ in (0, 1)]
# Stride
stride_attr = stride.special
# Dilation
dilation_attr = dilation.special
attr.ConvAttribute(pad_attr, stride_attr, dilation_attr, 0, 0)

# Non-bias case.
if len(node.all_input_nodes) == 2:
# Create a zero bias tensor if not presented
out_channels = weight.shape[0]
bias_name = "bias" + node.name.split("default", 1)[1]
bias = tosa_fb.addConst(
[out_channels],
ts.DType.INT32 if is_quant_node else outp.dtype,
[0] * out_channels,
name=bias_name,
)

if group.number > 1:
assert (
is_quant_node is False
), "quantized depthwise convolution is not supported yet in BI mode"

# Transpose weight to [KH, KW, C, M]
weight_HWCM_Order = [2, 3, 0, 1]
weight_transposed = transpose_helper(
Expand Down Expand Up @@ -600,14 +616,17 @@ def preprocess( # noqa: C901
# Transpose weight to [OC, H, W, IC]
weight_CHWC_Order = [0, 2, 3, 1]
weight_transposed = transpose_helper(
tosa_fb, weight, weight_CHWC_Order, outp.dtype
tosa_fb, weight, weight_CHWC_Order, actual_out_type
)

## TOSA output shape is [NHWO]
NHWO_Order = [0, 2, 3, 1]
out_shape_TOSA_CONV2D = [outp.shape[i] for i in NHWO_Order]

# The output type is int32 when input type is int8.
conv2d_res = tosa_fb.addIntermediate(
out_shape_TOSA_CONV2D, outp.dtype
out_shape_TOSA_CONV2D,
ts.DType.INT32 if is_quant_node else outp.dtype,
)
tosa_fb.addOperator(
TosaOp.Op().CONV2D,
Expand All @@ -624,6 +643,24 @@ def preprocess( # noqa: C901
NOHW_Order = [0, 3, 1, 2]
attr_output_transpose = ts.TosaSerializerAttribute()
attr_output_transpose.TransposeAttribute(NOHW_Order)

# For quantized convolution, rescale the output value back to the same
# integer value domain of the next op. Otherwise return float32 output.
if is_quant_node:
# Get scale_factor from input, weight, and output.
_, input_scale, _, _, _, _ = getNodeArgs(node.args[0])
_, weight_scale, _, _, _, _ = getNodeArgs(node.args[1])
_, output_scale, _, _, _, _ = getNodeArgs(list(node.users)[0])

conv2d_res = tosa_quant_utils.buildRescaleOpConvOutput(
tosa_fb,
conv2d_res,
actual_out_type,
input_scale,
weight_scale,
output_scale,
)

tosa_fb.addOperator(
TosaOp.Op().TRANSPOSE,
[conv2d_res.name],
Expand Down Expand Up @@ -879,7 +916,7 @@ def preprocess( # noqa: C901
p_data = edge_program.state_dict[parameter_name]

assert isinstance(p_data, torch.Tensor), "Expect Attr to be tensor"
weight_values = p_data.detach().numpy()
parameter_values = p_data.detach().numpy()

# Check if they're for quantized nodes
consumer_node = list(node.users)[0]
Expand All @@ -888,14 +925,14 @@ def preprocess( # noqa: C901
consumer_node
)

weight_values_quantized = (
(weight_values / weight_node_scale.number)
parameter_values_quantized = (
(parameter_values / weight_node_scale.number)
+ weight_node_zp.number
).astype(np.int8)
tosa_fb.addConst(
inputs[0].shape,
ts.DType.INT8,
weight_values_quantized,
parameter_values_quantized,
name=out,
)
elif (
Expand All @@ -914,30 +951,55 @@ def preprocess( # noqa: C901
weight_node
)

weight_values_quantized = (
weight_values / (input_node_scale * weight_node_scale)
parameter_values_quantized = (
parameter_values / (input_node_scale * weight_node_scale)
).astype(np.int32)

tosa_fb.addConst(
inputs[0].shape,
ts.DType.INT32,
parameter_values_quantized,
name=out,
)
elif (
consumer_node.target == exir_ops.edge.aten.convolution.default
and list(consumer_node.users)[0].target == tosa_quant_utils.q_op
):
(
input_node,
weight_node,
bias_node,
) = consumer_node.all_input_nodes

input_node_scale, _ = getQuantNodeArgs(input_node)
weight_node_scale, _ = getQuantNodeArgs(weight_node)

bias_scales = input_node_scale * weight_node_scale
parameter_values_quantized = (
parameter_values / bias_scales
).astype(np.int32)

tosa_fb.addConst(
inputs[0].shape,
ts.DType.INT32,
weight_values_quantized,
parameter_values_quantized,
name=out,
)
else:
tosa_fb.addConst(
inputs[0].shape, inputs[0].dtype, weight_values, name=out
inputs[0].shape, inputs[0].dtype, parameter_values, name=out
)

elif out in edge_program.graph_signature.inputs_to_buffers:
parameter_name = edge_program.graph_signature.inputs_to_buffers[
node.name
]
p_data = edge_program.state_dict[parameter_name]

assert isinstance(p_data, torch.Tensor), "Expect Attr to be tensor"
weight_values = p_data.detach().numpy()
buffer_values = p_data.detach().numpy()
tosa_fb.addConst(
inputs[0].shape, inputs[0].dtype, weight_values, name=out
inputs[0].shape, inputs[0].dtype, buffer_values, name=out
)
else:
tensor = ts.TosaSerializerTensor(
Expand Down
Loading