Skip to content

Commit 2c68280

Browse files
mcr229facebook-github-bot
authored andcommitted
remove transpose addmm weights hack
Summary: ### Background A common pattern we when encountering addmm is that weights are permuted before given to addmm. This is because generally for torch.nn.Linear, the input shape and weight shape are given as such: ``` input: (*, in_features) weight: (out_features,in_features) ``` while the input shape and weight shape of addmm are the following: ``` input1 (input): (*, in_features) input2 (weight): (in_features, out_features) ``` so when decomposing nn.Linear to addmm, the weights go through a permute node to comply with addmm's shapes ### XNNPACK Status XNNPACK can handle both the transpose and normal weight shape, however it requires a flag for whether or not the weights are transposed. So an easy optimization is to skip the permute node and use the flag. ### Change and Motivation Currently, we have hardcoded some of this optimization logic directly into serialization. I believe that serialization should not be aware of these optimizations, which is why I am removing this logic from within serialization. Instead this logic should be performed completely by the addmm --> linear pass which recomposes permute + addmm into a singular linear. We should no longer rely on serialization logic to perform this logic (Right now its errorneous and causing a bug). Reviewed By: kirklandsign Differential Revision: D49129704 fbshipit-source-id: 57a3f5770ce3ece030ae0dd311f0f97a9da1c228
1 parent ed1f0fd commit 2c68280

File tree

2 files changed

+1
-16
lines changed

2 files changed

+1
-16
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,6 @@ def define_tensor(
289289

290290
# convert tensor shape must reflect memory format, default is contiguous, so
291291
# only permute shape if we are converting the tensor to nhwc format
292-
if tensor.target in (
293-
exir_ops.edge.aten.permute_copy.default,
294-
exir_ops.edge.aten.t_copy.default,
295-
):
296-
# We ignore transpose nodes and reverse the dims to before it
297-
dims = dims[::-1]
298292
if swap_nc_for_depthwise_weights:
299293
dims = [dims[1], dims[0]] + dims[2:]
300294
if convert_to_nhwc:

backends/xnnpack/operators/op_addmm.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from executorch.backends.xnnpack.utils.xnnpack_constants import (
2323
XNN_FLAG_TRANSPOSE_WEIGHTS,
2424
)
25-
from executorch.exir.dialects._ops import ops as exir_ops
2625

2726

2827
@register_node_visitor
@@ -56,15 +55,7 @@ def define_node(
5655
# output
5756
output_id = vals_to_ids[node]
5857

59-
flag = (
60-
0
61-
if get_input_node(node, 2).target
62-
in (
63-
exir_ops.edge.aten.permute_copy.default,
64-
exir_ops.edge.aten.t_copy.default,
65-
)
66-
else XNN_FLAG_TRANSPOSE_WEIGHTS
67-
)
58+
flag = XNN_FLAG_TRANSPOSE_WEIGHTS
6859

6960
ser_node = XNode(
7061
xnode_union=XNNFullyConnected(

0 commit comments

Comments
 (0)