Skip to content

Commit dbc324d

Browse files
committed
Update base for Update on "qnn end to end flow"
Patch a few changes including: - support bool tensor type - support fp16 and fix the 8w8a quantization. - add two non-supported ops (slice_scatter and index_put) in common_defs.py stories model working end to end: AOT: fp16: ``` python -m examples.models.llama2.export_llama -kv --qnn -c stories110M.pt -p params.json ``` quantize: ``` python -m examples.models.llama2.export_llama -kv --qnn --pt2e_quantize -c stories110M.pt -p params.json ``` Runtime: ``` /llama_main --model_path=llama2_fp16_qnn_2.21.pte --tokenizer_path=tokenizer.bin --prompt="Once" ``` Output: ``` Once upon a time, there was a boy named Tim. Tim had a pet dog named Max. Max was a big, strong dog. They liked to play and run in the park. One day, Tim and Max went to the park to play. They saw a cat. The cat was up in a tree. Max wanted to help the cat. He tried to climb the tree, but he could not. Then, something unexpected happened. Max started to climb the tree! He was very strong. Max helped the cat come down. The cat was happy. Tim was so proud of his pet. ``` Stories model is too small and sensitive to qunatization. Differential Revision: [D56119738](https://our.internmc.facebook.com/intern/diff/D56119738/) [ghstack-poisoned]
2 parents ac1d66f + 1eed125 commit dbc324d

File tree

7 files changed

+210
-14
lines changed

7 files changed

+210
-14
lines changed

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,21 @@
88

99
#define divup4(x) ((x + 3) / 4)
1010

11-
#define to_buffer_i(idx, sizes) \
12-
idx.x + idx.y* sizes.x + idx.z* sizes.y* sizes.x + \
13-
idx.w* sizes.z* sizes.y* sizes.x;
11+
// Input: idx is a ivec4 user-level coordinate, sizes is the tensor shape
12+
// Output: buffer_idx in the continuous nchw-buffer.
13+
#define to_buffer_i(idx, sizes) \
14+
(idx.x + idx.y * sizes.x + idx.z * sizes.y * sizes.x + \
15+
idx.w * sizes.z * sizes.y * sizes.x)
16+
17+
// Inverse of to_buffer_i
18+
// Input: buffer_idx in the continuous nchw-buffer, sizes is the tensor shape
19+
// Output: ivec4 user-level coorindate
20+
#define from_buffer_i(buf_i, sizes) \
21+
ivec4( \
22+
buf_i % sizes.x, \
23+
(buf_i / (sizes.x)) % sizes.y, \
24+
(buf_i / (sizes.x * sizes.y)) % sizes.z, \
25+
(buf_i / (sizes.x * sizes.y * sizes.z)))
1426

1527
#define get_packed_dim_C_packed(vec) vec.z
1628
#define get_packed_dim_W_packed(vec) vec.x
@@ -20,6 +32,8 @@
2032
#define get_packed_stride_W_packed(vec) (1)
2133
#define get_packed_stride_H_packed(vec) (vec.x)
2234

35+
// Input: pos is a texture position, sizes is a pack-aligned size.
36+
// Output: a user-level (w, h, c, n) coordinate
2337
#define to_tensor_idx_C_packed(pos, sizes) \
2438
ivec4(pos.x, pos.y, (pos.z * 4) % sizes.z, (pos.z * 4) / sizes.z)
2539

@@ -29,6 +43,9 @@
2943
#define to_tensor_idx_H_packed(pos, sizes) \
3044
ivec4(pos.x, (pos.y * 4), pos.z % sizes.z, pos.z / sizes.z)
3145

46+
// Input: idx is a user-level (w, h, c, n) coordinate. size is a pack-aligned
47+
// size.
48+
// Output: texture location
3249
#define to_texture_pos_C_packed(idx, sizes) \
3350
ivec3(idx.x, idx.y, (idx.z + idx.w * sizes.z) / 4)
3451

@@ -38,6 +55,19 @@
3855
#define to_texture_pos_H_packed(idx, sizes) \
3956
ivec3(idx.x, idx.y / 4, (idx.z + idx.w * sizes.z))
4057

58+
// Input: idx is a user-level (w, h, c, n) coordinate. size is a pack-aligned
59+
// size with the index in the texel.
60+
// Output: ivec4, xyz is the texture position, w is the element index in the
61+
// texel.
62+
#define to_texture_pos_elem_C_packed(idx, sizes) \
63+
ivec4(idx.x, idx.y, (idx.z + idx.w * sizes.z) / 4, idx.z % 4)
64+
65+
#define to_texture_pos_elem_W_packed(idx, sizes) \
66+
ivec4(idx.x / 4, idx.y, (idx.z + idx.w * sizes.z), idx.x % 4)
67+
68+
#define to_texture_pos_elem_H_packed(idx, sizes) \
69+
ivec4(idx.x, idx.y / 4, (idx.z + idx.w * sizes.z), idx.y % 4)
70+
4171
// Given a buffer(1-D) index cur, compute a new index where the corresponding
4272
// tensor(N-D)'s adjacent dimensions are swapped. The parameters x,y and plane
4373
// describe sizes. As an example, let's say we want to swap dimensions 0,1 for a
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
20+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
21+
22+
#define VEC4_T ${texel_type(DTYPE)}
23+
24+
#define to_tensor_idx to_tensor_idx_${PACKING}
25+
#define to_texture_pos_elem to_texture_pos_elem_${PACKING}
26+
#define get_packed_stride get_packed_stride_${PACKING}
27+
28+
layout(set = 0, binding = 2) uniform PRECISION restrict OutGpuSizes {
29+
uvec4 out_gpu_sizes;
30+
};
31+
32+
layout(set = 0, binding = 3) uniform PRECISION restrict OutCpuSizes {
33+
uvec4 out_cpu_sizes;
34+
};
35+
36+
layout(set = 0, binding = 4) uniform PRECISION restrict InGpuSizes {
37+
uvec4 in_gpu_sizes;
38+
};
39+
40+
layout(set = 0, binding = 5) uniform PRECISION restrict InCpuSizes {
41+
uvec4 in_cpu_sizes;
42+
};
43+
44+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
45+
46+
47+
void main() {
48+
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
49+
const ivec4 out_tensor_idx = to_tensor_idx(out_pos, out_gpu_sizes);
50+
51+
if (all(greaterThanEqual(out_tensor_idx, out_gpu_sizes))) {
52+
return;
53+
}
54+
55+
// Assume there is a virtual continous buffer in nchw format. From the output
56+
// pos, we first calculate the index in the virual buffer, and then calculate
57+
// the input position from the indx.
58+
59+
const uint base_index = to_buffer_i(out_tensor_idx, out_cpu_sizes);
60+
const uvec4 buf_indices =
61+
base_index + ivec4(0, 1, 2, 3) * get_packed_stride(out_cpu_sizes);
62+
63+
VEC4_T value;
64+
// Need to look up the 4 values in the output texel separately.
65+
for (int i=0; i<4; i++) {
66+
ivec4 user_coor = from_buffer_i(buf_indices[i], in_cpu_sizes);
67+
68+
ivec4 in_pos_elem = to_texture_pos_elem(user_coor, in_gpu_sizes);
69+
70+
VEC4_T intex = VEC4_T(texelFetch(image_in, in_pos_elem.xyz, 0));
71+
72+
value[i] = intex[in_pos_elem.w];
73+
}
74+
75+
imageStore(image_out, out_pos, value);
76+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
view:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
PACKING:
10+
- VALUE: C_packed
11+
- VALUE: W_packed
12+
- VALUE: H_packed
13+
shader_variants:
14+
- NAME: view
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
13+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
14+
15+
namespace vkcompute {
16+
17+
void add_view_node(ComputeGraph& graph, ValueRef in, ValueRef out) {
18+
vTensorPtr t_in = graph.get_tensor(in);
19+
vTensorPtr t_out = graph.get_tensor(out);
20+
21+
std::string kernel_name = "view";
22+
kernel_name.reserve(kShaderNameReserve);
23+
add_dtype_suffix(kernel_name, *t_out);
24+
add_memory_layout_suffix(kernel_name, *t_out);
25+
26+
api::utils::uvec3 global_size = t_out->extents();
27+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
28+
29+
graph.execute_nodes().emplace_back(new ExecuteNode(
30+
graph,
31+
VK_KERNEL_FROM_STR(kernel_name),
32+
global_size,
33+
local_size,
34+
{{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}},
35+
{t_out->gpu_sizes_ubo(),
36+
t_out->cpu_sizes_ubo(),
37+
t_in->gpu_sizes_ubo(),
38+
t_in->cpu_sizes_ubo()}));
39+
}
40+
41+
void view(ComputeGraph& graph, const std::vector<ValueRef>& args) {
42+
// Note: The second argument size_ref is not used here. Since the output
43+
// tensor's size have been determined during compilation.
44+
return add_view_node(graph, args[0], args[2]);
45+
}
46+
47+
REGISTER_OPERATORS {
48+
VK_REGISTER_OP(aten.view_copy.default, view);
49+
}
50+
51+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,33 @@ def get_permute_inputs():
194194
return test_suite
195195

196196

197+
def get_view_inputs():
198+
test_suite = VkTestSuite(
199+
[
200+
((3, 4, 5), [1, 1, -1]),
201+
((3, 4, 5), [1, -1, 1]),
202+
((3, 4, 5), [-1, 1, 1]),
203+
((8, 7, 2, 3), [4, 3, 7, 4]),
204+
((8, 7, 2, 3), [7, -1, 2, 1]),
205+
((8, 7, 2, 3), [1, 1, 1, -1]),
206+
((8, 7, 2, 3), [-1]),
207+
((2, 3, 3, 7), [2, -1, 1, 1]),
208+
((3, 5, 2, 7), [7, -1, 2, 1]),
209+
((2, 2, 8, 6), [2, 6, -1, 1]),
210+
((2, 2, 8, 6), [6, -1, 1]),
211+
((S1, S2, S1, S2), [S2, -1, 1, S1]),
212+
((S1, S2, S1, S2), [S1, 1, -1, S2]),
213+
((S1, S2, S1, S2), [-1, 1, S1, S2]),
214+
]
215+
)
216+
test_suite.layouts = [
217+
"api::kWidthPacked",
218+
"api::kHeightPacked",
219+
"api::kChannelsPacked",
220+
]
221+
return test_suite
222+
223+
197224
test_suites = {
198225
"aten.add.Tensor": get_binary_elementwise_inputs(),
199226
"aten.sub.Tensor": get_binary_elementwise_inputs(),
@@ -208,4 +235,5 @@ def get_permute_inputs():
208235
"aten.select_copy.int": get_select_int_inputs(),
209236
"aten.permute.default": get_permute_inputs(),
210237
"aten.permute_copy.default": get_permute_inputs(),
238+
"aten.view_copy.default": get_view_inputs(),
211239
}

backends/vulkan/test/op_tests/utils/codegen_base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,16 @@ def gen_case_name(self, inputs: List[Any], prepack: bool = False) -> str:
105105
for size in arg_sizes_or_val:
106106
name_str += str(size) + "x"
107107
name_str = name_str[:-1]
108+
# minus sign is a invalid char for test case. change to "n".
109+
name_str = name_str.replace("-", "n")
110+
108111
elif isinstance(arg_sizes_or_val, list):
109112
for size in arg_sizes_or_val:
110113
name_str += str(size) + "c"
111114
name_str = name_str[:-1]
115+
# minus sign is a invalid char for test case. change to "n".
116+
name_str = name_str.replace("-", "n")
117+
112118
else:
113119
name_str += str(arg_sizes_or_val).replace(".", "p")
114120
return name_str
@@ -234,7 +240,7 @@ def generate_suite_cpp(self) -> str:
234240
235241
// from_blob doesn't take ownership of data. Hence must create a copy as
236242
// "values" will go out of scope.
237-
return at::from_blob(values.data(), sizes, dtype).detach().clone();
243+
return at::from_blob(values.data(), sizes, at::kFloat).toType(dtype).detach().clone();
238244
}}
239245
240246
{test_suites_cpp}

examples/models/llama2/export_llama_lib.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,7 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
145145

146146

147147
class SDPASimple(torch.nn.Module):
148-
"""
149-
This is a simpler implementation of SDPA module defined in llama_transformer.py. Notice that it's
150-
an implementation including both some preprocessing logic and F.scaled_dot_product_attention.
151-
"""
148+
152149
def __init__(
153150
self,
154151
kv_cache: KVCache,
@@ -172,7 +169,6 @@ def forward(
172169
seqlen,
173170
mask,
174171
):
175-
# The first few lines are the same as the original SDPA module.
176172
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
177173
k = k.transpose(1, 2)
178174
v = v.transpose(1, 2)
@@ -182,11 +178,6 @@ def forward(
182178

183179
k = k.repeat_interleave(self.n_rep, dim=1)
184180
v = v.repeat_interleave(self.n_rep, dim=1)
185-
186-
# Following is the different part. Instead of calling F.scaled_dot_product_attention,
187-
# we use the following implementation to avoid the decomposition from F.scaled_dot_product_attention,
188-
# as the decompostion is too expensive. The following will get rid of aten.full_like, aten.logical_not,
189-
# aten.scalar_tensor, aten.where and 2 extra aten.mul.
190181
scale_factor = 1 / math.sqrt(q.size(-1))
191182
attn_weight = q @ k.transpose(-2, -1) * scale_factor
192183
attn_weight += attn_mask

0 commit comments

Comments
 (0)