Skip to content

Commit 99f912a

Browse files
Jerry-Gefacebook-github-bot
authored andcommitted
Replace Linear lowering of using Matmul with Conv2d (#1336)
Summary: - The existing Vela compiler doesn't support Matmul and TOSA.Fully_Connected will be deprecated - Also add support for linear layers with rank>2 + tests Pull Request resolved: #1336 Reviewed By: cccclai Differential Revision: D51922063 Pulled By: digantdesai fbshipit-source-id: a8ad11f170911c1543fe88fd2ef1a1356ac859c3
1 parent cb76c51 commit 99f912a

File tree

10 files changed

+182
-75
lines changed

10 files changed

+182
-75
lines changed

backends/arm/arm_backend.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
from executorch.backends.arm.operators.op_placeholder import process_placeholder
2020
from executorch.backends.arm.tosa_mapping import TosaArg
2121
from executorch.backends.arm.tosa_quant_utils import isQuantNode
22-
from executorch.backends.arm.tosa_utils import dbg_fail, dbg_tosa_dump
22+
from executorch.backends.arm.tosa_utils import (
23+
dbg_fail,
24+
dbg_tosa_dump,
25+
is_permute_node_before_addmm,
26+
)
2327
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
2428
from executorch.exir.backend.compile_spec_schema import CompileSpec
2529
from torch._export.exported_program import ExportedProgram
@@ -74,7 +78,9 @@ def preprocess( # noqa: C901
7478
# Add output to TOSA graph
7579
tosa_graph.currRegion.currBasicBlock.addTensor(
7680
output.name,
77-
output.shape,
81+
inputs[0].shape
82+
if is_permute_node_before_addmm(node)
83+
else output.shape,
7884
ts.DType.INT8 if is_quant_node else output.dtype,
7985
)
8086

backends/arm/operators/op_addmm.py

Lines changed: 79 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
computeMultiplierAndShift,
1717
getQuantNodeArgs,
1818
)
19-
from executorch.backends.arm.tosa_utils import promote_shape
19+
20+
from executorch.backends.arm.tosa_utils import buildReshape
21+
from executorch.exir.dialects._ops import ops as exir_ops
2022
from serializer.tosa_serializer import TosaOp
2123

2224

@@ -37,69 +39,89 @@ def define_node(
3739
) -> None:
3840
bias, input, weight = inputs
3941

40-
output_dtype = ts.DType.INT8 if is_quant_node else output.dtype
42+
N = input.shape[0]
43+
input_channels = input.shape[1]
44+
output_channels = weight.shape[1]
4145

42-
# Reshape input, weight, bias tensors
43-
input_reshape_res = promote_shape(
44-
tosa_graph, input, (1,) + input.shape, output_dtype
45-
)
46-
weight_reshape_res = promote_shape(
47-
tosa_graph, weight, (1,) + weight.shape, output_dtype
46+
input_new_shape = (N, 1, 1, input_channels)
47+
input_reshaped = tosa_graph.addIntermediate(
48+
input_new_shape,
49+
ts.DType.INT8 if is_quant_node else input.dtype,
4850
)
4951

50-
bias_dtype = ts.DType.INT32 if is_quant_node else output.dtype
51-
bias_reshape_res = promote_shape(
52-
tosa_graph,
53-
bias,
54-
(
55-
1,
56-
1,
57-
)
58-
+ bias.shape,
59-
bias_dtype,
60-
)
52+
buildReshape(tosa_graph, input.name, input_new_shape, input_reshaped.name)
6153

62-
# Add dummy batch 1 to mm_shape
63-
mm_shape = (1, input.shape[0], weight.shape[1])
64-
# Define Intermediate tensor for MatMul res
65-
mm_res = tosa_graph.addIntermediate(
66-
mm_shape, ts.DType.INT32 if is_quant_node else output_dtype
54+
weight_new_shape = (output_channels, 1, 1, input_channels)
55+
weight_reshaped = tosa_graph.addIntermediate(
56+
weight_new_shape,
57+
ts.DType.INT8 if is_quant_node else weight.dtype,
6758
)
6859

69-
# Add MatMulOp
70-
attr_matmul = ts.TosaSerializerAttribute()
71-
a_zp, b_zp = (-128, 0) if is_quant_node else (0, 0)
72-
attr_matmul.MatMulAttribute(a_zp, b_zp)
73-
tosa_graph.addOperator(
74-
TosaOp.Op().MATMUL,
75-
[input_reshape_res.name, weight_reshape_res.name],
76-
[mm_res.name],
77-
attr_matmul,
60+
buildReshape(tosa_graph, weight.name, weight_new_shape, weight_reshaped.name)
61+
62+
# Get the attributes of convolution.
63+
attr = ts.TosaSerializerAttribute()
64+
pad_attr = [0, 0, 0, 0]
65+
stride_attr = [1, 1]
66+
dilation_attr = [1, 1]
67+
68+
input_zp = -128 if is_quant_node else 0
69+
attr.ConvAttribute(
70+
pad=pad_attr,
71+
stride=stride_attr,
72+
dilation=dilation_attr,
73+
input_zp=input_zp,
74+
weight_zp=0,
75+
local_bound=False,
7876
)
7977

80-
# Add AddOp
81-
add_res = tosa_graph.addIntermediate(
82-
mm_shape, ts.DType.INT32 if is_quant_node else output_dtype
78+
conv2d_output_shape = (N, 1, 1, output_channels)
79+
conv2d_res = tosa_graph.addIntermediate(
80+
conv2d_output_shape,
81+
ts.DType.INT32 if is_quant_node else output.dtype,
8382
)
8483

84+
# U55 doesn't support tosa.matmul and tosa.fully_connected will be deprecated
85+
# TOSA Conv2d input is NHWC and weights are in OHWI
8586
tosa_graph.addOperator(
86-
TosaOp.Op().ADD,
87-
[bias_reshape_res.name, mm_res.name],
88-
[add_res.name],
89-
None,
87+
TosaOp.Op().CONV2D,
88+
[
89+
input_reshaped.name,
90+
weight_reshaped.name,
91+
bias.name,
92+
],
93+
[conv2d_res.name],
94+
attr,
9095
)
9196

97+
result_shape = (N, output_channels)
98+
9299
if is_quant_node:
93100
# Read inputs' parent nodes
94-
#
95101
_, input_node, weight_node = node.all_input_nodes
96-
input_scale, _ = getQuantNodeArgs(input_node)
102+
103+
# rank > 2 linear layer
104+
if input_node.target == exir_ops.edge.aten.view_copy.default:
105+
quant_node = input_node.all_input_nodes[0]
106+
input_scale, _ = getQuantNodeArgs(quant_node)
107+
consumer_node = list(node.users)[0]
108+
consumer_consumer_node = list(consumer_node.users)[0]
109+
(
110+
consumer_node_scale,
111+
consumer_node_node_zp,
112+
) = getQuantNodeArgs(consumer_consumer_node)
113+
114+
else:
115+
input_scale, _ = getQuantNodeArgs(input_node)
116+
consumer_node = list(node.users)[0]
117+
(
118+
consumer_node_scale,
119+
consumer_node_node_zp,
120+
) = getQuantNodeArgs(consumer_node)
121+
97122
weight_node_q_node = weight_node.all_input_nodes[0]
98123
weight_scale, _ = getQuantNodeArgs(weight_node_q_node)
99124

100-
consumer_node = list(node.users)[0]
101-
consumer_node_scale, consumer_node_node_zp = getQuantNodeArgs(consumer_node)
102-
103125
output_rescale_scale = (input_scale * weight_scale) / consumer_node_scale
104126
(
105127
multiplier_output,
@@ -115,20 +137,20 @@ def define_node(
115137
scale32=True,
116138
double_round=True,
117139
per_channel=False,
140+
input_unsigned=False,
141+
output_unsigned=False,
118142
)
119-
add_res_int8 = tosa_graph.addIntermediate(mm_shape, ts.DType.INT8)
143+
144+
reshaped_res = tosa_graph.addIntermediate(result_shape, ts.DType.INT32)
145+
buildReshape(tosa_graph, conv2d_res.name, result_shape, reshaped_res.name)
146+
120147
tosa_graph.addOperator(
121148
TosaOp.Op().RESCALE,
122-
[add_res.name],
123-
[add_res_int8.name],
149+
[reshaped_res.name],
150+
[output.name],
124151
attr_rescale_output,
125152
)
126-
# Reshape final result to original shape
127-
attr_out = ts.TosaSerializerAttribute()
128-
attr_out.ReshapeAttribute(output.shape)
129-
tosa_graph.addOperator(
130-
TosaOp.Op().RESHAPE,
131-
[add_res_int8.name if is_quant_node else add_res.name],
132-
[output.name],
133-
attr_out,
134-
)
153+
154+
else:
155+
# non-quantized case
156+
buildReshape(tosa_graph, conv2d_res.name, result_shape, output.name)

backends/arm/operators/op_conv2d.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,14 @@ def define_node(
5252
pad_attr = [val for val in pad.special for _ in (0, 1)]
5353
stride_attr = stride.special
5454
dilation_attr = dilation.special
55-
attr.ConvAttribute(pad_attr, stride_attr, dilation_attr, 0, 0)
55+
attr.ConvAttribute(
56+
pad=pad_attr,
57+
stride=stride_attr,
58+
dilation=dilation_attr,
59+
input_zp=0,
60+
weight_zp=0,
61+
local_bound=False,
62+
)
5663

5764
# Non-bias case.
5865
if len(node.all_input_nodes) == 2:

backends/arm/operators/op_permute.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
register_node_visitor,
1313
)
1414
from executorch.backends.arm.tosa_mapping import TosaArg
15+
from executorch.backends.arm.tosa_utils import is_permute_node_before_addmm
1516
from serializer.tosa_serializer import TosaOp
1617

1718

@@ -30,6 +31,13 @@ def define_node(
3031
output: TosaArg,
3132
is_quant_node: bool,
3233
) -> None:
34+
if is_permute_node_before_addmm(node):
35+
## Simply add an identityOp
36+
tosa_graph.addOperator(
37+
TosaOp.Op().IDENTITY, [inputs[0].name], [output.name]
38+
)
39+
return
40+
3341
attr = ts.TosaSerializerAttribute()
3442
attr.TransposeAttribute(inputs[1].special)
3543
tosa_graph.addOperator(

backends/arm/operators/op_placeholder.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
isQuantArg,
99
q_op,
1010
)
11-
from executorch.backends.arm.tosa_utils import getNodeArgs
11+
from executorch.backends.arm.tosa_utils import getNodeArgs, is_bias_node_for_addmm
1212
from executorch.exir.dialects._ops import ops as exir_ops
1313
from torch._export.exported_program import ExportedProgram
1414

@@ -42,28 +42,30 @@ def process_placeholder(
4242
parameter_values_quantized,
4343
name=out,
4444
)
45-
elif (
46-
consumer_node.target == exir_ops.edge.aten.addmm.default
47-
and list(consumer_node.users)[0].target == q_op
48-
):
45+
elif is_bias_node_for_addmm(node):
4946
(
5047
_,
5148
input_node,
5249
weight_node_permuted,
5350
) = consumer_node.all_input_nodes
5451
weight_node = weight_node_permuted.all_input_nodes[0]
5552

56-
input_node_scale, _ = getQuantNodeArgs(input_node)
53+
# input_node_scale, _ = getQuantNodeArgs(input_node)
54+
if input_node.target == exir_ops.edge.aten.view_copy.default:
55+
input_node_scale, _ = getQuantNodeArgs(input_node.all_input_nodes[0])
56+
else:
57+
input_node_scale, _ = getQuantNodeArgs(input_node)
58+
5759
weight_node_scale, weight_node_zp = getQuantNodeArgs(weight_node)
5860

59-
parameter_values_quantized = (
61+
bias_values_quantized = (
6062
parameter_values / (input_node_scale * weight_node_scale)
6163
).astype(np.int32)
6264

6365
tosa_graph.addConst(
6466
inputs[0].shape,
6567
ts.DType.INT32,
66-
parameter_values_quantized,
68+
bias_values_quantized,
6769
name=out,
6870
)
6971
elif (

backends/arm/test/test_models.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,29 @@ def forward(self, x, y):
128128
@register_test
129129
class simple_linear(torch.nn.Module):
130130
inputs = {
131-
TosaProfile.BI: (torch.ones(100, 20),),
132-
TosaProfile.MI: (torch.ones(100, 20),),
131+
TosaProfile.BI: (torch.rand(1, 2),),
132+
TosaProfile.MI: (torch.rand(1, 2),),
133133
}
134134

135135
def __init__(self):
136136
super().__init__()
137137
torch.manual_seed(seed)
138+
self.fc = torch.nn.Linear(2, 3)
139+
140+
def forward(self, x):
141+
x = self.fc(x)
142+
return x
143+
144+
@register_test
145+
class simple_linear_rank4(torch.nn.Module):
146+
inputs = {
147+
TosaProfile.BI: (torch.rand(5, 10, 25, 20),),
148+
TosaProfile.MI: (torch.rand(5, 10, 25, 20),),
149+
}
150+
151+
def __init__(self):
152+
super().__init__()
153+
torch.manual_seed(42)
138154
self.fc = torch.nn.Linear(20, 30)
139155

140156
def forward(self, x):
Submodule serialization_lib updated from 9601cbd to 92358fc

backends/arm/tosa_quant_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@
2020
def isQuantNode(node):
2121
consumer_node = list(node.users)[0]
2222
input = node.all_input_nodes[0]
23+
24+
# For Rank > 2 Linear layers, the quant node is after the view_copy
25+
if (
26+
node.target == exir_ops.edge.aten.addmm.default
27+
and list(node.users)[0].target == exir_ops.edge.aten.view_copy.default
28+
):
29+
consumer_consumer_node = list(consumer_node.users)[0]
30+
return True if consumer_consumer_node.target == q_op else False
31+
2332
return (
2433
consumer_node.target == q_op
2534
or node.target in dq_q_ops
@@ -106,6 +115,8 @@ def buildRescale(
106115
scale32=is_scale32,
107116
double_round=is_double_round,
108117
per_channel=False,
118+
input_unsigned=False,
119+
output_unsigned=False,
109120
)
110121

111122
rescale_out = tosa_fb.addIntermediate(output_shape, output_type)
@@ -129,6 +140,8 @@ def buildRescaleToInt32(
129140
scale32=is_scale32,
130141
double_round=is_double_round,
131142
per_channel=False,
143+
input_unsigned=False,
144+
output_unsigned=False,
132145
)
133146
input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input.shape, ts.DType.INT32)
134147
tosa_fb.addOperator(
@@ -160,6 +173,8 @@ def buildRescaleFromInt32(
160173
scale32=is_scale32,
161174
double_round=is_double_round,
162175
per_channel=False,
176+
input_unsigned=False,
177+
output_unsigned=False,
163178
)
164179

165180
tosa_fb.addOperator(

0 commit comments

Comments
 (0)