Skip to content

Commit ae9227e

Browse files
committed
Update base for Update on "qnn end to end flow"
Patch a few changes including: - support bool tensor type - support fp16 and fix the 8w8a quantization. - add two non-supported ops (slice_scatter and index_put) in common_defs.py stories model working end to end: AOT: fp16: ``` python -m examples.models.llama2.export_llama -kv --qnn -c stories110M.pt -p params.json ``` quantize: ``` python -m examples.models.llama2.export_llama -kv --qnn --pt2e_quantize -c stories110M.pt -p params.json ``` Runtime: ``` /llama_main --model_path=llama2_fp16_qnn_2.21.pte --tokenizer_path=tokenizer.bin --prompt="Once" ``` Output: ``` Once upon a time, there was a boy named Tim. Tim had a pet dog named Max. Max was a big, strong dog. They liked to play and run in the park. One day, Tim and Max went to the park to play. They saw a cat. The cat was up in a tree. Max wanted to help the cat. He tried to climb the tree, but he could not. Then, something unexpected happened. Max started to climb the tree! He was very strong. Max helped the cat come down. The cat was happy. Tim was so proud of his pet. ``` Stories model is too small and sensitive to qunatization. Differential Revision: [D56119738](https://our.internmc.facebook.com/intern/diff/D56119738/) [ghstack-poisoned]
2 parents f177c79 + c61ef44 commit ae9227e

File tree

13 files changed

+116
-25
lines changed

13 files changed

+116
-25
lines changed

.ci/docker/ci_commit_pins/pytorch.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0a038cf0cff2d071b7359ac0491fd2ba7798a438
1+
868e5ced5df34f1aef3703654f76e03f5126b534

backends/vulkan/runtime/api/Adapter.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,8 +401,7 @@ std::string Adapter::stringize() const {
401401
ss << " Memory Info {" << std::endl;
402402
ss << " Memory Types [" << std::endl;
403403
for (size_t i = 0; i < mem_props.memoryTypeCount; ++i) {
404-
ss << " "
405-
<< " [Heap " << mem_props.memoryTypes[i].heapIndex << "] "
404+
ss << " " << " [Heap " << mem_props.memoryTypes[i].heapIndex << "] "
406405
<< get_memory_properties_str(mem_props.memoryTypes[i].propertyFlags)
407406
<< std::endl;
408407
}

backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ bool OperatorRegistry::has_op(const std::string& name) {
1616

1717
OperatorRegistry::OpFunction& OperatorRegistry::get_op_fn(
1818
const std::string& name) {
19-
return table_.find(name)->second;
19+
const auto it = table_.find(name);
20+
VK_CHECK_COND(it != table_.end(), "Could not find operator with name ", name);
21+
return it->second;
2022
}
2123

2224
void OperatorRegistry::register_op(const std::string& name, OpFunction& fn) {

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@
5050
// describe sizes. As an example, let's say we want to swap dimensions 0,1 for a
5151
// tensor of shape {4,3,2,24} to obtain {3,4,2,24}. Then, x=4, y=3 and
5252
// plane=2*24=48.
53-
#define SWAP_ADJ_DIMS(cur, x, y, plane) \
54-
cur + \
55-
plane*( \
56-
(1 - y) * ((cur % (x * y * plane)) / (y * plane)) + \
57-
(x - 1) * ((cur % (y * plane)) / plane))
53+
#define SWAP_ADJ_DIMS(cur, x, y, plane) \
54+
cur + \
55+
plane * \
56+
((1 - y) * ((cur % (x * y * plane)) / (y * plane)) + \
57+
(x - 1) * ((cur % (y * plane)) / plane))

docs/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ To build the documentation locally:
5757
```bash
5858
pip3 install -r ./.ci/docker/requirements-ci.txt
5959
```
60+
1. Update submodules
6061

62+
```bash
63+
git submodule sync && git submodule update --init
64+
```
6165
1. Run:
6266

6367
```bash

examples/models/llama2/export_llama_lib.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from executorch.sdk.etrecord import generate_etrecord
3535
from executorch.util.activation_memory_profiler import generate_memory_trace
3636
from sentencepiece import SentencePieceProcessor
37-
from torch.nn import functional as F
3837

3938
from .builder import DType, LlamaEdgeManager, load_llama_model, WeightType
4039
from .quant_lib import _get_pt2e_quantization_params, get_pt2e_quantizers
@@ -174,17 +173,17 @@ def forward(
174173
v = v.transpose(1, 2)
175174

176175
k, v = self.kv_cache.update(input_pos, k, v)
177-
mask = mask[None, None, input_pos]
176+
attn_mask = mask[None, None, input_pos]
178177

179178
k = k.repeat_interleave(self.n_rep, dim=1)
180179
v = v.repeat_interleave(self.n_rep, dim=1)
181-
scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
182-
scores = F.softmax(scores.float(), dim=-1).type_as(q)
183-
scores = scores + mask
184-
output = torch.matmul(scores, v) # (bs, n_local_heads, seqlen, head_dim)
180+
scale_factor = 1 / math.sqrt(q.size(-1))
181+
attn_weight = q @ k.transpose(-2, -1) * scale_factor
182+
attn_weight += attn_mask
183+
attn_weight = torch.softmax(attn_weight, dim=-1)
184+
y = attn_weight @ v
185185

186-
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
187-
return output
186+
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
188187

189188

190189
def replace_sdpa_with_simple_sdpa(module: torch.nn.Module):
@@ -200,6 +199,24 @@ def replace_sdpa_with_simple_sdpa(module: torch.nn.Module):
200199
return module
201200

202201

202+
def replace_causal_mask(module: torch.nn.Module):
203+
for buffer_fqn_name, buffer in module.named_buffers():
204+
buffer_name = buffer_fqn_name.split(".")[-1]
205+
if buffer_name == "mask":
206+
max_seq_len = buffer.shape[-1]
207+
mask = torch.full(
208+
(max_seq_len, max_seq_len),
209+
float("-inf"),
210+
device="cpu",
211+
)
212+
213+
mask = torch.triu(mask, diagonal=1)
214+
module.register_buffer(buffer_name, mask)
215+
for _, child in module.named_children():
216+
replace_causal_mask(child)
217+
return module
218+
219+
203220
def quantize(
204221
model: torch.nn.Module,
205222
qmode: str,

examples/models/llama2/runner/runner.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,8 +472,7 @@ std::string statsToJsonString(const Runner::Stats& stats) {
472472
<< "\"prompt_eval_end_ms\":" << stats.prompt_eval_end_ms << ","
473473
<< "\"first_token_ms\":" << stats.first_token_ms << ","
474474
<< "\"aggregate_sampling_time_ms\":" << stats.aggregate_sampling_time_ms
475-
<< ","
476-
<< "\"SCALING_FACTOR_UNITS_PER_SECOND\":"
475+
<< "," << "\"SCALING_FACTOR_UNITS_PER_SECOND\":"
477476
<< stats.SCALING_FACTOR_UNITS_PER_SECOND << "}";
478477
return ss.str();
479478
}

examples/models/llama2/tests/TARGETS

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_unittest(
6+
name = "test_simple_sdpa",
7+
srcs = [
8+
"test_simple_sdpa.py",
9+
],
10+
deps = [
11+
"//caffe2:torch",
12+
"//executorch/examples/models/llama2:export_library",
13+
"//executorch/examples/models/llama2:llama_transformer",
14+
],
15+
)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
import unittest
9+
10+
import torch
11+
from executorch.examples.models.llama2.export_llama_lib import SDPASimple
12+
from executorch.examples.models.llama2.llama_transformer import KVCache, SDPA
13+
14+
15+
class SDPATest(unittest.TestCase):
16+
def test_simple_sdpa(self):
17+
# Verify the correctness between the simple SDPA and the original SDPA module defined in llama_transformer.py
18+
max_batch_size = 1
19+
max_seq_length = 128
20+
n_heads = 8
21+
head_dim = 8
22+
dim = 64
23+
n_rep = 1
24+
bsz = 1
25+
seqlen = 1
26+
n_local_heads = n_heads
27+
kv_cache = KVCache(
28+
max_batch_size=max_batch_size,
29+
max_seq_length=max_seq_length,
30+
n_heads=n_heads,
31+
head_dim=head_dim,
32+
transpose_cache=True,
33+
)
34+
sdpa = SDPA(
35+
kv_cache=copy.deepcopy(kv_cache), dim=dim, head_dim=head_dim, n_rep=n_rep
36+
)
37+
input_pos = torch.tensor([0])
38+
query = torch.randn(1, 1, n_local_heads, head_dim)
39+
key = torch.randn(1, 1, n_local_heads, head_dim)
40+
value = torch.randn(1, 1, n_local_heads, head_dim)
41+
mask = torch.randn(max_seq_length, max_seq_length)
42+
sdpa_output = sdpa(
43+
input_pos, query, key, value, bsz=bsz, seqlen=seqlen, mask=mask
44+
)
45+
46+
simple_sdpa = SDPASimple(
47+
kv_cache=copy.deepcopy(kv_cache), dim=dim, head_dim=head_dim, n_rep=n_rep
48+
)
49+
simple_sdpa_output = simple_sdpa(
50+
input_pos, query, key, value, bsz=bsz, seqlen=seqlen, mask=mask
51+
)
52+
53+
# Compare the output from output from two sdpa implementation
54+
self.assertTrue(torch.allclose(sdpa_output, simple_sdpa_output))

kernels/portable/cpu/op_cumsum.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
#include <executorch/runtime/platform/assert.h>
1212
#include <cmath>
1313
#include <cstddef>
14-
//#include <cstdint>
15-
//#include <type_traits>
14+
// #include <cstdint>
15+
// #include <type_traits>
1616

1717
namespace torch {
1818
namespace executor {

runtime/core/portable_type/optional.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ class optional final {
7474
}
7575

7676
optional& operator=(optional&& rhs) noexcept(
77-
std::is_nothrow_move_assignable<T>::value&&
78-
std::is_nothrow_move_constructible<T>::value) {
77+
std::is_nothrow_move_assignable<T>::value &&
78+
std::is_nothrow_move_constructible<T>::value) {
7979
if (init_ && !rhs.init_) {
8080
clear();
8181
} else if (!init_ && rhs.init_) {

sdk/etdump/etdump_flatcc.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ ETDumpGen::ETDumpGen(Span<uint8_t> buffer) {
103103
alloc.set_buffer(
104104
(uint8_t*)buffer_with_builder,
105105
buffer_size,
106-
(size_t)((buffer_size / 4 > max_alloc_buf_size) ? max_alloc_buf_size : buffer_size / 4));
106+
(size_t)((buffer_size / 4 > max_alloc_buf_size) ? max_alloc_buf_size
107+
: buffer_size / 4));
107108
et_flatcc_custom_init(builder, &alloc);
108109
} else {
109110
builder = (struct flatcc_builder*)malloc(sizeof(struct flatcc_builder));

third-party/pytorch

Submodule pytorch updated 589 files

0 commit comments

Comments
 (0)