Skip to content

Commit 39e17e4

Browse files
copyrightlyfacebook-github-bot
authored andcommitted
add 2x4 tile in mm computation (#4031)
Summary: Pull Request resolved: #4031 The existing optimized mm implementation compute output through 4x4 tile. This isn't efficient when the input tensor's height is a multiple of 3 but not a multiple of 4, e.g. 6. ~~We add a 3x4 tile computation and a parameter `HEIGHT6` to help us choose the computation manner.~~ According to nathanaelsee's experimentation, 2x4 is even more efficient than 3x4, we add 2x4 tile computation and add `TILE_ROW` in yaml files to generate shaders for 2x4 and 4x4 respectively. Reviewed By: nathanaelsee, liuk22 Differential Revision: D58769774 fbshipit-source-id: 79d8867c87464402b2c6432599b3effc12965122
1 parent caf3b1b commit 39e17e4

File tree

8 files changed

+110
-34
lines changed

8 files changed

+110
-34
lines changed

backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ $if MAT2_IS_TRANSPOSED:
1616
$if BATCH_MODE:
1717
#define BATCH_MODE
1818

19+
$if TILE_ROW == "tile_row_2":
20+
#define TILE_ROW_2
21+
1922
#include "indexing_utils.h"
2023
#include "matmul.h"
2124

@@ -56,24 +59,24 @@ void main() {
5659
}
5760

5861
$if BATCH_MODE:
59-
FloatMatrix_3d results = matmul_partial_4x4x4(
62+
FloatMatrix_3d results = matmul_partial_3d(
6063
im_mat1,
6164
im_mat2,
6265
pos,
6366
out_sizes[2],
6467
in_limits[0]);
6568
$else:
66-
FloatMatrix_2d results = matmul_partial_4x4(
69+
FloatMatrix_2d results = matmul_partial_2d(
6770
im_mat1,
6871
im_mat2,
6972
pos,
7073
out_sizes[2],
7174
in_limits[0]);
7275

73-
for (int idx_c = 0; idx_c < FOUR; idx_c++) {
76+
for (int idx_c = 0; idx_c < TILE_ROWS; idx_c++) {
7477
for (int idx_r = 0; idx_r < FOUR; idx_r++) {
7578
const ivec3 out_pos =
76-
ivec3(idx_r + FOUR * pos.x, idx_c + FOUR * pos.y, pos.z);
79+
ivec3(idx_r + FOUR * pos.x, idx_c + TILE_ROWS * pos.y, pos.z);
7780

7881
vec4 self_texel = get_texel_C_packed(
7982
im_self,

backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@ addmm_optimized:
1111
PACKING: C_packed
1212
MAT2_IS_TRANSPOSED: false
1313
BATCH_MODE: false
14+
TILE_ROW: tile_row_4
1415
generate_variant_forall:
16+
TILE_ROW:
17+
- VALUE: tile_row_4
18+
- VALUE: tile_row_2
1519
DTYPE:
1620
- VALUE: float
1721
- VALUE: half

backends/vulkan/runtime/graph/ops/glsl/matmul.h

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,20 @@
1010
// macro
1111
#define FOUR 4
1212

13+
#ifdef TILE_ROW_2
14+
#define TILE_ROWS 2
15+
#else
16+
#define TILE_ROWS 4
17+
#endif
18+
1319
// we avoid mat4 and vec4 usage here as they compile to much less efficient
1420
// SPIR-V
1521
struct FloatMatrix_2d {
16-
float data[FOUR][FOUR];
22+
float data[TILE_ROWS][FOUR];
1723
};
1824

1925
struct FloatMatrix_3d {
20-
float data[FOUR][FOUR][FOUR];
26+
float data[TILE_ROWS][FOUR][FOUR];
2127
};
2228

2329
#ifdef MAT2_IS_TRANSPOSED
@@ -150,25 +156,25 @@ vec4 get_texel_C_packed(
150156
return self_texel;
151157
}
152158

153-
FloatMatrix_2d matmul_partial_4x4(
159+
FloatMatrix_2d matmul_partial_2d(
154160
sampler3D im_mat1,
155161
sampler3D im_mat2,
156162
const ivec3 pos,
157163
const int batch_size,
158164
const int K_texel_len) {
159165
FloatMatrix_2d results;
160-
for (int i = 0; i < FOUR; i++) {
166+
for (int i = 0; i < TILE_ROWS; i++) {
161167
for (int j = 0; j < FOUR; j++) {
162168
results.data[i][j] = 0.0f;
163169
}
164170
}
165-
vec4 im_mat1_partial_load[FOUR];
171+
vec4 im_mat1_partial_load[TILE_ROWS];
166172
vec4 im_mat2_partial_load[FOUR];
167173

168174
for (int mat1_x = 0; mat1_x < K_texel_len; mat1_x++) {
169-
for (int offset = 0; offset < FOUR; offset++) {
170-
// read and cache 4x4 tile of im_mat1
171-
const int mat1_y = (FOUR * pos.y) + offset;
175+
for (int offset = 0; offset < TILE_ROWS; offset++) {
176+
// read and cache 2x4 (or 4x4) tile of im_mat1
177+
const int mat1_y = (TILE_ROWS * pos.y) + offset;
172178
const ivec3 mat1_pos = ivec3(mat1_x, mat1_y, 0);
173179
im_mat1_partial_load[offset] = texelFetch(im_mat1, mat1_pos, 0);
174180
// read and cache 4x4 tile of im_mat2
@@ -182,8 +188,24 @@ FloatMatrix_2d matmul_partial_4x4(
182188
im_mat2_partial_load[offset] = texelFetch(im_mat2, mat2_pos, 0);
183189
#endif
184190
}
191+
192+
#ifdef TILE_ROW_2
193+
// column 3 and 4 of im_mat2
194+
#ifdef MAT2_IS_TRANSPOSED
195+
im_mat2_partial_load[2] =
196+
texelFetch(im_mat2, ivec3(mat1_x, (FOUR * pos.x) + 2, 0), 0);
197+
im_mat2_partial_load[3] =
198+
texelFetch(im_mat2, ivec3(mat1_x, (FOUR * pos.x) + 3, 0), 0);
199+
#else
200+
im_mat2_partial_load[2] =
201+
texelFetch(im_mat2, ivec3((FOUR * pos.x) + 2, mat1_x, 0), 0);
202+
im_mat2_partial_load[3] =
203+
texelFetch(im_mat2, ivec3((FOUR * pos.x) + 3, mat1_x, 0), 0);
204+
#endif
205+
#endif
206+
185207
// perform partial dot products and add partial result to results
186-
for (int out_row = 0; out_row < FOUR; out_row++) {
208+
for (int out_row = 0; out_row < TILE_ROWS; out_row++) {
187209
for (int out_col = 0; out_col < FOUR; out_col++) {
188210
results.data[out_row][out_col] +=
189211
dot(im_mat1_partial_load[out_row], im_mat2_partial_load[out_col]);
@@ -193,21 +215,21 @@ FloatMatrix_2d matmul_partial_4x4(
193215
return results;
194216
}
195217

196-
FloatMatrix_3d matmul_partial_4x4x4(
218+
FloatMatrix_3d matmul_partial_3d(
197219
sampler3D im_mat1,
198220
sampler3D im_mat2,
199221
const ivec3 pos,
200222
const int batch_size,
201223
const int K_texel_len) {
202224
FloatMatrix_3d results;
203-
for (int i = 0; i < FOUR; i++) {
225+
for (int i = 0; i < TILE_ROWS; i++) {
204226
for (int j = 0; j < FOUR; j++) {
205227
for (int k = 0; k < FOUR; k++) {
206228
results.data[i][j][k] = 0.0f;
207229
}
208230
}
209231
}
210-
vec4 im_mat1_partial_load[FOUR];
232+
vec4 im_mat1_partial_load[TILE_ROWS];
211233
vec4 im_mat2_partial_load[FOUR];
212234

213235
for (int batch_idx = 0; batch_idx < FOUR; batch_idx++) {
@@ -216,9 +238,9 @@ FloatMatrix_3d matmul_partial_4x4x4(
216238
}
217239
int mat_z = FOUR * pos.z + batch_idx;
218240
for (int mat1_x = 0; mat1_x < K_texel_len; mat1_x++) {
219-
for (int offset = 0; offset < FOUR; offset++) {
220-
// read and cache 4x4 tile of im_mat1
221-
const int mat1_y = (FOUR * pos.y) + offset;
241+
for (int offset = 0; offset < TILE_ROWS; offset++) {
242+
// read and cache 2x4 (or 4x4) tile of im_mat1
243+
const int mat1_y = (TILE_ROWS * pos.y) + offset;
222244
const ivec3 mat1_pos = ivec3(mat1_x, mat1_y, mat_z);
223245
im_mat1_partial_load[offset] = texelFetch(im_mat1, mat1_pos, 0);
224246
// read and cache 4x4 tile of im_mat2
@@ -232,8 +254,24 @@ FloatMatrix_3d matmul_partial_4x4x4(
232254
im_mat2_partial_load[offset] = texelFetch(im_mat2, mat2_pos, 0);
233255
#endif
234256
}
257+
258+
#ifdef TILE_ROW_2
259+
// column 3, and 4 of im_mat2
260+
#ifdef MAT2_IS_TRANSPOSED
261+
im_mat2_partial_load[2] =
262+
texelFetch(im_mat2, ivec3(mat1_x, (FOUR * pos.x) + 2, 0), 0);
263+
im_mat2_partial_load[3] =
264+
texelFetch(im_mat2, ivec3(mat1_x, (FOUR * pos.x) + 3, 0), 0);
265+
#else
266+
im_mat2_partial_load[2] =
267+
texelFetch(im_mat2, ivec3((FOUR * pos.x) + 2, mat1_x, mat_z), 0);
268+
im_mat2_partial_load[3] =
269+
texelFetch(im_mat2, ivec3((FOUR * pos.x) + 3, mat1_x, mat_z), 0);
270+
#endif
271+
#endif
272+
235273
// perform partial dot products and add partial result to results
236-
for (int out_row = 0; out_row < FOUR; out_row++) {
274+
for (int out_row = 0; out_row < TILE_ROWS; out_row++) {
237275
for (int out_col = 0; out_col < FOUR; out_col++) {
238276
results.data[out_row][out_col][batch_idx] +=
239277
dot(im_mat1_partial_load[out_row], im_mat2_partial_load[out_col]);

backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ $if MAT2_IS_TRANSPOSED:
1616
$if BATCH_MODE:
1717
#define BATCH_MODE
1818

19+
$if TILE_ROW == "tile_row_2":
20+
#define TILE_ROW_2
21+
1922
#include "indexing_utils.h"
2023
#include "matmul.h"
2124

@@ -45,24 +48,24 @@ void main() {
4548
}
4649

4750
$if BATCH_MODE:
48-
FloatMatrix_3d results = matmul_partial_4x4x4(
51+
FloatMatrix_3d results = matmul_partial_3d(
4952
im_mat1,
5053
im_mat2,
5154
pos,
5255
out_sizes[2],
5356
in_limits[0]);
5457
$else:
55-
FloatMatrix_2d results = matmul_partial_4x4(
58+
FloatMatrix_2d results = matmul_partial_2d(
5659
im_mat1,
5760
im_mat2,
5861
pos,
5962
out_sizes[2],
6063
in_limits[0]);
6164

62-
for (int idx_c = 0; idx_c < FOUR; idx_c++) {
65+
for (int idx_c = 0; idx_c < TILE_ROWS; idx_c++) {
6366
for (int idx_r = 0; idx_r < FOUR; idx_r++) {
6467
const ivec3 out_pos =
65-
ivec3(idx_r + FOUR * pos.x, idx_c + FOUR * pos.y, pos.z);
68+
ivec3(idx_r + FOUR * pos.x, idx_c + TILE_ROWS * pos.y, pos.z);
6669

6770
// results is in transposed order w.r.t. the desired output
6871
$if BATCH_MODE:

backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@ matmul_optimized:
1111
PACKING: C_packed
1212
MAT2_IS_TRANSPOSED: false
1313
BATCH_MODE: false
14+
TILE_ROW: tile_row_4
1415
generate_variant_forall:
16+
TILE_ROW:
17+
- VALUE: tile_row_4
18+
- VALUE: tile_row_2
1519
DTYPE:
1620
- VALUE: float
1721
- VALUE: half

backends/vulkan/runtime/graph/ops/impl/Linear.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,21 +162,31 @@ void add_addmm_optimized_node(
162162
viewFn(graph, {mat2, graph.add_none(), mat2_packed});
163163
}
164164

165-
api::utils::uvec3 global_size =
166-
api::utils::divup_vec(graph.image_extents_of(out), {4, 4, 1});
167-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
168-
169165
std::string kernel_name = graph.get_bool(mat2_is_transposed)
170166
? "linear_optimized"
171167
: "addmm_optimized";
172168

173-
int mat1_dims = graph.sizes_of(mat1_W_packed).size();
169+
std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1_W_packed);
170+
int mat1_dims = mat1_sizes.size();
174171
if (mat1_dims == 3) {
175172
kernel_name = "batch_" + kernel_name;
176173
}
174+
if (mat1_sizes.at(mat1_dims - 2) < 8) {
175+
kernel_name += "_tile_row_2";
176+
} else {
177+
kernel_name += "_tile_row_4";
178+
}
177179

178180
add_dtype_suffix(kernel_name, graph.dtype_of(out));
179181

182+
api::utils::uvec3 global_size;
183+
if (mat1_sizes.at(mat1_dims - 2) < 8) {
184+
global_size = api::utils::divup_vec(graph.image_extents_of(out), {4, 2, 1});
185+
} else {
186+
global_size = api::utils::divup_vec(graph.image_extents_of(out), {4, 4, 1});
187+
}
188+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
189+
180190
graph.execute_nodes().emplace_back(new ExecuteNode(
181191
graph,
182192
VK_KERNEL_FROM_STR(kernel_name),

backends/vulkan/runtime/graph/ops/impl/MatMul.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,21 +127,31 @@ void add_matmul_optimized_node(
127127
viewFn(graph, {mat2, graph.add_none(), mat2_packed});
128128
}
129129

130-
api::utils::uvec3 global_size =
131-
api::utils::divup_vec(graph.image_extents_of(out), {4, 4, 1});
132-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
133-
134130
std::string kernel_name = mat2_is_transposed_val
135131
? "matmul_transposed_optimized"
136132
: "matmul_optimized";
137133

138-
int mat1_dims = graph.sizes_of(mat1_W_packed).size();
134+
std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1_W_packed);
135+
int mat1_dims = mat1_sizes.size();
139136
if (mat1_dims == 3) {
140137
kernel_name = "batch_" + kernel_name;
141138
}
139+
if (mat1_sizes.at(mat1_dims - 2) < 8) {
140+
kernel_name += "_tile_row_2";
141+
} else {
142+
kernel_name += "_tile_row_4";
143+
}
142144

143145
add_dtype_suffix(kernel_name, graph.dtype_of(out));
144146

147+
api::utils::uvec3 global_size;
148+
if (mat1_sizes.at(mat1_dims - 2) < 8) {
149+
global_size = api::utils::divup_vec(graph.image_extents_of(out), {4, 2, 1});
150+
} else {
151+
global_size = api::utils::divup_vec(graph.image_extents_of(out), {4, 4, 1});
152+
}
153+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
154+
145155
graph.execute_nodes().emplace_back(new ExecuteNode(
146156
graph,
147157
VK_KERNEL_FROM_STR(kernel_name),

backends/vulkan/test/op_tests/cases.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def get_mm_inputs():
6464
[
6565
((M1, L), (L, M2)),
6666
((S1, S2), (S2, M)),
67+
((6, 32), (32, 64)),
6768
],
6869
)
6970
test_suite.prepacked_args = ["mat2"]
@@ -82,6 +83,7 @@ def get_bmm_inputs():
8283
[
8384
((S, M1, L), (S, L, M2)),
8485
((M, S1, S2), (M, S2, M)),
86+
((4, 6, 32), (4, 32, 16)),
8587
],
8688
)
8789
test_suite.prepacked_args = ["mat2"]
@@ -104,6 +106,7 @@ def get_addmm_inputs():
104106
((M1, M2), (M1, M2), (M2, M2), 4.2, 2.3),
105107
((M1, 1), (M1, L), (L, L), 2.0, 3.0),
106108
((M2), (M1, M2), (M2, M2)),
109+
((6, M2), (6, M2), (M2, M2)),
107110
]
108111
)
109112
# ATen matmul doesn't support half
@@ -129,6 +132,7 @@ def get_linear_inputs():
129132
inputs_list += [((M, K), (N, K), (N)) for M, K, N in MKN_list]
130133
inputs_list += [((3, M, K), (N, K), None) for M, K, N in MKN_list]
131134
inputs_list += [((3, M, K), (N, K), (N)) for M, K, N in MKN_list]
135+
inputs_list += [((3, 6, K), (N, K), (N)) for M, K, N in MKN_list]
132136

133137
test_suite = VkTestSuite(inputs_list)
134138
test_suite.dtypes = ["at::kFloat"]

0 commit comments

Comments
 (0)