Skip to content

Commit 4afc4fb

Browse files
authored
[ET-VK] Add buffer implementation for matrix multiplication
Differential Revision: D61666461 Pull Request resolved: #4845
1 parent 33fbe03 commit 4afc4fb

File tree

9 files changed

+211
-69
lines changed

9 files changed

+211
-69
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ utils::uvec3 ComputeGraph::create_local_wg_size(
368368
}
369369

370370
utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) {
371-
return create_local_wg_size(image_extents_of(idx));
371+
return create_local_wg_size(create_global_wg_size(idx));
372372
}
373373

374374
void ComputeGraph::copy_into_staging(

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,21 @@ class ComputeGraph final {
186186

187187
std::vector<int64_t> sizes_of(const ValueRef idx) const;
188188

189+
/*
190+
* Returns the size of the tensor at `idx` along the specified dimension.
191+
* Negative indexing is allowed.
192+
*/
193+
template <typename T>
194+
T size_at(const int64_t dim, const ValueRef idx) const {
195+
const Value& val = values_.at(idx);
196+
if (val.isTensor()) {
197+
return static_cast<T>(utils::val_at(dim, val.toConstTensor().sizes()));
198+
} else if (val.isTensorRef()) {
199+
return static_cast<T>(utils::val_at(dim, val.toConstTensorRef().sizes));
200+
}
201+
VK_THROW("Could not get sizes of value with type ", val.type());
202+
}
203+
189204
vkapi::ScalarType dtype_of(const ValueRef idx) const;
190205

191206
inline utils::uvec3 image_extents_of(const ValueRef idx) const {
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define T ${buffer_scalar_type(DTYPE)}
14+
15+
${define_required_extensions(DTYPE)}
16+
17+
layout(std430) buffer;
18+
19+
${layout_declare_tensor(0, "w", "t_out", DTYPE, "buffer")}
20+
${layout_declare_tensor(1, "r", "t_mat1", DTYPE, "buffer")}
21+
${layout_declare_tensor(2, "r", "t_mat2", DTYPE, "buffer")}
22+
${layout_declare_ubo(3, "ivec4", "out_sizes")}
23+
${layout_declare_ubo(4, "ivec4", "out_strides")}
24+
${layout_declare_ubo(5, "ivec4", "mat1_sizes")}
25+
${layout_declare_ubo(6, "ivec4", "mat1_strides")}
26+
${layout_declare_ubo(7, "ivec4", "mat2_sizes")}
27+
${layout_declare_ubo(8, "ivec4", "mat2_strides")}
28+
${layout_declare_ubo(9, "int", "out_numel")}
29+
30+
#include "indexing_utils.h"
31+
32+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
33+
34+
void main() {
35+
const ivec4 out_idx = ivec4(
36+
gl_GlobalInvocationID.x,
37+
gl_GlobalInvocationID.y,
38+
gl_GlobalInvocationID.z % out_sizes.z,
39+
gl_GlobalInvocationID.z / out_sizes.z);
40+
41+
if (any(greaterThanEqual(out_idx, out_sizes))) {
42+
return;
43+
}
44+
45+
int mat1_id = to_buffer_id(
46+
ivec4(0, out_idx.y, out_idx.z, out_idx.w), mat1_strides);
47+
int mat2_id = to_buffer_id(
48+
ivec4(out_idx.x, 0, out_idx.z, out_idx.w), mat2_strides);
49+
50+
T sum = T(0.0);
51+
for (int i = 0; i < mat1_sizes.x; ++i) {
52+
sum += t_mat1[mat1_id] * t_mat2[mat2_id];
53+
54+
mat1_id += mat1_strides.x;
55+
mat2_id += mat2_strides.y;
56+
}
57+
58+
const int out_id = to_buffer_id(out_idx, out_strides);
59+
t_out[out_id] = T(sum);
60+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
matmul_naive_buffer:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: buffer
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: float
14+
- VALUE: half
15+
shader_variants:
16+
- NAME: matmul_naive_buffer

backends/vulkan/runtime/graph/ops/glsl/matmul_naive.glsl renamed to backends/vulkan/runtime/graph/ops/glsl/matmul_naive_texture3d.glsl

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,11 @@ $if MAT2_IS_TRANSPOSED:
1616
#include "indexing_utils.h"
1717
#include "matmul.h"
1818

19-
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly image3D im_out;
20-
layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat1;
21-
layout(set = 0, binding = 2) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat2;
22-
23-
layout(set = 0, binding = 3) uniform PRECISION restrict OutLimits {
24-
ivec3 out_limits;
25-
};
26-
27-
layout(set = 0, binding = 4) uniform PRECISION restrict InSizes {
28-
ivec4 in_sizes;
29-
};
19+
${layout_declare_tensor(0, "w", "im_out", DTYPE, "texture3d")}
20+
${layout_declare_tensor(1, "r", "im_mat1", DTYPE, "texture3d")}
21+
${layout_declare_tensor(2, "r", "im_mat2", DTYPE, "texture3d")}
22+
${layout_declare_ubo(3, "ivec3", "out_limits")}
23+
${layout_declare_ubo(4, "ivec4", "in_sizes")}
3024

3125
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3226

backends/vulkan/runtime/graph/ops/glsl/matmul_naive.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/matmul_naive_texture3d.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
matmul_naive:
7+
matmul_naive_texture3d:
88
parameter_names_with_default_values:
99
DTYPE: float
10-
NDIM: 3
10+
STORAGE: texture3d
1111
MAT1_PACKING: W_packed
1212
MAT2_PACKING: H_packed
1313
MAT2_IS_TRANSPOSED: false
@@ -16,9 +16,9 @@ matmul_naive:
1616
- VALUE: float
1717
- VALUE: half
1818
shader_variants:
19-
- NAME: matmul_naive_W_packed_H_packed
20-
- NAME: matmul_naive_W_packed_W_packed
19+
- NAME: matmul_naive_texture3d_W_packed_H_packed
20+
- NAME: matmul_naive_texture3d_W_packed_W_packed
2121
MAT2_PACKING: W_packed
22-
- NAME: matmul_transposed_naive_W_packed_W_packed
22+
- NAME: matmul_transposed_naive_texture3d_W_packed_W_packed
2323
MAT2_PACKING: W_packed
2424
MAT2_IS_TRANSPOSED: true

backends/vulkan/runtime/graph/ops/impl/MatMul.cpp

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,48 @@ void resize_matmul_node(
6262
out->virtual_resize(new_out_sizes);
6363
}
6464

65-
void add_matmul_naive_node(
65+
void add_matmul_naive_buffer_node(
66+
ComputeGraph& graph,
67+
const ValueRef mat1,
68+
const ValueRef mat2_data,
69+
const ValueRef out,
70+
const ValueRef mat2_is_transposed) {
71+
ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, utils::kHeightPacked);
72+
73+
std::string kernel_name = "matmul_naive_buffer";
74+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
75+
76+
utils::uvec3 global_size = {
77+
graph.size_at<uint32_t>(-1, out),
78+
graph.size_at<uint32_t>(-2, out),
79+
graph.size_at<uint32_t>(-3, out) * graph.size_at<uint32_t>(-4, out)};
80+
81+
graph.execute_nodes().emplace_back(new ExecuteNode(
82+
graph,
83+
VK_KERNEL_FROM_STR(kernel_name),
84+
global_size,
85+
graph.create_local_wg_size(global_size),
86+
// Inputs and Outputs
87+
{{out, vkapi::MemoryAccessType::WRITE},
88+
{{mat1, mat2}, vkapi::MemoryAccessType::READ}},
89+
// Shader params buffers
90+
{
91+
graph.sizes_ubo(out),
92+
graph.strides_ubo(out),
93+
graph.sizes_ubo(mat1),
94+
graph.strides_ubo(mat1),
95+
graph.sizes_ubo(mat2),
96+
graph.strides_ubo(mat2),
97+
graph.numel_ubo(out),
98+
},
99+
// Specialization Constants
100+
{},
101+
// Resizing Logic
102+
resize_matmul_node,
103+
{mat2_is_transposed}));
104+
}
105+
106+
void add_matmul_naive_texture3d_node(
66107
ComputeGraph& graph,
67108
const ValueRef mat1,
68109
const ValueRef mat2_data,
@@ -74,6 +115,7 @@ void add_matmul_naive_node(
74115
? "matmul_transposed_naive"
75116
: "matmul_naive";
76117
kernel_name.reserve(kShaderNameReserve);
118+
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
77119
add_memory_layout_suffix(kernel_name, graph.memory_layout_of(mat1));
78120
add_memory_layout_suffix(kernel_name, graph.memory_layout_of(mat2));
79121
add_dtype_suffix(kernel_name, graph.dtype_of(out));
@@ -174,12 +216,16 @@ void add_matmul_node(
174216
const ValueRef mat2_data,
175217
const ValueRef out,
176218
const ValueRef mat2_is_transposed) {
177-
if (graph.memory_layout_of(mat1) == utils::kChannelsPacked) {
219+
if (graph.is_buffer_storage(out)) {
220+
add_matmul_naive_buffer_node(
221+
graph, mat1, mat2_data, out, mat2_is_transposed);
222+
} else if (graph.memory_layout_of(mat1) == utils::kChannelsPacked) {
178223
add_matmul_optimized_node(graph, mat1, mat2_data, out, mat2_is_transposed);
179224
} else if (graph.memory_layout_of(mat1) == utils::kWidthPacked) {
180-
add_matmul_naive_node(graph, mat1, mat2_data, out, mat2_is_transposed);
225+
add_matmul_naive_texture3d_node(
226+
graph, mat1, mat2_data, out, mat2_is_transposed);
181227
} else {
182-
VK_THROW("Input should be channel packed or width packed.");
228+
VK_THROW("Input texture should be channel packed or width packed.");
183229
}
184230
}
185231

backends/vulkan/test/op_tests/cases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def get_mm_inputs():
7070
test_suite.prepacked_args = ["mat2"]
7171
# ATen matmul doesn't support half
7272
test_suite.dtypes = ["at::kFloat"]
73+
test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"]
7374
test_suite.layouts = [
7475
"utils::kWidthPacked",
7576
"utils::kChannelsPacked",

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 58 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2282,24 +2282,28 @@ void test_binary_op(
22822282
}
22832283
}
22842284

2285-
#define CALL_TEST_FN_FORALL_CONDITIONS(_) \
2286-
_(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_WIDTH_PACKED, false) \
2287-
_(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_HEIGHT_PACKED, false) \
2288-
_(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, false) \
2289-
_(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_WIDTH_PACKED, true) \
2290-
_(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_HEIGHT_PACKED, true) \
2291-
_(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, true)
2292-
2293-
#define CALL_TEST_FN_FOR_W_PACKED(_) \
2294-
_(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_WIDTH_PACKED, false) \
2295-
_(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_WIDTH_PACKED, true)
2296-
2297-
#define CALL_TEST_FN_FOR_C_PACKED(_) \
2298-
_(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, false) \
2299-
_(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, true)
2285+
#define CALL_TEST_FN_FORALL_CONDITIONS(_) \
2286+
_(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, false) \
2287+
_(vkapi::kFloat, utils::kTexture3D, utils::kHeightPacked, false) \
2288+
_(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked, false) \
2289+
_(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, true) \
2290+
_(vkapi::kFloat, utils::kTexture3D, utils::kHeightPacked, true) \
2291+
_(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked, true)
2292+
2293+
#define CALL_TEST_FN_FOR_W_PACKED(_) \
2294+
_(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, false) \
2295+
_(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, true) \
2296+
_(vkapi::kFloat, utils::kBuffer, utils::kWidthPacked, false) \
2297+
_(vkapi::kFloat, utils::kBuffer, utils::kWidthPacked, true)
2298+
2299+
#define CALL_TEST_FN_FOR_C_PACKED(_) \
2300+
_(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked, false) \
2301+
_(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked, true) \
2302+
_(vkapi::kFloat, utils::kBuffer, utils::kChannelsPacked, false) \
2303+
_(vkapi::kFloat, utils::kBuffer, utils::kChannelsPacked, true)
23002304

23012305
TEST(VulkanComputeGraphOpsTest, add_smoke_test) {
2302-
#define RUN_TESTS(dtype, layout, prepack) \
2306+
#define RUN_TESTS(dtype, storage, layout, prepack) \
23032307
test_binary_op("add", {17, 21}, {17, 21}, dtype, layout, prepack); \
23042308
test_binary_op("add", {17, 21}, {1, 1}, dtype, layout, prepack); \
23052309
test_binary_op("sub", {11, 22}, {11, 22}, dtype, layout, prepack); \
@@ -2320,9 +2324,11 @@ void test_mm(
23202324
int K,
23212325
int N,
23222326
vkapi::ScalarType dtype,
2327+
utils::StorageType storage_type,
23232328
utils::GPUMemoryLayout memory_layout,
23242329
bool prepack = true) {
23252330
GraphConfig config;
2331+
config.set_storage_type_override(storage_type);
23262332
ComputeGraph graph(config);
23272333

23282334
std::vector<int64_t> mat1_size = {M, K};
@@ -2379,38 +2385,42 @@ void test_mm(
23792385
}
23802386

23812387
TEST(VulkanComputeGraphOpsTest, mm_smoke_test) {
2382-
#define RUN_TESTS(dtype, layout, prepack) \
2383-
test_mm( \
2384-
/*B = */ 1, \
2385-
/*M = */ 31, \
2386-
/*K = */ 127, \
2387-
/*N = */ 23, \
2388-
dtype, \
2389-
layout, \
2390-
prepack); \
2391-
test_mm( \
2392-
/*B = */ 5, \
2393-
/*M = */ 31, \
2394-
/*K = */ 127, \
2395-
/*N = */ 23, \
2396-
dtype, \
2397-
layout, \
2398-
prepack); \
2399-
test_mm( \
2400-
/*B = */ 7, \
2401-
/*M = */ 13, \
2402-
/*K = */ 89, \
2403-
/*N = */ 17, \
2404-
dtype, \
2405-
layout, \
2406-
prepack); \
2407-
test_mm( \
2408-
/*B = */ 1, \
2409-
/*M = */ 13, \
2410-
/*K = */ 89, \
2411-
/*N = */ 17, \
2412-
dtype, \
2413-
layout, \
2388+
#define RUN_TESTS(dtype, storage_type, layout, prepack) \
2389+
test_mm( \
2390+
/*B = */ 1, \
2391+
/*M = */ 31, \
2392+
/*K = */ 127, \
2393+
/*N = */ 23, \
2394+
dtype, \
2395+
storage_type, \
2396+
layout, \
2397+
prepack); \
2398+
test_mm( \
2399+
/*B = */ 5, \
2400+
/*M = */ 31, \
2401+
/*K = */ 127, \
2402+
/*N = */ 23, \
2403+
dtype, \
2404+
storage_type, \
2405+
layout, \
2406+
prepack); \
2407+
test_mm( \
2408+
/*B = */ 7, \
2409+
/*M = */ 13, \
2410+
/*K = */ 89, \
2411+
/*N = */ 17, \
2412+
dtype, \
2413+
storage_type, \
2414+
layout, \
2415+
prepack); \
2416+
test_mm( \
2417+
/*B = */ 1, \
2418+
/*M = */ 13, \
2419+
/*K = */ 89, \
2420+
/*N = */ 17, \
2421+
dtype, \
2422+
storage_type, \
2423+
layout, \
24142424
prepack);
24152425

24162426
CALL_TEST_FN_FOR_W_PACKED(RUN_TESTS);

0 commit comments

Comments
 (0)