Skip to content

Commit b39ff56

Browse files
committed
Update on "[ET-VK] Add coop shader for int8 linear"
Title says it all! ## Changes * Apply co-operative shader for vector * matrix computations. Differential Revision: [D73279548](https://our.internmc.facebook.com/intern/diff/D73279548/) [ghstack-poisoned]
2 parents ba31c57 + 834e80d commit b39ff56

File tree

3 files changed

+24
-25
lines changed

3 files changed

+24
-25
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.glsl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,15 @@ void main() {
5757
VEC4_T b[4];
5858
VEC4_T local_c[TILE_ROWS];
5959

60+
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
61+
local_c[i] = VEC4_T(0.0);
62+
}
63+
6064
$if SCALES_STORAGE == "buffer":
6165
const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]);
6266
$else:
6367
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2(out_col >> 2, 0), 0));
6468

65-
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
66-
partial_c[gid][wid][i] = VEC4_T(0.0);
67-
}
68-
6969
for (int pos = 4 * wid; pos < in_sizes.x; pos += (4 * NWORKERS)) {
7070
// Preload t_weight
7171
[[unroll]] for (int i = 0; i < 4; i++) {
@@ -77,17 +77,17 @@ void main() {
7777
// Preload t_in
7878
for (int i = 0; i < TILE_ROWS; i++) {
7979
$if IN_STORAGE == "buffer":
80-
a[i] = t_in[((out_row + i) * in_sizes.x + ((pos)) >> 2)];
80+
a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2];
8181
$else:
8282
a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0));
8383
}
8484

85-
// Compute t_out...?
85+
// Accumulate partial output
8686
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
87-
local_c[i] += a[i].x * b[0]
88-
+ a[i].y * b[1]
89-
+ a[i].z * b[2]
90-
+ a[i].w * b[3];
87+
local_c[i] += a[i].x * b[0] +
88+
a[i].y * b[1] +
89+
a[i].z * b[2] +
90+
a[i].w * b[3];
9191
}
9292
}
9393

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,18 @@ void main() {
7171
// Preload input tensor
7272
[[unroll]] for (int i = 0; i < TILE_ROWS; i++) {
7373
$if IN_STORAGE == "buffer":
74-
a[i] = t_in[((out_row + i) * in_sizes.x + (pos)) >> 2];
74+
a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2];
7575
$else:
7676
a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0));
7777
}
7878

79-
// Compute partial output
79+
// Accumulate output
8080
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
8181
c[i] += a[i].x * b[0] + a[i].y * b[1] + a[i].z * b[2] + a[i].w * b[3];
8282
}
8383
}
8484

85-
// Store output tensor
85+
// Store to output tensor
8686
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
8787
$if OUT_STORAGE == "buffer":
8888
if (out_row + i < out_sizes.y) {

backends/vulkan/test/op_tests/cases.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -153,25 +153,24 @@ def get_linear_inputs():
153153
def get_weight_int8pack_mm_inputs():
154154
MKN_list = [
155155
[1, 480, 256],
156-
# [1, 1024, 1024],
157-
# [1, 1024, 256],
158-
# [3, 480, 256],
159-
# [6, 480, 256],
160-
# [6, 256, 1024],
161-
# [6, 1024, 256],
162-
# [6, 256, 256],
163-
# [6, 256, 512],
164-
# [4, 768, 4096],
165-
# [1024, 1024, 1024],
156+
[1, 1024, 1024],
157+
[1, 1024, 256],
158+
[3, 480, 256],
159+
[6, 480, 256],
160+
[6, 256, 1024],
161+
[6, 1024, 256],
162+
[6, 256, 256],
163+
[6, 256, 512],
164+
[4, 768, 4096],
165+
[1024, 1024, 1024],
166166
]
167167

168168
inputs_list = [((M, K), (N, K), (N)) for M, K, N in MKN_list]
169169

170170
test_suite = VkTestSuite(inputs_list)
171171
test_suite.dtypes = ["at::kFloat"]
172172
test_suite.layouts = ["utils::kWidthPacked"]
173-
# test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"]
174-
test_suite.storage_types = ["utils::kBuffer"]
173+
test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"]
175174
test_suite.prepacked_args = ["mat2", "scales"]
176175
test_suite.requires_prepack = True
177176

0 commit comments

Comments
 (0)