Skip to content

Commit 3277f6d

Browse files
committed
Update base for Update on "[ET-VK] Add coop shader for int8 linear"
Title says it all! ## Changes * Apply co-operative shader for vector * matrix computations. Differential Revision: [D73279548](https://our.internmc.facebook.com/intern/diff/D73279548/) [ghstack-poisoned]
1 parent 4d1618e commit 3277f6d

File tree

3 files changed

+15
-6
lines changed

3 files changed

+15
-6
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ void main() {
5353
$if SCALES_STORAGE == "buffer":
5454
const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]);
5555
$else:
56-
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec3(out_col >> 2, 0, 0), 0));
56+
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2(out_col >> 2, 0), 0));
5757

5858
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
5959
c[i] = VEC4_T(0.0);

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ q_8w_linear_tiled:
1010
IN_STORAGE: texture3d
1111
OUT_STORAGE: texture3d
1212
WEIGHT_STORAGE: texture2d
13-
SCALES_STORAGE: buffer
13+
SCALES_STORAGE: texture2d
1414
TILE_ROWS: 4
1515
generate_variant_forall:
1616
TILE_ROWS:
@@ -21,11 +21,12 @@ q_8w_linear_tiled:
2121
- VALUE: 6
2222
SUFFIX: o4x6
2323
shader_variants:
24-
- NAME: q_8w_linear_tiled_texture3d_texture3d_texture2d_float
25-
- NAME: q_8w_linear_tiled_buffer_buffer_texture2d_float
24+
- NAME: q_8w_linear_tiled_texture3d_texture3d_texture2d_texture2d_float
25+
- NAME: q_8w_linear_tiled_buffer_buffer_texture2d_texture2d_float
2626
IN_STORAGE: buffer
2727
OUT_STORAGE: buffer
28-
- NAME: q_8w_linear_tiled_buffer_buffer_buffer_float
28+
- NAME: q_8w_linear_tiled_buffer_buffer_buffer_buffer_float
2929
IN_STORAGE: buffer
3030
OUT_STORAGE: buffer
3131
WEIGHT_STORAGE: buffer
32+
SCALES_STORAGE: buffer

backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,14 +161,19 @@ void add_q_8w_linear_tiled_node(
161161
ValueRef q_mat2 = prepack_standard_hw_transposed(
162162
graph, q_mat2_data, q_mat2_storage, utils::kWidthPacked);
163163

164+
utils::StorageType scales_storage = utils::kTexture2D;
165+
if (N > max_extent) {
166+
scales_storage = utils::kBuffer;
167+
}
164168
ValueRef scales =
165-
prepack_standard(graph, scales_data, utils::kBuffer, utils::kWidthPacked);
169+
prepack_standard(graph, scales_data, scales_storage, utils::kWidthPacked);
166170

167171
std::string kernel_name = "q_8w_linear_tiled";
168172
kernel_name.reserve(kShaderNameReserve);
169173
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
170174
add_storage_type_suffix(kernel_name, graph.storage_type_of(mat1));
171175
add_storage_type_suffix(kernel_name, graph.storage_type_of(q_mat2));
176+
add_storage_type_suffix(kernel_name, graph.storage_type_of(scales));
172177
add_dtype_suffix(kernel_name, graph.dtype_of(out));
173178

174179
std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1);
@@ -177,6 +182,9 @@ void add_q_8w_linear_tiled_node(
177182
if (M % 6 == 0) {
178183
kernel_name += "_o4x6";
179184
out_tile_nrows = 6;
185+
} else if (M % 4 == 0) {
186+
kernel_name += "_o4x4";
187+
out_tile_nrows = 4;
180188
} else if (M % 1 == 0) {
181189
kernel_name += "_o4x1";
182190
out_tile_nrows = 1;

0 commit comments

Comments
 (0)