Skip to content

Initial Implementation of MediaTek Backend for Executorch #3571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
Aug 14, 2024

Conversation

neuropilot-captain
Copy link
Collaborator

@neuropilot-captain neuropilot-captain commented May 10, 2024

This pull request introduces the initial implementation of the MediaTek backend for Executorch. Below are the key points and instructions related to this contribution:

  • "Ahead of Time" and "Runtime" are included in this implementation.
  • The build process and execution phase rely on libraries provided by the MediaTek NeuroPilot SDK. For the purpose of verification, these libraries will be made available offline.

Build Instructions:

To build the MediaTek backend for Executorch, please adhere to the guidelines provided in the backends/mediatek/scripts/README.md file within this repository.

Ahead-Of-Time Compilation

To compile models for MediaTek devices, please follow the instruction on the ExecuTorch documentation.
For Llama compilation, please refer to examples/mediatek/README.md for the usage of the export script shell_scripts/export_llama.sh

On-device Execution:

For instructions on how to run an example on MediaTek devices, refer to the documentation in examples/mediatek/README.md.

Llama Example:

A sample Llama runner is built together with MediaTek backend.

- Add backends/mediatek & examples/mediatek
- Add runtime: Neuron ExecuTorch Backend
- Add example: mtk_executor_runner
Copy link

pytorch-bot bot commented May 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/3571

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7786412 with merge base ba3448c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 10, 2024
Download Android NDK and make sure that $ANDROID_NDK is specified as the NDK Path.

3. Download **MediaTek ExercuTorch libraries**
Download related libraries.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are these and where do you get it from? I would like to repro

@@ -0,0 +1,29 @@
# ExecuTorch Neuron Backend examples

## Directory structure
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add pre-requisite here as to which phones support MTK chipset? ALso do add "validated on" to indicate which phone or chipset this is validated on

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestion! We have added the list of supported chips that the examples are validated on. Also, we have refactor this README for better readability.

Please follow the build step in backends/mediatek/scripts.
The mtk_executor_runner will be automatically generated.

## Execute
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you generate the pte file?

@kimishpatel
Copy link
Contributor

I dont see any instructions/code on the ahead of time component

Copy link

pytorch-bot bot commented May 10, 2024

Please seek CI approval before scheduling CIFlow labels

@mergennachin
Copy link
Contributor

Thanks for your contribution and PR.

Will take a look at it deeply soon. I just kicked off OSS CI jobs. One thing, we have a lint rule. Here's how you can invoke:

https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#coding-style

@mergennachin
Copy link
Contributor

The build process and execution phase rely on libraries provided by the MediaTek NeuroPilot SDK. For the purpose of verification, these libraries will be made available offline.

Is there a public link that we can download from MediaTek website? If not, is there a plan to publish it publicly?

Copy link
Contributor

@swolchok swolchok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just dropping in with a bunch of small C++ suggestions; feel free to apply them or not

ET_CHECK_MSG(
GetModelInstance() == nullptr,
"Model is already initialized before calling LoadModels.");
void* instance = CreateModelInstance(mModelPathMap[id]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/mModelPathMap[id]/modelPath/ , no need to look it up again

}

template <typename IdType>
std::string MultiModelLoader<IdType>::GetModelPath() const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this return a const reference to avoid copying?

size_t getRotEmbedLength() const;

private:
char* mMasterLut; // byte flatten array
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use std::unique_ptr<char[]>?

Comment on lines 48 to 50
if (mLutBuffer != nullptr) {
delete mLutBuffer;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. deleting nullptr is fine
  2. this should be delete[]
  3. but why not use std::unique_ptr<uint8_t[]> and avoid the whole thing?

Comment on lines +127 to +128
size_t num_memory_planned_buffers = method_meta->num_memory_planned_buffers();
for (size_t id = 0; id < num_memory_planned_buffers; ++id) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it's probably not performance-critical, but reserving these vectors before filling them wouldn't hurt

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As answered in the conversation below.
The current implementation of mtk_executor_runner is based on examples/portable/executor_runner, and we'll look into executorch/extension/module to determine if we need to modify the code.

Copy link
Contributor

@cccclai cccclai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you pushing the first PR! Looks great :)

Overall I think NeuronBackend and mtk_executor_runner are reasonable. It will be great if we can share the expected flow for users to generate the .pte file and run it.

Regarding the llama runner logic, I can see why we need the multimodel loader and the caching logic, but it's not quite clear to me what logic from .sois needed here. Maybe we can walk through it.

return Error::InvalidState;
};

auto& allocator = GET_NEURON_ALLOCATOR;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can set up temp allocator pointing to neuron allocator, which will be passed to NeuronExecuTorchDelegate::execute

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I apologize if I misunderstood your question. The NEURON_ALLOCATOR is a global instance that allows users to set up shared memory, which can then be utilized by the NeuronBackend.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I get that, I just meant we can customize the runtime allcator and having temp allocator pointing to GET_NEURON_ALLOCATOR, then we don't need to use the NEURON_ALLOCATOR as a global instance. The code will look like this

// In the executor_runner.cpp
auto& neuron_allocator = GET_NEURON_ALLOCATOR;
MemoryManager memory_manager(&method_allocator, &planned_memory,  neuron_allocator);

// In the NeuronBackend.cpp

Error NeuronBackend::execute( context,
                             DelegateHandle* handle,
                             EValue** args){
context.allocate(size=*, alignment=*); // see https://github.com/pytorch/executorch/blob/main/runtime/backend/backend_execution_context.h#L40-L42
}

If the api from the BackendExecutionContext isn't sufficient, let me know and we can expose more from there.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your response and suggestions.

We understand that we can customize the runtime allocator and point the temporary allocator to GET_NEURON_ALLOCATOR, so we don't need to use NEURON_ALLOCATOR as a global instance. This is indeed a good solution.

However, in our implementation, NeuronBackend needs to determine which device's buffer an address belongs to. Therefore, we use an additional API Find from NEURON_ALLOCATOR to query the backend. The current API in BackendExecutionContext does not allow us to achieve this functionality.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like it's reverting to the old implementation, is it expected?

LogInfo("NeuronBackend", "version %u, input %u, output %u, length %u, payload size: %zu",
Payload.Header.Version, Payload.Header.InputCount, Payload.Header.OutputCount, Payload.Header.DataLen, processed->size());

auto delegate = std::unique_ptr<NeuronExecuTorchDelegate>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason we have a NeuronExecuTorchDelegate wrapper but not applying the logic inside NeuronBackend directly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our design, a NeuronExecuTorchDelegate handles a partition of the model on the NPU. This separation ensures modularity and flexibility, making the codebase cleaner and easier to maintain.

return mExecutor.Compute() == NEURON_NO_ERROR ? Error::Ok : Error::InvalidState;
};

int NeuronExecuTorchDelegate::HintNeuronBackend(EValue** args) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to understand what HintNeuronBackend mean - is there some import forever issue?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'import forever' is an optimization hint designed for specific models including Llama. We prefer not to automatically apply it to every model because it could leads unknow effects.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - hmm probably add more comments about it (or maybe just copy paste your response)

Comment on lines 22 to 28
size_t prompt_token_batch_size = 1;
size_t cache_size = 1024;
size_t hidden_size = 4096;
size_t num_head = 32;
size_t num_layer = 32;
size_t max_token_length = 2048;
double rot_emb_base = 10000.0f;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this data ideally don't need to be hard-coded but retrieved from the model. If they're hardcoded, something might go wrong if users export a llama-based model but with different model arch

return indexes;
}

LlamaModelChunk::LlamaModelChunk(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does user export 4 .pte file or 1 .pte file?

neuropilot-captain and others added 4 commits May 17, 2024 15:08
- Fix typos in NeuronExecutor.cpp.
- Remove duplicate APUWareUtilsLib.h in CMakeLists.txt.
- Update NEURON_BUFFER_ALLOCATOR to NEURON_BUFFER_ALLOCATOR_LIB.
- Clean up mtk_build.sh and set BUCK_PATH.
- Modify ReadSystemProperty signature to take string reference.
- Add MediaTek libraries introduction
- Improve readability
- Add the list of supported devices
- Improve readability
neuropilot-captain and others added 6 commits May 23, 2024 09:38
- minor wording change
- Fix Buck2 download link
- Add command example
Changed kHighAddrKey and kImportForeverKey from constexpr to extern in NeuronBackend.h.
Defined kHighAddrKey and kImportForeverKey in NeuronBackend.cpp.
Moved backend registration logic to an anonymous namespace in NeuronBackend.cpp.
Copy link
Contributor

@Riandy Riandy Jul 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Add that only linux environment is supported (this is due to the extra libs mtk_neuron and mtk_converter whl files are only supported in linux)
  • Add hyperlink to ET environment setup https://pytorch.org/executorch/stable/getting-started-setup.html
  • Mention that requirements.txt file has two dependencies that cannot be installed via pip.
    • mtk_converter
    • mtk-neuron
    • Another option is to split this up by removing those two from requirements.txt, so devs can just pip3 install -r requirements.txt and then explicitly mention that they need to install the other two libs.
// Ensure that you are inside executorch/examples/mediatek directory
pip3 install -r requirements.txt

// Download the two whl files from (some_location/URL)
pip3 install mtk_neuron-8.2.2-py3-none-linux_x86_64.whl
pip3 install mtk_converter-8.8.0.dev20240723+public.d1467db9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
  • Explain that dev should expect 8 pte files (4 for prompt and 4 for generation) + 1 token embedding file and provide location of the token embedding file location
    models/llm_models/weights/<model_name>/<embedding_file_name>.bin
  • mtk_neuron seems to require a localhost connection. So verify that localhost connection is available (not blocked) before running the AoT script
  • AoT flow will take approximately 2.5hr (with 114 GB ram. Results may vary by devices/ hardware configurations).
  • For llama models, need to use mtk_llama_executor_runner, and ask them to check examples/mediatek/executor_runner/run_llama3_sample.sh for reference

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mtk_neuron seems to require a localhost connection. So verify that localhost connection is available (not blocked) before running the AoT script

Do you mean to guide users check their network condition before running this example?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Give them some warning that a working localhost connection is required in AOT flow. When I tested the flow in Windows + Linux WSL (localhost is blocked) and I wasted 1.5 hours on the initial AoT flow without knowing that localhost is required.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Buck2 Build Tool
    • This seems to be not needed? Looks like ET will auto download the required version it needs? (See: executorch/build/resolve_buck.py)
  • Android NDK
    • Any specific version that is needed? Or tested upon? Maybe we should specify the version tested?
  • “Build Script” section
    • Mentioned the output of the build script. Which is in the “cmake-android-out” folder.
      In this case, it looks like we have 3 runners generated (executorch_runner, mtk_executor_runner and mtk_llama_executor_runner). mtk_llama_executor_runner is located inside /cmake-android-out/examples/mediatek
    • Question: should executorch_runner ever be used? If not, maybe not generate it?
      Mention that they should use mtk_llama_executor_runner for llama models

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we should remove in the future

@cccclai
Copy link
Contributor

cccclai commented Aug 7, 2024

The changes in root cmake and the buffer allocator looks good! I marked them as resolved. The only one left is the destroy function and we're good then.

@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@cccclai
Copy link
Contributor

cccclai commented Aug 8, 2024

Many thanks for MTK team for making the contribution. I'll go ahead to merge it. Need to wait for the CI signal a bit.

@cccclai
Copy link
Contributor

cccclai commented Aug 8, 2024

Looks like there are some lint errors. Could you address them? Here is the failing job https://github.com/pytorch/executorch/actions/runs/10296603326/job/28529110807?pr=3571

@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.


### 3. MediaTek ExercuTorch Libraries

Download the following libraries from MediaTek's NeuroPilot portal (link to be added):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the links ready?

Comment on lines +36 to +44
2. Setup MTK AoT Environment
```bash
// Ensure that you are inside executorch/examples/mediatek directory
pip3 install -r requirements.txt

// Download the two whl files from NeuroPilot Portal
pip3 install mtk_neuron-8.2.2-py3-none-linux_x86_64.whl
pip3 install mtk_converter-8.8.0.dev20240723+public.d1467db9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
```
Copy link
Contributor

@mergennachin mergennachin Aug 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should these instructions be inside backends/mediatek/scripts/?

we need to make sure these two scripts and readme need to have separate roles:

(1) backends/mediatek contains scripts and readme that is for setting up generic and reusable mediatek setup instructions.
(2) examples/mediatek contains scripts and readme that is for setting up specifically what is needed to run that examples

(2) can be refer to (1) as well as additional stuff specifically for that example.

an application developer will need to call the setup scripts inside backends/mediatek.

@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot facebook-github-bot merged commit 1cb97e0 into pytorch:main Aug 14, 2024
35 checks passed
kirklandsign pushed a commit to kirklandsign/executorch that referenced this pull request Aug 15, 2024
Differential Revision: D60970271

Pull Request resolved: pytorch#3571
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants