Skip to content

remove transpose addmm weights hack #358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from

Conversation

mcr229
Copy link
Contributor

@mcr229 mcr229 commented Sep 15, 2023

Summary:

Background

A common pattern we when encountering addmm is that weights are permuted before given to addmm. This is because generally for torch.nn.Linear, the input shape and weight shape are given as such:

input: (*, in_features)
weight: (out_features,in_features)

while the input shape and weight shape of addmm are the following:

input1 (input): (*, in_features)
input2 (weight): (in_features, out_features)

so when decomposing nn.Linear to addmm, the weights go through a permute node to comply with addmm's shapes

XNNPACK Status

XNNPACK can handle both the transpose and normal weight shape, however it requires a flag for whether or not the weights are transposed. So an easy optimization is to skip the permute node and use the flag.

Change and Motivation

Currently, we have hardcoded some of this optimization logic directly into serialization. I believe that serialization should not be aware of these optimizations, which is why I am removing this logic from within serialization. Instead this logic should be performed completely by the addmm --> linear pass which recomposes permute + addmm into a singular linear. We should no longer rely on serialization logic to perform this logic (Right now its errorneous and causing a bug).

Reviewed By: kirklandsign

Differential Revision: D49129704

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 15, 2023
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D49129704

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D49129704

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D49129704

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D49129704

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D49129704

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D49129704

mcr229 and others added 3 commits September 15, 2023 18:43
Differential Revision: D49068312

fbshipit-source-id: 6709b38f3be9da89f51db78740e9533e220a4a04
Differential Revision: D49301134

fbshipit-source-id: ccd56ffbef578989d2b3436adc0fe1e92a0431ce
Summary:
Pull Request resolved: pytorch/executorch#358

### Background

A common pattern we when encountering addmm is that weights are permuted before given to addmm. This is because generally for torch.nn.Linear, the input shape and weight shape are given as such:
```
input: (*, in_features)
weight: (out_features,in_features)
```

while the input shape and weight shape of addmm are the following:
```
input1 (input): (*, in_features)
input2 (weight): (in_features, out_features)
```

so when decomposing nn.Linear to addmm, the weights go through a permute node to comply with addmm's shapes

### XNNPACK Status
XNNPACK can handle both the transpose and normal weight shape, however it requires a flag for whether or not the weights are transposed. So an easy optimization is to skip the permute node and use the flag.

### Change and Motivation
Currently, we have hardcoded some of this optimization logic directly into serialization. I believe that serialization should not be aware of these optimizations, which is why I am removing this logic from within serialization. Instead this logic should be performed completely by the addmm --> linear pass which recomposes permute + addmm into a singular linear. We should no longer rely on serialization logic to perform this logic (Right now its errorneous and causing a bug).

Reviewed By: kirklandsign

Differential Revision: D49129704

fbshipit-source-id: 9c41a2b4433c62b06e703b98ed8e0442275c3501
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D49129704

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 2b7eb62.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants