Skip to content

[llm] Use new API to register custom ops for llama model #2840

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

larryliu0820
Copy link
Contributor

Summary: Using the following 2 APIs:

  • EXECUTORCH_LIBRARY to replace the need of a yaml file. With this macro we can directly register a custom kernel into ExecuTorch runtime.
  • WRAP_TO_ATEN allows custom op authors to use the same kernel for ExecuTorch and PyTorch. This can be helpful during debugging.

Test Plan: Rely on the new CI job test_llama with xnnpack+kv+custom option.

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Apr 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/2840

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit d1612c8 with merge base 081c849 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@larryliu0820 larryliu0820 requested a review from kimishpatel April 3, 2024 21:48
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 3, 2024
@facebook-github-bot
Copy link
Contributor

@larryliu0820 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@@ -72,6 +72,7 @@ target_include_directories(
xnnpack_schema INTERFACE ${_xnnpack_schema__include_dir}
${EXECUTORCH_ROOT}/third-party/flatbuffers/include)

target_compile_options(pthreadpool PUBLIC ${_common_compile_options})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? and pthreadpool related stuff has moved to root CMakeLists

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my local build failed complaining no -fPIC

if(ANDROID)
list(APPEND link_libraries log)
endif()

target_compile_options(llama_main PUBLIC ${_common_compile_options}
-DET_USE_THREADPOOL)
-DET_USE_THREADPOOL)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reformat cmakelists.txt

@@ -22,6 +21,7 @@
#include <executorch/backends/xnnpack/threadpool/threadpool.h>
#include <executorch/extension/parallel/thread_parallel.h>
#endif
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should rename this header to something else

// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const c10::optional<double> scale) {
auto output = at::empty_like(q_projected);
WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whats this 11?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 11th argument is out

Copy link
Contributor

@kimishpatel kimishpatel Apr 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

number of args? I think there is some template magic that allows you to count number of args, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this is telling the template we need to return the 11th object, since it is out. See this code: https://github.com/pytorch/executorch/blob/main/extension/aten_util/make_aten_functor_from_et_functor.h#L268-L277

"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
"float drpout_p=0.0, bool is_causal=False, float? scale=None) -> Tensor",
&torch::executor::native::sdpa_with_kv_cache_aten);
m.def(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WHy is this one needed? So that we can generate out variant one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need both sdpa_with_kv_cache and sdpa_with_kv_cache.out in ATen so that exir is happy.

@@ -32,7 +32,7 @@ exec_aten::Tensor op_sdpa_with_kv_cache(
exec_aten::optional<double> scale,
exec_aten::Tensor& out) {
exec_aten::RuntimeContext context{};
return torch::executor::llama::sdpa_with_kv_cache_outf(
return torch::executor::native::sdpa_with_kv_cache_out(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why changes the namespace?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama namespace was generated by FunctionHeaderWrapper.h. Here I added a header op_sdpa.h and that is using the same namespace as op_sdpa.cpp.

Comment on lines 19 to 31
# assuming we only hit this in OSS, find the default install path
prefix = os.environ.get("CMAKE_INSTALL_PREFIX", "../../../../cmake-out")
lib_path = os.path.join(prefix, "lib/libcustom_ops_aot_lib.so")
torch.ops.load_library(lib_path)
op = torch.ops.llama.sdpa_with_kv_cache.default
assert op is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a little bit clunky but I dont know how to improv e it

runtime.cxx_library(
name = "sdpa",
name = "custom_ops",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why rename this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of places are referring to examples/models/llama2/custom_ops:custom_ops library name. I'm just too lazy to change all of them to sdpa.

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments

@facebook-github-bot
Copy link
Contributor

@larryliu0820 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@larryliu0820 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@larryliu0820 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@larryliu0820 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Apr 5, 2024
Summary:
Using the following 2 APIs:

* `EXECUTORCH_LIBRARY` to replace the need of a yaml file. With this macro we can directly register a custom kernel into ExecuTorch runtime.
* `WRAP_TO_ATEN` allows custom op authors to use the same kernel for ExecuTorch and PyTorch. This can be helpful during debugging.


Test Plan: Rely on the new CI job `test_llama` with `xnnpack+kv+custom` option.

Reviewed By: kimishpatel

Differential Revision: D55713944

Pulled By: larryliu0820
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D55713944

facebook-github-bot pushed a commit that referenced this pull request Apr 5, 2024
Summary:
Using the following 2 APIs:

* `EXECUTORCH_LIBRARY` to replace the need of a yaml file. With this macro we can directly register a custom kernel into ExecuTorch runtime.
* `WRAP_TO_ATEN` allows custom op authors to use the same kernel for ExecuTorch and PyTorch. This can be helpful during debugging.


Test Plan: Rely on the new CI job `test_llama` with `xnnpack+kv+custom` option.

Reviewed By: kimishpatel

Differential Revision: D55713944

Pulled By: larryliu0820
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D55713944

facebook-github-bot pushed a commit that referenced this pull request Apr 5, 2024
Summary:
Using the following 2 APIs:

* `EXECUTORCH_LIBRARY` to replace the need of a yaml file. With this macro we can directly register a custom kernel into ExecuTorch runtime.
* `WRAP_TO_ATEN` allows custom op authors to use the same kernel for ExecuTorch and PyTorch. This can be helpful during debugging.


Test Plan: Rely on the new CI job `test_llama` with `xnnpack+kv+custom` option.

Reviewed By: kimishpatel

Differential Revision: D55713944

Pulled By: larryliu0820
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D55713944

larryliu0820 added a commit that referenced this pull request Apr 5, 2024
Summary:
Using the following 2 APIs:

* `EXECUTORCH_LIBRARY` to replace the need of a yaml file. With this macro we can directly register a custom kernel into ExecuTorch runtime.
* `WRAP_TO_ATEN` allows custom op authors to use the same kernel for ExecuTorch and PyTorch. This can be helpful during debugging.

Pull Request resolved: #2840

Test Plan: Rely on the new CI job `test_llama` with `xnnpack+kv+custom` option.

Reviewed By: kimishpatel

Differential Revision: D55713944

Pulled By: larryliu0820
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D55713944

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D55713944

larryliu0820 added a commit that referenced this pull request Apr 5, 2024
Summary:
Using the following 2 APIs:

* `EXECUTORCH_LIBRARY` to replace the need of a yaml file. With this macro we can directly register a custom kernel into ExecuTorch runtime.
* `WRAP_TO_ATEN` allows custom op authors to use the same kernel for ExecuTorch and PyTorch. This can be helpful during debugging.

Pull Request resolved: #2840

Test Plan: Rely on the new CI job `test_llama` with `xnnpack+kv+custom` option.

Reviewed By: kimishpatel

Differential Revision: D55713944

Pulled By: larryliu0820
facebook-github-bot pushed a commit that referenced this pull request Apr 6, 2024
Summary:
Using the following 2 APIs:

* `EXECUTORCH_LIBRARY` to replace the need of a yaml file. With this macro we can directly register a custom kernel into ExecuTorch runtime.
* `WRAP_TO_ATEN` allows custom op authors to use the same kernel for ExecuTorch and PyTorch. This can be helpful during debugging.


Test Plan: Rely on the new CI job `test_llama` with `xnnpack+kv+custom` option.

Reviewed By: kimishpatel

Differential Revision: D55713944

Pulled By: larryliu0820
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D55713944

Summary:
Using the following 2 APIs:

* `EXECUTORCH_LIBRARY` to replace the need of a yaml file. With this macro we can directly register a custom kernel into ExecuTorch runtime.
* `WRAP_TO_ATEN` allows custom op authors to use the same kernel for ExecuTorch and PyTorch. This can be helpful during debugging.


Test Plan: Rely on the new CI job `test_llama` with `xnnpack+kv+custom` option.

Reviewed By: kimishpatel

Differential Revision: D55713944

Pulled By: larryliu0820
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D55713944

@facebook-github-bot
Copy link
Contributor

@larryliu0820 merged this pull request in 020d8be.

larryliu0820 added a commit that referenced this pull request Apr 7, 2024
facebook-github-bot pushed a commit that referenced this pull request Apr 7, 2024
…2912)

Summary:
This reverts commit 020d8be.


Reviewed By: shoumikhin

Differential Revision: D55852547

Pulled By: larryliu0820
facebook-github-bot pushed a commit that referenced this pull request Apr 7, 2024
…2912)

Summary:
This reverts commit 020d8be.

Pull Request resolved: #2912

Reviewed By: shoumikhin

Differential Revision: D55852547

Pulled By: larryliu0820

fbshipit-source-id: c8528041c03196239d6daef7e2843ee5cf8a8f3d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants