Skip to content

Commit d40749c

Browse files
committed
Update base for Update on "[ET-VK] Clean up shader library and introduce some new conventions"
## Context This changeset introduces some fairly mechnical improvements to the Vulkan compute graph shader library in order to introduce some new conventions. **Note that backwards compatibility with existing shader authoring methods is preserved**. ### Only List `VALUE` in the `.yaml` files Previously, to generate variants for a combination of vales, the YAML file will contain ``` PACKING: - VALUE: CHANNELS_PACKED SUFFIX: C_packed - VALUE: WIDTH_PACKED SUFFIX: W_packed - VALUE: HEIGHT_PACKED SUFFIX: H_packed ``` however, the shader code generation script will use the `VALUE` as the `SUFFIX` if no `SUFFIX` is provided. Therefore, only the below is needed: ``` PACKING: - VALUE: C_packed - VALUE: W_packed - VALUE: H_packed ``` ### Change indexing utility macros to lowercase Indexing utility macros have been changed to lowercase, and the packing identifiers have been changed due to the change in YAML files. The change to lowercase is to make calls to the macro read more like functions (and indeed they are typically used as functions) in order to help make the code more readable. ``` POS_TO_COORD_${PACKING} -> pos_to_coord_${PACKING} ``` ### Use convention of defining macros in order to reduce Python code blocks usage Previously python code blocks were used in the GLSL code itself in order to vary the shader between different settings. However, usage of Python code blocks negatively impact code readability. Therefore, this diff seeks to introduce a convention of defining macros near the top of the shader to reduce the usage of Python code blocks, i.e. ``` #define pos_to_coord pos_to_coord_${PACKING} #define get_packed_dim get_packed_dim_${PACKING} #define get_packed_stride get_packed_stride_${PACKING} ``` ### Improve GLSL type definitions Previously, the following Python code blocks were used to determine appropriate vectorized and scalar types: ``` ${VEC4_T[DTYPE}} texel = ... ${T[DTYPE]} scalar = ... ``` This changeset replaces that with: ``` #define BUF_T ${buffer_scalar_type(DTYPE)} #define VEC4_T ${texel_type(DTYPE)} #define SCALAR_T ${texel_component_type(DTYPE)} layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { BUF_T data[]; } buffer_in; VEC4_T texel = ... SCALAR_T scalar = ... ``` The main differences are as such: * `buffer_scalar_type()` produces the same result as `T[DTYPE]` * `texel_type()` is not determined from a mapping with `DTYPE`, but is determined indirectly based on the image format that is associated with the `DTYPE`. * `texel_component_type()` is based on the result of `texel_type(DTYPE)` Essentially, the mapping is more in-line with what happens in code. The reason for this change is to enable FP16 support and is a bit complicated. Basically, we need a way to distinguish the scalar type used for buffer storage, vs the scalar type used to store a component of a vec4 type (hence `BUF_T` vs `SCALAR_T`). The reason this is required is that to support half-precision tensors, the buffer representation will use a 16-bit float type but textures will still extract to `vec4` (i.e. 4x34bit floats). Differential Revision: [D56082461](https://our.internmc.facebook.com/intern/diff/D56082461/) [ghstack-poisoned]
2 parents 74eb8b3 + 21fdc4e commit d40749c

File tree

14 files changed

+257
-61
lines changed

14 files changed

+257
-61
lines changed

.github/workflows/android.yml

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ jobs:
3333
submodules: 'true'
3434
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
3535
timeout: 90
36+
upload-artifact: android-apps
3637
script: |
3738
set -eux
3839
@@ -45,3 +46,62 @@ jobs:
4546
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "${BUILD_TOOL}"
4647
# Build Android demo app
4748
bash build/test_android_ci.sh
49+
50+
mkdir -p artifacts-to-be-uploaded
51+
# Copy the app and its test suite to S3
52+
cp examples/demo-apps/android/LlamaDemo/app/build/outputs/apk/debug/*.apk artifacts-to-be-uploaded/
53+
cp examples/demo-apps/android/LlamaDemo/app/build/outputs/apk/androidTest/debug/*.apk artifacts-to-be-uploaded/
54+
# Also copy the share libraries
55+
cp cmake-out-android/lib/*.a artifacts-to-be-uploaded/
56+
57+
# Upload the app and its test suite to S3 so that they can be downloaded by the test job
58+
upload-artifacts:
59+
needs: test-demo-android
60+
runs-on: linux.2xlarge
61+
steps:
62+
- name: Download the artifacts
63+
uses: actions/download-artifact@v3
64+
with:
65+
# The name here needs to match the name of the upload-artifact parameter
66+
name: android-apps
67+
path: ${{ runner.temp }}/artifacts/
68+
69+
- name: Verify the artifacts
70+
shell: bash
71+
working-directory: ${{ runner.temp }}/artifacts/
72+
run: |
73+
ls -lah ./
74+
75+
- name: Upload the artifacts to S3
76+
uses: seemethere/upload-artifact-s3@v5
77+
with:
78+
s3-bucket: gha-artifacts
79+
s3-prefix: |
80+
${{ github.repository }}/${{ github.run_id }}/artifact
81+
retention-days: 14
82+
if-no-files-found: ignore
83+
path: ${{ runner.temp }}/artifacts/
84+
85+
# Let's see how expensive this job is, we might want to tone it down by running it periodically
86+
test-llama-app:
87+
needs: upload-artifacts
88+
permissions:
89+
id-token: write
90+
contents: read
91+
uses: pytorch/test-infra/.github/workflows/mobile_job.yml@main
92+
with:
93+
device-type: android
94+
runner: ubuntu-latest
95+
test-infra-ref: ''
96+
# This is the ARN of ExecuTorch project on AWS
97+
project-arn: arn:aws:devicefarm:us-west-2:308535385114:project:02a2cf0f-6d9b-45ee-ba1a-a086587469e6
98+
# This is the custom Android device pool that only includes Samsung Galaxy S2x
99+
device-pool-arn: arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/e59f866a-30aa-4aa1-87b7-4510e5820dfa
100+
# Uploaded to S3 from the previous job, the name of the app comes from the project itself
101+
android-app-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/app-debug.apk
102+
android-test-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/app-debug-androidTest.apk
103+
# The test spec can be downloaded from https://ossci-assets.s3.amazonaws.com/android-llama2-device-farm-test-spec.yml
104+
test-spec: arn:aws:devicefarm:us-west-2:308535385114:upload:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/414cb54d-4d83-4576-8317-93244e4dc50e
105+
# The exported llama2 model and its tokenizer, can be downloaded from https://ossci-assets.s3.amazonaws.com/executorch-android-llama2-7b.zip.
106+
# Among the input, this is the biggest file and uploading it to AWS beforehand makes the test run much faster
107+
extra-data: arn:aws:devicefarm:us-west-2:308535385114:upload:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/bd15825b-ddab-4e47-9fef-a9c8935778dd

build/test_android_ci.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ build_android_llama_demo_app() {
3232
pushd examples/demo-apps/android/LlamaDemo
3333
ANDROID_NDK=/opt/ndk ANDROID_ABI=arm64-v8a ./gradlew setup
3434
ANDROID_HOME=/opt/android/sdk ./gradlew build
35+
ANDROID_HOME=/opt/android/sdk ./gradlew assembleAndroidTest
3536
popd
3637
}
3738

docs/source/llm/getting-started.md

Lines changed: 119 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -374,46 +374,102 @@ specific hardware (delegation), and because it is doing all of the calculations
374374

375375
## Delegation
376376

377-
While ExecuTorch provides a portable, cross-platform implementation for all operators, it also provides specialized
378-
backends for a number of different targets. These include, but are not limited to, x86 and ARM CPU acceleration via
379-
the XNNPACK backend, Apple acceleration via the CoreML backend and Metal Performance Shader (MPS) backend, and GPU
380-
acceleration via the Vulkan backend.
381-
382-
Because optimizations are specific to a given backend, each pte file is specific to the backend(s) targeted at
383-
export. To support multiple devices, such as XNNPACK acceleration for Android and CoreML for iOS, export a separate
384-
PTE file for each backend.
385-
386-
To delegate to a backend at export time, ExecuTorch provides the `to_backend()` function, which takes a backend-
387-
specific partitioner object. The partitioner is responsible for finding parts of the computation graph that can
388-
be accelerated by the target backend. Any portions of the computation graph not delegated will be executed by the
389-
portable or optimized ExecuTorch implementations.
390-
391-
To delegate to the XNNPACK backend, call `to_backend` with an instance of `XnnpackPartitioner()`.
377+
While ExecuTorch provides a portable, cross-platform implementation for all
378+
operators, it also provides specialized backends for a number of different
379+
targets. These include, but are not limited to, x86 and ARM CPU acceleration via
380+
the XNNPACK backend, Apple acceleration via the CoreML backend and Metal
381+
Performance Shader (MPS) backend, and GPU acceleration via the Vulkan backend.
382+
383+
Because optimizations are specific to a given backend, each pte file is specific
384+
to the backend(s) targeted at export. To support multiple devices, such as
385+
XNNPACK acceleration for Android and CoreML for iOS, export a separate PTE file
386+
for each backend.
387+
388+
To delegate to a backend at export time, ExecuTorch provides the `to_backend()`
389+
function in the `EdgeProgramManager` object, which takes a backend-specific
390+
partitioner object. The partitioner is responsible for finding parts of the
391+
computation graph that can be accelerated by the target backend,and
392+
`to_backend()` function will delegate matched part to given backend for
393+
acceleration and optimization. Any portions of the computation graph not
394+
delegated will be executed by the ExecuTorch operator implementations.
395+
396+
To delegate the exported model to the specific backend, we need to import its
397+
partitioner as well as edge compile config from ExecuTorch Codebase first, then
398+
call `to_backend` with an instance of partitioner on the `EdgeProgramManager`
399+
object `to_edge` function created.
400+
401+
Here's an example of how to delegate NanoGPT to XNNPACK (if you're deploying to an Android Phone for instance):
392402

393403
```python
394404
# export_nanogpt.py
395405

406+
# Load partitioner for Xnnpack backend
396407
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
408+
409+
# Model to be delegated to specific backend should use specific edge compile config
397410
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
411+
from executorch.exir import EdgeCompileConfig, to_edge
412+
413+
import torch
414+
from torch.export import export
415+
from torch.nn.attention import sdpa_kernel, SDPBackend
416+
from torch._export import capture_pre_autograd_graph
417+
418+
from model import GPT
419+
420+
# Load the NanoGPT model.
421+
model = GPT.from_pretrained('gpt2')
398422

399-
#...
423+
# Create example inputs. This is used in the export process to provide
424+
# hints on the expected shape of the model input.
425+
example_inputs = (
426+
torch.randint(0, 100, (1, 8), dtype=torch.long),
427+
)
428+
429+
# Trace the model, converting it to a portable intermediate representation.
430+
# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
431+
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
432+
m = capture_pre_autograd_graph(model, example_inputs)
433+
traced_model = export(m, example_inputs)
400434

435+
# Convert the model into a runnable ExecuTorch program.
436+
# To be further lowered to Xnnpack backend, `traced_model` needs xnnpack-specific edge compile config
401437
edge_config = get_xnnpack_edge_compile_config()
402438
edge_manager = to_edge(traced_model, compile_config=edge_config)
403439

404-
# Delegate to the XNNPACK backend.
440+
# Delegate exported model to Xnnpack backend by invoking `to_backend` function with Xnnpack partitioner.
405441
edge_manager = edge_manager.to_backend(XnnpackPartitioner())
406-
407442
et_program = edge_manager.to_executorch()
408443

444+
# Save the Xnnpack-delegated ExecuTorch program to a file.
445+
with open("nanogpt.pte", "wb") as file:
446+
file.write(et_program.buffer)
447+
448+
409449
```
410450

411-
Additionally, update CMakeLists.txt to build and link the XNNPACK backend.
451+
Additionally, update CMakeLists.txt to build and link the XNNPACK backend to
452+
ExecuTorch runner.
412453

413454
```
414-
option(EXECUTORCH_BUILD_XNNPACK "" ON)
455+
cmake_minimum_required(VERSION 3.19)
456+
project(nanogpt_runner)
415457
416-
# ...
458+
set(CMAKE_CXX_STANDARD 17)
459+
set(CMAKE_CXX_STANDARD_REQUIRED True)
460+
461+
# Set options for executorch build.
462+
option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "" ON)
463+
option(EXECUTORCH_BUILD_EXTENSION_MODULE "" ON)
464+
option(EXECUTORCH_BUILD_OPTIMIZED "" ON)
465+
option(EXECUTORCH_BUILD_XNNPACK "" ON) # Build with Xnnpack backend
466+
467+
# Include the executorch subdirectory.
468+
add_subdirectory(
469+
${CMAKE_CURRENT_SOURCE_DIR}/third-party/executorch
470+
${CMAKE_BINARY_DIR}/executorch)
471+
472+
# include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)
417473
418474
add_executable(nanogpt_runner main.cpp)
419475
target_link_libraries(
@@ -423,11 +479,51 @@ target_link_libraries(
423479
extension_module_static # Provides the Module class
424480
optimized_native_cpu_ops_lib # Provides baseline cross-platform kernels
425481
xnnpack_backend) # Provides the XNNPACK CPU acceleration backend
482+
```
483+
484+
Keep the rest of the code the same. For more details refer to
485+
[Exporting to ExecuTorch](https://pytorch.org/executorch/main/llm/getting-started.html#step-1-exporting-to-executorch)
486+
and
487+
[Invoking the Runtime](https://pytorch.org/executorch/main/llm/getting-started.html#step-2-invoking-the-runtime)
488+
for more details
426489

490+
At this point, the working directory should contain the following files:
491+
492+
- CMakeLists.txt
493+
- main.cpp
494+
- basic_tokenizer.h
495+
- basic_sampler.h
496+
- managed_tensor.h
497+
- export_nanogpt.py
498+
- model.py
499+
- vocab.json
500+
501+
If all of these are present, you can now export Xnnpack delegated pte model:
502+
```bash
503+
python export_nanogpt.py
427504
```
428505

429-
For more information, see the ExecuTorch guides for the [XNNPACK Backend](https://pytorch.org/executorch/stable/tutorial-xnnpack-delegate-lowering.html)
430-
and [CoreML Backend](https://pytorch.org/executorch/stable/build-run-coreml.html).
506+
It will generate `nanogpt.pte`, under the same working directory.
507+
508+
Then we can build and run the model by:
509+
```bash
510+
(rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake ..)
511+
cmake --build cmake-out -j10
512+
./cmake-out/nanogpt_runner
513+
```
514+
515+
You should see something like the following:
516+
517+
```
518+
Once upon a time, there was a man who was a member of the military...
519+
```
520+
521+
522+
For more information regarding backend delegateion, see the ExecuTorch guides
523+
for the
524+
[XNNPACK Backend](https://pytorch.org/executorch/stable/tutorial-xnnpack-delegate-lowering.html)
525+
and
526+
[CoreML Backend](https://pytorch.org/executorch/stable/build-run-coreml.html).
431527

432528
## Quantization
433529

examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public class PerfTest implements LlamaCallback {
2828
private static final String TOKENIZER_BIN = "tokenizer.bin";
2929

3030
// From https://github.com/pytorch/executorch/blob/main/examples/models/llama2/README.md
31-
private static final Float EXPECTED_TPS = 7.0F;
31+
private static final Float EXPECTED_TPS = 10.0F;
3232

3333
private final List<String> results = new ArrayList<>();
3434
private final List<Float> tokensPerSecond = new ArrayList<>();

examples/models/llama2/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Please note that the models are subject to the [acceptable use policy](https://g
2020
Since 7B Llama2 model needs at least 4-bit quantization to fit even within some of the highend phones, results presented here correspond to 4-bit groupwise post-training quantized model.
2121

2222
## Quantization:
23-
We employed 4-bit groupwise per token dynamic quantization of all the linear layers of the model. Dynamic quantization refers to quantizating activations dynamically, such that quantization parameters for activations are calculated, from min/max range, at runtime. Here we quantized activations with 8bits (signed integer). Furthermore, weights are statically quantized. In our case weights were per-channel groupwise quantized with 4bit signed integer. For more information refer to this [page](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html).
23+
We employed 4-bit groupwise per token dynamic quantization of all the linear layers of the model. Dynamic quantization refers to quantizating activations dynamically, such that quantization parameters for activations are calculated, from min/max range, at runtime. Here we quantized activations with 8bits (signed integer). Furthermore, weights are statically quantized. In our case weights were per-channel groupwise quantized with 4bit signed integer. For more information refer to this [page](https://github.com/pytorch-labs/ao/).
2424

2525
We evaluated UncycloText perplexity using [LM Eval](https://github.com/EleutherAI/lm-evaluation-harness). Below are the results for two different groupsizes.
2626

examples/models/llama2/runner/runner.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// The module takes in a string as input and emits a string as output.
1111

1212
#include <executorch/examples/models/llama2/runner/runner.h>
13+
#include <executorch/examples/models/llama2/tokenizer/bpe_tokenizer.h>
1314
#include <executorch/extension/evalue_util/print_evalue.h>
1415
#include <executorch/extension/runner_util/managed_tensor.h>
1516

@@ -76,7 +77,7 @@ Error Runner::load() {
7677
append_eos_ = getMetadataHelper("append_eos_to_prompt", false);
7778

7879
// Load tokenizer
79-
tokenizer_ = std::make_unique<Tokenizer>(vocab_size_, bos_id_, eos_id_);
80+
tokenizer_ = std::make_unique<BPETokenizer>(vocab_size_, bos_id_, eos_id_);
8081
tokenizer_->load(tokenizer_path_);
8182
if (tokenizer_->bos_tok() != bos_id_) {
8283
ET_LOG(
@@ -105,7 +106,7 @@ Error Runner::load() {
105106
}
106107

107108
template <typename T>
108-
T Runner::getMetadataHelper(std::string method_name, T default_val) {
109+
T Runner::getMetadataHelper(const std::string& method_name, T default_val) {
109110
T res = default_val;
110111
if (model_methods_.count(method_name)) {
111112
Result<std::vector<EValue>> outputs = module_->execute(method_name);
@@ -484,9 +485,9 @@ void Runner::stop() {
484485

485486
// explicit instantiation of template methods
486487
template int64_t Runner::getMetadataHelper<int64_t>(
487-
std::string method_name,
488+
const std::string& method_name,
488489
int64_t default_val);
489490
template bool Runner::getMetadataHelper<bool>(
490-
std::string method_name,
491+
const std::string& method_name,
491492
bool default_val);
492493
} // namespace torch::executor

examples/models/llama2/runner/runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class Runner {
6969
private:
7070
// metadata
7171
template <typename T>
72-
T getMetadataHelper(std::string method_name, T default_val);
72+
T getMetadataHelper(const std::string& method_name, T default_val);
7373
template <typename T>
7474
int32_t
7575
logitsToToken(const exec_aten::Tensor& logits_tensor, int64_t pos, T _);

examples/models/llama2/tokenizer/tokenizer.cpp renamed to examples/models/llama2/tokenizer/bpe_tokenizer.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/examples/models/llama2/tokenizer/tokenizer.h>
9+
#include <executorch/examples/models/llama2/tokenizer/bpe_tokenizer.h>
1010

1111
#include <string>
1212

@@ -23,11 +23,11 @@ static int compare_tokens(const void* a, const void* b) {
2323
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
2424
}
2525

26-
Tokenizer::Tokenizer(int32_t vocab_size, uint64_t bos_tok, uint64_t eos_tok)
27-
: initialized_(false),
28-
vocab_size_(vocab_size),
29-
bos_tok_(bos_tok),
30-
eos_tok_(eos_tok),
26+
BPETokenizer::BPETokenizer(
27+
int32_t vocab_size,
28+
uint64_t bos_tok,
29+
uint64_t eos_tok)
30+
: Tokenizer(vocab_size, bos_tok, eos_tok),
3131
vocab_(std::make_unique<char*[]>(vocab_size)),
3232
vocab_scores_(std::make_unique<float[]>(vocab_size)),
3333
sorted_vocab_(std::make_unique<TokenIndex[]>(vocab_size)) {
@@ -47,7 +47,7 @@ Tokenizer::Tokenizer(int32_t vocab_size, uint64_t bos_tok, uint64_t eos_tok)
4747
* @param tokenizer_path The path to the tokenizer file.
4848
* @return Error
4949
*/
50-
Error Tokenizer::load(const std::string& tokenizer_path) {
50+
Error BPETokenizer::load(const std::string& tokenizer_path) {
5151
if (initialized_) {
5252
ET_LOG(Info, "Tokenizer already initialized");
5353
return Error::Ok;
@@ -131,7 +131,7 @@ Error Tokenizer::load(const std::string& tokenizer_path) {
131131
return Error::Ok;
132132
}
133133

134-
Tokenizer::~Tokenizer() {
134+
BPETokenizer::~BPETokenizer() {
135135
for (int i = 0; i < vocab_size_; i++) {
136136
delete[] vocab_[i];
137137
}
@@ -145,7 +145,7 @@ Tokenizer::~Tokenizer() {
145145
* @return Result<std::string> A pointer to the string representation of the
146146
* token.
147147
*/
148-
Result<std::string> Tokenizer::decode(uint64_t prev_token, uint64_t token) {
148+
Result<std::string> BPETokenizer::decode(uint64_t prev_token, uint64_t token) {
149149
if (!initialized_) {
150150
ET_LOG(Error, "Tokenizer not initialized");
151151
return Error::NotSupported;
@@ -187,7 +187,7 @@ str_lookup(const char* str, TokenIndex* sorted_vocab, int32_t vocab_size) {
187187
* @return Result<std::vector<uint64_t>>
188188
*/
189189
Result<std::vector<uint64_t>>
190-
Tokenizer::encode(const std::string& text, int8_t bos, int8_t eos) {
190+
BPETokenizer::encode(const std::string& text, int8_t bos, int8_t eos) {
191191
if (!initialized_) {
192192
ET_LOG(Error, "Tokenizer not initialized");
193193
return Error::NotSupported;

0 commit comments

Comments
 (0)