-
Notifications
You must be signed in to change notification settings - Fork 607
remove transpose addmm weights hack #358
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This pull request was exported from Phabricator. Differential Revision: D49129704 |
This pull request was exported from Phabricator. Differential Revision: D49129704 |
2c68280
to
a48a8bb
Compare
This pull request was exported from Phabricator. Differential Revision: D49129704 |
a48a8bb
to
c0fc7b3
Compare
This pull request was exported from Phabricator. Differential Revision: D49129704 |
c0fc7b3
to
975e03f
Compare
This pull request was exported from Phabricator. Differential Revision: D49129704 |
975e03f
to
0338e22
Compare
This pull request was exported from Phabricator. Differential Revision: D49129704 |
0338e22
to
2db4474
Compare
Differential Revision: D49068312 fbshipit-source-id: 6709b38f3be9da89f51db78740e9533e220a4a04
Differential Revision: D49301134 fbshipit-source-id: ccd56ffbef578989d2b3436adc0fe1e92a0431ce
Summary: Pull Request resolved: pytorch/executorch#358 ### Background A common pattern we when encountering addmm is that weights are permuted before given to addmm. This is because generally for torch.nn.Linear, the input shape and weight shape are given as such: ``` input: (*, in_features) weight: (out_features,in_features) ``` while the input shape and weight shape of addmm are the following: ``` input1 (input): (*, in_features) input2 (weight): (in_features, out_features) ``` so when decomposing nn.Linear to addmm, the weights go through a permute node to comply with addmm's shapes ### XNNPACK Status XNNPACK can handle both the transpose and normal weight shape, however it requires a flag for whether or not the weights are transposed. So an easy optimization is to skip the permute node and use the flag. ### Change and Motivation Currently, we have hardcoded some of this optimization logic directly into serialization. I believe that serialization should not be aware of these optimizations, which is why I am removing this logic from within serialization. Instead this logic should be performed completely by the addmm --> linear pass which recomposes permute + addmm into a singular linear. We should no longer rely on serialization logic to perform this logic (Right now its errorneous and causing a bug). Reviewed By: kirklandsign Differential Revision: D49129704 fbshipit-source-id: 9c41a2b4433c62b06e703b98ed8e0442275c3501
This pull request was exported from Phabricator. Differential Revision: D49129704 |
2db4474
to
411273d
Compare
This pull request has been merged in 2b7eb62. |
Summary:
Background
A common pattern we when encountering addmm is that weights are permuted before given to addmm. This is because generally for torch.nn.Linear, the input shape and weight shape are given as such:
while the input shape and weight shape of addmm are the following:
so when decomposing nn.Linear to addmm, the weights go through a permute node to comply with addmm's shapes
XNNPACK Status
XNNPACK can handle both the transpose and normal weight shape, however it requires a flag for whether or not the weights are transposed. So an easy optimization is to skip the permute node and use the flag.
Change and Motivation
Currently, we have hardcoded some of this optimization logic directly into serialization. I believe that serialization should not be aware of these optimizations, which is why I am removing this logic from within serialization. Instead this logic should be performed completely by the addmm --> linear pass which recomposes permute + addmm into a singular linear. We should no longer rely on serialization logic to perform this logic (Right now its errorneous and causing a bug).
Reviewed By: kirklandsign
Differential Revision: D49129704