Skip to content

Commit 06708ef

Browse files
committed
Update base for Update on "add 16a4w_hqq quant mode"
Prerequistie: install hqq following https://github.com/mobiusml/hqq Step 1: use hqq to quantize weight to 4bit Step 2: use static quant to quantize activation to 16bit Currently the graph calibration is too slow, so adding the the quant oberserver to the eager model for faster iteration command: ``` python -m examples.models.llama2.eval_llama -t /data/users/chenlai/models/llama2/tokenizer.model -p /data/users/chenlai/models/llama2/params.json -c /data/users/chenlai/models/llama2/consolidated.00.pth --max_seq_len 129 -qmode 16a4w-hqq --limit 5 2>&1 | tee hqq_16a4w.log ``` Differential Revision: [D57849772](https://our.internmc.facebook.com/intern/diff/D57849772/) [ghstack-poisoned]
2 parents 17d117c + c665c17 commit 06708ef

27 files changed

+1058
-112
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __contains__(self, op):
9999
]
100100

101101
INDEXING_OPS = [
102+
exir_ops.edge.aten.index_select.default,
102103
exir_ops.edge.aten.select_copy.int,
103104
exir_ops.edge.aten.slice_copy.Tensor,
104105
]

backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,29 +19,14 @@
1919

2020
layout(std430) buffer;
2121

22-
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
23-
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
24-
layout(set = 0, binding = 2) uniform PRECISION sampler3D image_other;
25-
26-
layout(set = 0, binding = 3) uniform PRECISION restrict OutSizes {
27-
ivec4 out_sizes;
28-
};
29-
30-
layout(set = 0, binding = 4) uniform PRECISION restrict InSizes {
31-
ivec4 in_sizes;
32-
};
33-
34-
layout(set = 0, binding = 5) uniform PRECISION restrict OtherSizes {
35-
ivec4 other_sizes;
36-
};
37-
38-
layout(set = 0, binding = 6) uniform PRECISION restrict BroadcastParams {
39-
ivec2 broadcast_params;
40-
};
41-
42-
layout(set = 0, binding = 7) uniform PRECISION restrict Alpha {
43-
float alpha;
44-
};
22+
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
23+
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
24+
${layout_declare_tensor(2, "r", "t_other", DTYPE, STORAGE)}
25+
${layout_declare_ubo(3, "ivec4", "out_sizes")}
26+
${layout_declare_ubo(4, "ivec4", "in_sizes")}
27+
${layout_declare_ubo(5, "ivec4", "other_sizes")}
28+
${layout_declare_ubo(6, "ivec2", "broadcast_params")}
29+
${layout_declare_ubo(7, "float", "alpha")}
4530

4631
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4732

@@ -57,13 +42,13 @@ void main() {
5742

5843
ivec4 in_idx = broadcast_indices(idx, in_sizes);
5944
VEC4_T in_texel = VEC4_T(texelFetch(
60-
image_in,
45+
t_in,
6146
to_texture_pos(in_idx, in_sizes, packed_dim),
6247
0));
6348

6449
ivec4 other_idx = broadcast_indices(idx, other_sizes);
6550
VEC4_T other_texel = VEC4_T(texelFetch(
66-
image_other,
51+
t_other,
6752
to_texture_pos(other_idx, other_sizes, packed_dim),
6853
0));
6954

@@ -75,5 +60,5 @@ void main() {
7560
other_texel = other_texel.xxxx;
7661
}
7762

78-
imageStore(image_out, pos, VEC4_T(op(in_texel, other_texel, alpha)));
63+
imageStore(t_out, pos, VEC4_T(op(in_texel, other_texel, alpha)));
7964
}

backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ binary_op:
1010
NDIM: 3
1111
DTYPE: float
1212
PACKING: C_packed
13+
STORAGE: texture3d
1314
generate_variant_forall:
1415
DTYPE:
1516
- VALUE: half
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
20+
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
21+
${layout_declare_tensor(2, "r", "t_idx", "int", STORAGE)}
22+
${layout_declare_ubo(3, "ivec4", "sizes")}
23+
${layout_declare_ubo(4, "int", "gpu_dim", "int", "stride")}
24+
25+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
26+
27+
layout(constant_id = 3) const int packed_dim = C_DIM;
28+
29+
void main() {
30+
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
31+
32+
if (pos_out_of_bounds(out_pos, sizes, packed_dim)) {
33+
return;
34+
}
35+
36+
const int out_idx = out_pos[gpu_dim] / stride;
37+
const int within_stride = out_pos[gpu_dim] % stride;
38+
const int in_idx = texelFetch(t_idx, ivec3(out_idx, 0, 0), 0).x;
39+
40+
ivec3 in_pos = out_pos;
41+
in_pos[gpu_dim] = in_idx * stride + within_stride;
42+
43+
imageStore(t_out, out_pos, texelFetch(t_in, in_pos, 0));
44+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
index_select:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
STORAGE: texture3d
6+
generate_variant_forall:
7+
DTYPE:
8+
- VALUE: half
9+
- VALUE: float
10+
- VALUE: int
11+
shader_variants:
12+
- NAME: index_select
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
20+
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
21+
${layout_declare_tensor(2, "r", "t_idx", "int", STORAGE)}
22+
${layout_declare_ubo(3, "ivec4", "out_sizes")}
23+
${layout_declare_ubo(4, "ivec4", "in_sizes")}
24+
25+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
26+
27+
layout(constant_id = 3) const int packed_dim = C_DIM;
28+
29+
void main() {
30+
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
31+
32+
if (pos_out_of_bounds(out_pos, out_sizes, packed_dim)) {
33+
return;
34+
}
35+
36+
const ivec4 idx = to_tensor_idx(out_pos, out_sizes, packed_dim);
37+
const ivec4 buffer_ixs = get_texel_nchw_buffer_ixs(idx, out_sizes, packed_dim);
38+
39+
VEC4_T out_texel;
40+
for (int i = 0; i < 4; ++i) {
41+
const ivec4 out_idx = from_nchw_buffer_i(buffer_ixs[i], out_sizes);
42+
int out_channel = out_idx.z;
43+
int in_channel = texelFetch(t_idx, ivec3(out_channel, 0, 0), 0).x;
44+
45+
ivec4 in_idx = out_idx;
46+
in_idx.z = in_channel;
47+
48+
ivec4 in_elem_pos = to_texture_elem_pos(in_idx, in_sizes, packed_dim);
49+
50+
VEC4_T in_texel = texelFetch(t_in, in_elem_pos.xyz, 0);
51+
52+
out_texel[i] = in_texel[in_elem_pos.w];
53+
}
54+
imageStore(t_out, out_pos, out_texel);
55+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
index_select_channel:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
STORAGE: texture3d
6+
generate_variant_forall:
7+
DTYPE:
8+
- VALUE: half
9+
- VALUE: float
10+
- VALUE: int
11+
shader_variants:
12+
- NAME: index_select_channel

backends/vulkan/runtime/graph/ops/glsl/upsample.glsl renamed to backends/vulkan/runtime/graph/ops/glsl/upsample_nearest2d.glsl

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,15 @@
1313

1414
#define PRECISION ${PRECISION}
1515

16-
#define VEC4_T ${texel_type(DTYPE)}
16+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
1717

1818
layout(std430) buffer;
1919

20-
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
21-
22-
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
23-
24-
layout(set = 0, binding = 2) uniform PRECISION restrict OutLimits {
25-
ivec3 out_limits;
26-
};
27-
28-
layout(set = 0, binding = 3) uniform PRECISION restrict Sizes {
29-
ivec4 sizes;
30-
};
20+
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
21+
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
22+
${layout_declare_ubo(2, "ivec3", "out_limits")}
23+
${layout_declare_ubo(3, "ivec2", "input_size")}
24+
${layout_declare_ubo(4, "vec2", "rev_scales")}
3125

3226
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3327

@@ -38,6 +32,8 @@ void main() {
3832
return;
3933
}
4034

41-
VEC4_T in_texel = texelFetch(image_in, pos, 0);
42-
imageStore(image_out, pos, in_texel);
35+
const ivec2 ipos = clamp(ivec2(pos.xy * rev_scales), ivec2(0), input_size);
36+
37+
VEC4_T in_texel = texelFetch(t_in, ivec3(ipos, pos.z), 0);
38+
imageStore(t_out, pos, in_texel);
4339
}

backends/vulkan/runtime/graph/ops/glsl/upsample.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/upsample_nearest2d.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
upsample:
7+
upsample_nearest2d:
88
parameter_names_with_default_values:
99
NDIM: 3
1010
DTYPE: float
1111
PACKING: C_packed
12+
STORAGE: texture3d
1213
generate_variant_forall:
1314
DTYPE:
1415
- VALUE: half
1516
- VALUE: float
1617
shader_variants:
17-
- NAME: upsample
18+
- NAME: upsample_nearest2d
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15+
16+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17+
18+
namespace vkcompute {
19+
20+
void check_index_select_args(
21+
const vTensor& in,
22+
const vTensor& idx,
23+
const vTensor& out) {
24+
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
25+
VK_CHECK_COND(check_memory_layout_is(idx, api::kChannelsPacked));
26+
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
27+
}
28+
29+
void add_index_select_channel_node(
30+
ComputeGraph& graph,
31+
ValueRef in,
32+
ValueRef idx,
33+
ValueRef out) {
34+
vTensorPtr t_in = graph.get_tensor(in);
35+
vTensorPtr t_idx = graph.get_tensor(idx);
36+
vTensorPtr t_out = graph.get_tensor(out);
37+
38+
check_index_select_args(*t_in, *t_idx, *t_out);
39+
40+
std::string kernel_name = "index_select_channel";
41+
kernel_name.reserve(kShaderNameReserve);
42+
add_dtype_suffix(kernel_name, *t_out);
43+
44+
api::utils::uvec3 global_size = t_out->image_extents();
45+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
46+
47+
graph.execute_nodes().emplace_back(new ExecuteNode(
48+
graph,
49+
VK_KERNEL_FROM_STR(kernel_name),
50+
global_size,
51+
local_size,
52+
{{out, api::MemoryAccessType::WRITE},
53+
{{in, idx}, api::MemoryAccessType::READ}},
54+
{t_out->sizes_ubo(), t_in->sizes_ubo()}));
55+
}
56+
57+
struct IndexSelectParams final {
58+
int32_t gpu_dim;
59+
int32_t stride;
60+
};
61+
62+
IndexSelectParams create_index_select_params(
63+
const int64_t dim_idx,
64+
const vTensor& in) {
65+
if (dim_idx == kWidth4D) {
66+
return {0, 1};
67+
} else if (dim_idx == kHeight4D) {
68+
return {1, 1};
69+
} else if (dim_idx == kBatch4D) {
70+
int64_t n_channels = dim_at(in.sizes(), kChannel4D);
71+
int64_t stride = api::utils::div_up_4(n_channels);
72+
return {2, static_cast<int32_t>(stride)};
73+
} else {
74+
VK_THROW("Unexpected dim_idx!");
75+
}
76+
}
77+
78+
void add_index_select_node(
79+
ComputeGraph& graph,
80+
ValueRef in,
81+
const int64_t dim_idx,
82+
ValueRef idx,
83+
ValueRef out) {
84+
vTensorPtr t_in = graph.get_tensor(in);
85+
vTensorPtr t_idx = graph.get_tensor(idx);
86+
vTensorPtr t_out = graph.get_tensor(out);
87+
88+
check_index_select_args(*t_in, *t_idx, *t_out);
89+
90+
IndexSelectParams params = create_index_select_params(dim_idx, *t_in);
91+
92+
std::string kernel_name = "index_select";
93+
kernel_name.reserve(kShaderNameReserve);
94+
add_dtype_suffix(kernel_name, *t_out);
95+
96+
api::utils::uvec3 global_size = t_out->image_extents();
97+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
98+
99+
graph.execute_nodes().emplace_back(new ExecuteNode(
100+
graph,
101+
VK_KERNEL_FROM_STR(kernel_name),
102+
global_size,
103+
local_size,
104+
{{out, api::MemoryAccessType::WRITE},
105+
{{in, idx}, api::MemoryAccessType::READ}},
106+
{t_out->sizes_ubo(), graph.create_params_buffer(params)}));
107+
}
108+
109+
int64_t get_dim_idx(ComputeGraph& graph, ValueRef in, ValueRef dim_ref) {
110+
vTensorPtr t_in = graph.get_tensor(in);
111+
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
112+
dim = normalize(dim, t_in->dim());
113+
return normalize_to_dim_index(*t_in, dim);
114+
}
115+
116+
void index_select(ComputeGraph& graph, const std::vector<ValueRef>& args) {
117+
ValueRef in = prepack_if_tensor_ref(graph, args[0]);
118+
ValueRef dim_ref = args[1];
119+
ValueRef idx = prepack_if_tensor_ref(graph, args[2]);
120+
ValueRef out = args[3];
121+
122+
const int64_t dim_idx = get_dim_idx(graph, in, dim_ref);
123+
if (dim_idx == kChannel4D) {
124+
add_index_select_channel_node(graph, in, idx, out);
125+
} else {
126+
add_index_select_node(graph, in, dim_idx, idx, out);
127+
}
128+
}
129+
130+
REGISTER_OPERATORS {
131+
VK_REGISTER_OP(aten.index_select.default, index_select);
132+
}
133+
134+
} // namespace vkcompute

0 commit comments

Comments
 (0)