Skip to content

Add direct copy fast path for portable copy op #10487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 27, 2025

Conversation

GregoryComer
Copy link
Member

@GregoryComer GregoryComer commented Apr 25, 2025

Summary:
The PR adds a direct memcpy fast-path for portable copy and copy_ ops. This speeds up copy significantly in cases where no broadcasting is needed. Note that the copy op always checks that dim order and dtype match, so this should be sound in all cases where the shape matches (no broadcasting).

This is most noticeable when copying buffer mutations back, such as transformer KV cache when managing the cache as a mutable buffer. Prior to this change, an encoder/decoder model was taking ~25% of the total runtime copying KV cache back after permuting. With this change, the copy becomes significantly cheaper.

I benchmarked a simple model on S23 and Pixel 5:

class TestModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("buffer", torch.zeros((2, 10, 1024, 1024)))

    def forward(self, x):
        self.buffer.add_(x)
        return self.buffer


model = TestModel()
inputs = (torch.randn(2, 10, 1024, 1024),)

lowered = to_edge_transform_and_lower(
    torch.export.export(model, inputs),
    partitioner=[XnnpackPartitioner()],
).to_executorch()

S23, average of 50 runs, time in copy_:
4.1ms vs 22.3ms

Pixel 5, average of 50 runs, time in copy_:
12.1ms vs 66.6ms

This is approximately a ~5.5x speedup of the copy operator.

Differential Revision: D73656456

Copy link

pytorch-bot bot commented Apr 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/10487

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 60bb4ef with merge base 12079fe (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 25, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73656456

@GregoryComer GregoryComer added the release notes: ops & kernels Changes to the opset and any new / changed kernel implementations label Apr 25, 2025
@facebook-github-bot
Copy link
Contributor

@GregoryComer has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

GregoryComer added a commit to GregoryComer/executorch that referenced this pull request Apr 26, 2025
Summary:
The PR adds a direct memcpy fast-path for portable copy and copy_ ops. This speeds up copy significantly in cases where no broadcasting is needed.

This is most noticable when copying buffer mutations back, such as transformer KV cache when managing the cache as a mutable buffer. Prior to this change, an encoder/decoder model was taking ~25% of the total runtime copying KV cache back after permuting. With this change, the copy becomes significantly cheaper.

I benchmarked a simple model on S23 and Pixel 5:
```
class TestModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("buffer", torch.zeros((2, 10, 1024, 1024)))

    def forward(self, x):
        self.buffer.add_(x)
        return self.buffer


model = TestModel()
inputs = (torch.randn(2, 10, 1024, 1024),)

lowered = to_edge_transform_and_lower(
    torch.export.export(model, inputs),
    partitioner=[XnnpackPartitioner()],
).to_executorch()
```

S23, average of 50 runs, time in copy_:
4.1ms vs 22.3ms

Pixel 5, average of 50 runs, time in copy_:
12.1ms vs 66.6ms

This is approximately a ~5.5x speedup of the copy operator.


Reviewed By: swolchok

Differential Revision: D73656456

Pulled By: GregoryComer
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73656456

Summary:
The PR adds a direct memcpy fast-path for portable copy and copy_ ops. This speeds up copy significantly in cases where no broadcasting is needed.

This is most noticable when copying buffer mutations back, such as transformer KV cache when managing the cache as a mutable buffer. Prior to this change, an encoder/decoder model was taking ~25% of the total runtime copying KV cache back after permuting. With this change, the copy becomes significantly cheaper.

I benchmarked a simple model on S23 and Pixel 5:
```
class TestModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("buffer", torch.zeros((2, 10, 1024, 1024)))

    def forward(self, x):
        self.buffer.add_(x)
        return self.buffer


model = TestModel()
inputs = (torch.randn(2, 10, 1024, 1024),)

lowered = to_edge_transform_and_lower(
    torch.export.export(model, inputs),
    partitioner=[XnnpackPartitioner()],
).to_executorch()
```

S23, average of 50 runs, time in copy_:
4.1ms vs 22.3ms

Pixel 5, average of 50 runs, time in copy_:
12.1ms vs 66.6ms

This is approximately a ~5.5x speedup of the copy operator.


Reviewed By: swolchok

Differential Revision: D73656456

Pulled By: GregoryComer
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73656456

@GregoryComer
Copy link
Member Author

Overriding lint failure for land - broken trunk.

@facebook-github-bot facebook-github-bot merged commit 9ea9313 into pytorch:main Apr 27, 2025
84 of 86 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported release notes: ops & kernels Changes to the opset and any new / changed kernel implementations
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants