You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
Pull Request resolved: #358
### Background
A common pattern we when encountering addmm is that weights are permuted before given to addmm. This is because generally for torch.nn.Linear, the input shape and weight shape are given as such:
```
input: (*, in_features)
weight: (out_features,in_features)
```
while the input shape and weight shape of addmm are the following:
```
input1 (input): (*, in_features)
input2 (weight): (in_features, out_features)
```
so when decomposing nn.Linear to addmm, the weights go through a permute node to comply with addmm's shapes
### XNNPACK Status
XNNPACK can handle both the transpose and normal weight shape, however it requires a flag for whether or not the weights are transposed. So an easy optimization is to skip the permute node and use the flag.
### Change and Motivation
Currently, we have hardcoded some of this optimization logic directly into serialization. I believe that serialization should not be aware of these optimizations, which is why I am removing this logic from within serialization. Instead this logic should be performed completely by the addmm --> linear pass which recomposes permute + addmm into a singular linear. We should no longer rely on serialization logic to perform this logic (Right now its errorneous and causing a bug).
Reviewed By: kirklandsign
Differential Revision: D49129704
fbshipit-source-id: 1134c33f76eb27ac05a90b29c6dc057c8c647b58
0 commit comments