Skip to content

Commit 5f133b3

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Add matmul operator (#2517)
Summary: Pull Request resolved: #2517 ## Context Add matrix multiplication operator support. ghstack-source-id: 219281501 Reviewed By: jorgep31415 Differential Revision: D55031043 fbshipit-source-id: d7f5a3ff1e421602e75ec1904043ca07681a3b35
1 parent f7300b2 commit 5f133b3

File tree

9 files changed

+371
-3
lines changed

9 files changed

+371
-3
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,19 @@
2626
class VulkanSupportedOperators(OperatorSupportBase):
2727
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
2828
supported = node.op == "call_function" and node.target in [
29-
# BinaryOp
29+
# Binary arithmetic operators
3030
exir_ops.edge.aten.add.Tensor,
3131
exir_ops.edge.aten.sub.Tensor,
3232
exir_ops.edge.aten.mul.Tensor,
3333
exir_ops.edge.aten.div.Tensor,
3434
exir_ops.edge.aten.div.Tensor_mode,
3535
exir_ops.edge.aten.pow.Tensor_Tensor,
36-
# Clamp
36+
# Activation operators
3737
exir_ops.edge.aten.clamp.default,
3838
exir_ops.edge.aten.hardtanh.default,
3939
exir_ops.edge.aten.relu.default,
40+
# Matrix multiplication operators
41+
exir_ops.edge.aten.mm.default,
4042
# Other
4143
operator.getitem,
4244
]

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#define DIVUP4(x) ((x + 3) / 4)
10+
911
#define PACKED_DIM_CHANNELS_PACKED(vec) vec.z
1012

1113
#define PACKED_DIM_WIDTH_PACKED(vec) vec.x
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#include "indexing_utils.h"
12+
13+
#define PRECISION ${PRECISION}
14+
15+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly image3D im_out;
16+
layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat1;
17+
layout(set = 0, binding = 2) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat2;
18+
19+
layout(set = 0, binding = 3) uniform PRECISION restrict OutExtents {
20+
uvec4 data;
21+
}
22+
out_extents;
23+
24+
layout(set = 0, binding = 4) uniform PRECISION restrict InSizes {
25+
ivec4 data;
26+
}
27+
in_sizes;
28+
29+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
30+
31+
void main() {
32+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
33+
34+
if (any(greaterThanEqual(pos, out_extents.data.xyz))) {
35+
return;
36+
}
37+
38+
vec4 texel = vec4(0);
39+
40+
ivec3 mat1_pos = ivec3(0, pos.y, pos.z);
41+
42+
$if MAT2_PACKING == "HEIGHT_PACKED":
43+
ivec3 mat2_pos = ivec3(pos.x * 4, 0, pos.z);
44+
$else:
45+
ivec3 mat2_pos = ivec3(pos.x, 0, pos.z);
46+
47+
$if MAT1_PACKING == "WIDTH_PACKED":
48+
int K = DIVUP4(in_sizes.data[0]);
49+
for (int i = 0; i < K; ++i) {
50+
$if MAT2_PACKING == "HEIGHT_PACKED":
51+
vec4 mat1_tex = texelFetch(im_mat1, mat1_pos, 0);
52+
vec4 sums = vec4(
53+
dot(mat1_tex, texelFetch(im_mat2, mat2_pos, 0)),
54+
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(1, 0, 0), 0)),
55+
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(2, 0, 0), 0)),
56+
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(3, 0, 0), 0)));
57+
58+
texel += sums;
59+
60+
mat1_pos.x++;
61+
mat2_pos.y++;
62+
$elif MAT2_PACKING == "WIDTH_PACKED":
63+
vec4 mat1_tex = texelFetch(im_mat1, mat1_pos, 0);
64+
texel = fma(mat1_tex.xxxx, texelFetch(im_mat2, mat2_pos, 0), texel);
65+
mat2_pos.y++;
66+
texel = fma(mat1_tex.yyyy, texelFetch(im_mat2, mat2_pos, 0), texel);
67+
mat2_pos.y++;
68+
texel = fma(mat1_tex.zzzz, texelFetch(im_mat2, mat2_pos, 0), texel);
69+
mat2_pos.y++;
70+
texel = fma(mat1_tex.wwww, texelFetch(im_mat2, mat2_pos, 0), texel);
71+
mat2_pos.y++;
72+
73+
mat1_pos.x++;
74+
$else:
75+
$raise Exception("Unsupported value for MAT2_PACKING")
76+
}
77+
$elif MAT1_PACKING == "CHANNELS_PACKED" and MAT2_PACKING == "CHANNELS_PACKED":
78+
int K = in_sizes.data[0];
79+
for (int i = 0; i < K; ++i) {
80+
texel = fma(
81+
texelFetch(im_mat1, mat1_pos, 0),
82+
texelFetch(im_mat2, mat2_pos, 0),
83+
texel);
84+
85+
mat1_pos.x++;
86+
mat2_pos.y++;
87+
}
88+
$else:
89+
$raise Exception("Unsupported value combo for MAT1_PACKING and MAT2_PACKING")
90+
91+
imageStore(im_out, pos, texel);
92+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
matmul:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
NDIM: 3
11+
MAT1_PACKING: WIDTH_PACKED
12+
MAT2_PACKING: HEIGHT_PACKED
13+
generate_variant_forall:
14+
DTYPE:
15+
- VALUE: float
16+
SUFFIX: float
17+
- VALUE: half
18+
SUFFIX: half
19+
shader_variants:
20+
- NAME: matmul_W_packed_H_packed
21+
- NAME: matmul_W_packed_W_packed
22+
MAT2_PACKING: WIDTH_PACKED
23+
- NAME: matmul_C_packed_C_packed
24+
MAT1_PACKING: CHANNELS_PACKED
25+
MAT2_PACKING: CHANNELS_PACKED
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15+
16+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17+
18+
namespace at {
19+
namespace native {
20+
namespace vulkan {
21+
22+
void check_matmul_args(
23+
const vTensor& mat1,
24+
const vTensor& mat2,
25+
const vTensor& out) {
26+
VK_CHECK_COND(check_ndim_is(mat1, 2) || check_ndim_is(mat1, 3));
27+
VK_CHECK_COND(check_same_ndim(mat1, mat2));
28+
29+
VK_CHECK_COND(
30+
check_memory_layout_is(
31+
mat1, api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED) ||
32+
check_memory_layout_is(mat1, api::GPUMemoryLayout::TENSOR_WIDTH_PACKED));
33+
VK_CHECK_COND(check_same_memory_layout(mat1, out));
34+
35+
VK_CHECK_COND(check_same_sizes_at(mat1, -1, mat2, -2));
36+
}
37+
38+
void resize_matmul_node(
39+
ComputeGraph* graph,
40+
const std::vector<ArgGroup>& args,
41+
const std::vector<ValueRef>& extra_args) {
42+
(void)extra_args;
43+
vTensor& out = graph->get_val(args[0].refs[0]).toTensor();
44+
vTensor& mat1 = graph->get_val(args[1].refs[0]).toTensor();
45+
vTensor& mat2 = graph->get_val(args[1].refs[1]).toTensor();
46+
47+
std::vector<int64_t> new_out_sizes(3);
48+
if (mat1.sizes().size() == 2) {
49+
new_out_sizes.resize(2);
50+
new_out_sizes.at(0) = mat1.sizes().at(0);
51+
new_out_sizes.at(1) = mat2.sizes().at(1);
52+
} else {
53+
new_out_sizes.at(0) = mat1.sizes().at(0);
54+
new_out_sizes.at(1) = mat1.sizes().at(1);
55+
new_out_sizes.at(2) = mat2.sizes().at(2);
56+
}
57+
58+
out.virtual_resize(new_out_sizes);
59+
}
60+
61+
void add_matmul_node(
62+
ComputeGraph& graph,
63+
const ValueRef mat1,
64+
const ValueRef mat2,
65+
const ValueRef out) {
66+
ValueRef arg1 = prepack_if_tensor_ref(
67+
graph, mat1, api::GPUMemoryLayout::TENSOR_WIDTH_PACKED);
68+
69+
api::GPUMemoryLayout mat2_layout = graph.memory_layout_of(arg1) ==
70+
api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED
71+
? api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED
72+
: api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED;
73+
74+
ValueRef arg2 = prepack_if_tensor_ref(graph, mat2, mat2_layout);
75+
76+
vTensor& t_mat1 = graph.get_val(arg1).toTensor();
77+
vTensor& t_mat2 = graph.get_val(arg2).toTensor();
78+
vTensor& t_out = graph.get_val(out).toTensor();
79+
80+
check_matmul_args(t_mat1, t_mat2, t_out);
81+
82+
api::utils::uvec3 global_size = t_out.virtual_extents();
83+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
84+
85+
std::stringstream kernel_name;
86+
kernel_name << "matmul";
87+
apply_memory_layout_suffix(kernel_name, t_mat1);
88+
apply_memory_layout_suffix(kernel_name, t_mat2);
89+
apply_dtype_suffix(kernel_name, t_out);
90+
91+
graph.execute_nodes().emplace_back(new ExecuteNode(
92+
graph,
93+
VK_KERNEL_FROM_STR(kernel_name.str()),
94+
global_size,
95+
local_size,
96+
// Inputs and Outputs
97+
{{out, api::MemoryAccessType::WRITE},
98+
{{arg1, arg2}, api::MemoryAccessType::READ}},
99+
// Shader params buffers
100+
{t_out.extents_ubo(), t_mat1.cpu_sizes_ubo()},
101+
// Resizing
102+
resize_matmul_node));
103+
}
104+
105+
void matmul(ComputeGraph& graph, const std::vector<ValueRef>& args) {
106+
return add_matmul_node(graph, args[0], args[1], args[2]);
107+
}
108+
109+
REGISTER_OPERATORS {
110+
VK_REGISTER_OP(aten.mm.default, matmul);
111+
}
112+
113+
} // namespace vulkan
114+
} // namespace native
115+
} // namespace at

backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,29 @@ std::vector<int64_t> calculate_broadcasted_output_size(
3737
// Tensor property checking functions
3838
//
3939

40+
bool check_ndim_is(const vTensor& t, size_t ndim) {
41+
return t.sizes().size() == ndim;
42+
}
43+
44+
bool check_same_sizes_at(
45+
const vTensor& t1,
46+
const int64_t d1,
47+
const vTensor& t2,
48+
const int64_t d2) {
49+
return api::utils::val_at(d1, t1.sizes()) ==
50+
api::utils::val_at(d2, t2.sizes());
51+
}
52+
53+
bool check_memory_layout_is(const vTensor& t, api::GPUMemoryLayout layout) {
54+
return t.gpu_memory_layout() == layout;
55+
}
56+
57+
bool check_same_ndim(const vTensor& t1, const vTensor& t2) {
58+
return t1.sizes().size() == t2.sizes().size();
59+
}
60+
4061
bool check_same_memory_layout(const vTensor& t1, const vTensor& t2) {
41-
return (t1.gpu_memory_layout() == t2.gpu_memory_layout());
62+
return t1.gpu_memory_layout() == t2.gpu_memory_layout();
4263
}
4364

4465
bool check_same_memory_layout(

backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,18 @@ std::vector<int64_t> calculate_broadcasted_output_size(
2828
// Tensor property checking functions
2929
//
3030

31+
bool check_ndim_is(const vTensor& t, size_t ndim);
32+
33+
bool check_same_ndim(const vTensor& t1, const vTensor& t2);
34+
35+
bool check_same_sizes_at(
36+
const vTensor& t1,
37+
int64_t d1,
38+
const vTensor& t2,
39+
int64_t d2);
40+
41+
bool check_memory_layout_is(const vTensor& t, api::GPUMemoryLayout layout);
42+
3143
bool check_same_memory_layout(const vTensor& t1, const vTensor& t2);
3244

3345
bool check_same_memory_layout(

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,17 @@ def forward(self, x1, x2):
351351
self.lower_module_and_test_output(
352352
model, sample_inputs, dynamic_shapes=dynamic_shapes, test_inputs=test_inputs
353353
)
354+
355+
def test_vulkan_backend_matmul(self):
356+
class MatMulModule(torch.nn.Module):
357+
def __init__(self):
358+
super().__init__()
359+
self.weight = torch.ones(size=(63, 22), dtype=torch.float32)
360+
361+
def forward(self, x):
362+
return torch.matmul(x, self.weight)
363+
364+
module = MatMulModule()
365+
sample_inputs = (torch.ones(size=(31, 63), dtype=torch.float32),)
366+
367+
self.lower_module_and_test_output(module, sample_inputs)

0 commit comments

Comments
 (0)