Skip to content

Commit c3f79e8

Browse files
copyrightlyfacebook-github-bot
authored andcommitted
add abs, sigmoid, tanh to ET-VK (#2605)
Summary: Pull Request resolved: #2605 tsia Differential Revision: D55169458
1 parent 8532e79 commit c3f79e8

File tree

7 files changed

+137
-1
lines changed

7 files changed

+137
-1
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,12 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
3636
exir_ops.edge.aten.div.Tensor_mode,
3737
exir_ops.edge.aten.pow.Tensor_Tensor,
3838
# Activation operators
39+
exir_ops.edge.aten.abs.default,
3940
exir_ops.edge.aten.clamp.default,
4041
exir_ops.edge.aten.hardtanh.default,
4142
exir_ops.edge.aten.relu.default,
43+
exir_ops.edge.aten.sigmoid.default,
44+
exir_ops.edge.aten.tanh.default,
4245
# Matrix multiplication operators
4346
exir_ops.edge.aten.mm.default,
4447
# Pooling operators
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define OP(X) ${OPERATOR}
14+
15+
layout(std430) buffer;
16+
17+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
18+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
19+
20+
layout(set = 0, binding = 2) uniform PRECISION restrict OutExtents {
21+
uvec4 data;
22+
}
23+
out_extents;
24+
25+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
26+
27+
void main() {
28+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
29+
30+
if (any(greaterThanEqual(pos, out_extents.data.xyz))) {
31+
return;
32+
}
33+
34+
vec4 in_texel = texelFetch(image_in, pos, 0);
35+
imageStore(image_out, pos, OP(in_texel));
36+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
activation:
2+
parameter_names_with_default_values:
3+
OPERATOR: op(X)
4+
NDIM: 3
5+
DTYPE: float
6+
generate_variant_forall:
7+
DTYPE:
8+
- VALUE: half
9+
SUFFIX: half
10+
- VALUE: float
11+
SUFFIX: float
12+
shader_variants:
13+
- NAME: abs
14+
OPERATOR: abs(X)
15+
- NAME: sigmoid
16+
OPERATOR: 1 / (1 + exp(-1 * X))
17+
- NAME: tanh
18+
OPERATOR: tanh(clamp(X, -15.0, 15.0))

backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/clamp.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
unary_op:
1+
clamp:
22
parameter_names_with_default_values:
33
OPERATOR: clamp(X, A, B)
44
NDIM: 3

backends/vulkan/runtime/graph/ops/impl/Clamp.cpp renamed to backends/vulkan/runtime/graph/ops/impl/Activation.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,34 @@ void add_clamp_node(
6161
resize_clamp_node));
6262
}
6363

64+
void add_activation_node(
65+
ComputeGraph& graph,
66+
const ValueRef in,
67+
const ValueRef out,
68+
const std::string& op_name_string) {
69+
ValueRef arg = prepack_if_tensor_ref(graph, in);
70+
71+
vTensor& t_out = graph.get_val(out).toTensor();
72+
api::utils::uvec3 global_size = t_out.virtual_extents();
73+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
74+
75+
std::stringstream kernel_name;
76+
kernel_name << op_name_string;
77+
apply_dtype_suffix(kernel_name, t_out);
78+
79+
graph.execute_nodes().emplace_back(new ExecuteNode(
80+
graph,
81+
VK_KERNEL_FROM_STR(kernel_name.str()),
82+
global_size,
83+
local_size,
84+
// Inputs and Outputs
85+
{{out, api::MemoryAccessType::WRITE}, {arg, api::MemoryAccessType::READ}},
86+
// Shader params buffers
87+
{t_out.extents_ubo()},
88+
// Resizing
89+
resize_clamp_node));
90+
}
91+
6492
float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) {
6593
if (!graph.get_val(val).isNone()) {
6694
return extract_scalar<float>(graph.get_val(val));
@@ -69,6 +97,11 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) {
6997
: -std::numeric_limits<float>::infinity();
7098
}
7199

100+
#define DEFINE_ABS_FN(op_name) \
101+
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
102+
return add_activation_node(graph, args[0], args[1], "abs"); \
103+
}
104+
72105
#define DEFINE_CLAMP_FN(op_name) \
73106
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
74107
return add_clamp_node( \
@@ -85,14 +118,30 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) {
85118
graph, args[0], 0, std::numeric_limits<float>::infinity(), args[1]); \
86119
}
87120

121+
#define DEFINE_SIGMOID_FN(op_name) \
122+
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
123+
return add_activation_node(graph, args[0], args[1], "sigmoid"); \
124+
}
125+
126+
#define DEFINE_TANH_FN(op_name) \
127+
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
128+
return add_activation_node(graph, args[0], args[1], "tanh"); \
129+
}
130+
131+
DEFINE_ABS_FN(abs);
88132
DEFINE_CLAMP_FN(clamp);
89133
DEFINE_CLAMP_FN(hardtanh);
90134
DEFINE_RELU_FN(relu);
135+
DEFINE_SIGMOID_FN(sigmoid);
136+
DEFINE_TANH_FN(tanh);
91137

92138
REGISTER_OPERATORS {
139+
VK_REGISTER_OP(aten.abs.default, abs);
93140
VK_REGISTER_OP(aten.clamp.default, clamp);
94141
VK_REGISTER_OP(aten.hardtanh.default, hardtanh);
95142
VK_REGISTER_OP(aten.relu.default, relu);
143+
VK_REGISTER_OP(aten.sigmoid.default, sigmoid);
144+
VK_REGISTER_OP(aten.tanh.default, tanh);
96145
}
97146

98147
} // namespace vulkan

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,36 @@ def forward(self, x):
379379
first_output_only=True,
380380
)
381381

382+
def test_vulkan_backend_abs(self):
383+
class AbsModule(torch.nn.Module):
384+
def __init__(self):
385+
super().__init__()
386+
387+
def forward(self, x):
388+
return torch.abs(x)
389+
390+
self.lower_clamp_module_and_test_output(AbsModule())
391+
392+
def test_vulkan_backend_sigmoid(self):
393+
class SigmoidModule(torch.nn.Module):
394+
def __init__(self):
395+
super().__init__()
396+
397+
def forward(self, x):
398+
return torch.sigmoid(x)
399+
400+
self.lower_clamp_module_and_test_output(SigmoidModule())
401+
402+
def test_vulkan_backend_tanh(self):
403+
class TanhModule(torch.nn.Module):
404+
def __init__(self):
405+
super().__init__()
406+
407+
def forward(self, x):
408+
return torch.tanh(x)
409+
410+
self.lower_clamp_module_and_test_output(TanhModule())
411+
382412
def test_vulkan_backend_partial(self):
383413
class SimpleModel(torch.nn.Module):
384414
def __init__(self):

0 commit comments

Comments
 (0)