Skip to content

Commit 4de53bd

Browse files
copyrightlyfacebook-github-bot
authored andcommitted
add abs, sigmoid, tanh to ET-VK (#2605)
Summary: Pull Request resolved: #2605 tsia bypass-github-pytorch-ci-checks Reviewed By: jorgep31415 Differential Revision: D55169458 fbshipit-source-id: fa980d73c9544f5be7e44141d99456008f343b73
1 parent f7fbc7a commit 4de53bd

File tree

4 files changed

+73
-12
lines changed

4 files changed

+73
-12
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,13 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
3535
exir_ops.edge.aten.div.Tensor,
3636
exir_ops.edge.aten.div.Tensor_mode,
3737
exir_ops.edge.aten.pow.Tensor_Tensor,
38-
# Activation operators
38+
# Unary operators
39+
exir_ops.edge.aten.abs.default,
3940
exir_ops.edge.aten.clamp.default,
4041
exir_ops.edge.aten.hardtanh.default,
4142
exir_ops.edge.aten.relu.default,
43+
exir_ops.edge.aten.sigmoid.default,
44+
exir_ops.edge.aten.tanh.default,
4245
# Matrix multiplication operators
4346
exir_ops.edge.aten.mm.default,
4447
# Pooling operators

backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,11 @@ unary_op:
1010
- VALUE: float
1111
SUFFIX: float
1212
shader_variants:
13+
- NAME: abs
14+
OPERATOR: abs(X)
1315
- NAME: clamp
1416
OPERATOR: clamp(X, A, B)
17+
- NAME: sigmoid
18+
OPERATOR: 1 / (1 + exp(-1 * X))
19+
- NAME: tanh
20+
OPERATOR: tanh(clamp(X, -15.0, 15.0))

backends/vulkan/runtime/graph/ops/impl/Clamp.cpp renamed to backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ namespace at {
1919
namespace native {
2020
namespace vulkan {
2121

22-
void resize_clamp_node(
22+
constexpr float kDummyFloat = -1.0f;
23+
const std::string kClampShaderName = "clamp";
24+
25+
void resize_unary_op_node(
2326
ComputeGraph* graph,
2427
const std::vector<ArgGroup>& args,
2528
const std::vector<ValueRef>& extra_args) {
@@ -30,20 +33,21 @@ void resize_clamp_node(
3033
out.virtual_resize(self.sizes());
3134
}
3235

33-
void add_clamp_node(
36+
void add_unary_op_node(
3437
ComputeGraph& graph,
3538
const ValueRef in,
3639
const float min,
3740
const float max,
38-
const ValueRef out) {
41+
const ValueRef out,
42+
const std::string& op_name) {
3943
ValueRef arg = prepack_if_tensor_ref(graph, in);
4044

4145
vTensor& t_out = graph.get_val(out).toTensor();
4246
api::utils::uvec3 global_size = t_out.virtual_extents();
4347
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
4448

4549
std::stringstream kernel_name;
46-
kernel_name << "clamp";
50+
kernel_name << op_name;
4751
apply_dtype_suffix(kernel_name, t_out);
4852

4953
graph.execute_nodes().emplace_back(new ExecuteNode(
@@ -58,7 +62,7 @@ void add_clamp_node(
5862
graph.create_params_buffer(min),
5963
graph.create_params_buffer(max)},
6064
// Resizing
61-
resize_clamp_node));
65+
resize_unary_op_node));
6266
}
6367

6468
float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) {
@@ -69,30 +73,48 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) {
6973
: -std::numeric_limits<float>::infinity();
7074
}
7175

76+
#define DEFINE_ACTIVATION_FN(op_name) \
77+
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
78+
return add_unary_op_node( \
79+
graph, args[0], kDummyFloat, kDummyFloat, args[1], #op_name); \
80+
}
81+
7282
#define DEFINE_CLAMP_FN(op_name) \
7383
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
74-
return add_clamp_node( \
84+
return add_unary_op_node( \
7585
graph, \
7686
args[0], \
7787
get_val_or_inf(graph, args[1], /*max =*/false), \
7888
get_val_or_inf(graph, args[2], /*max =*/true), \
79-
args[3]); \
89+
args[3], \
90+
kClampShaderName); \
8091
}
8192

82-
#define DEFINE_RELU_FN(op_name) \
83-
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
84-
return add_clamp_node( \
85-
graph, args[0], 0, std::numeric_limits<float>::infinity(), args[1]); \
93+
#define DEFINE_RELU_FN(op_name) \
94+
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
95+
return add_unary_op_node( \
96+
graph, \
97+
args[0], \
98+
0, \
99+
std::numeric_limits<float>::infinity(), \
100+
args[1], \
101+
kClampShaderName); \
86102
}
87103

104+
DEFINE_ACTIVATION_FN(abs);
105+
DEFINE_ACTIVATION_FN(sigmoid);
106+
DEFINE_ACTIVATION_FN(tanh);
88107
DEFINE_CLAMP_FN(clamp);
89108
DEFINE_CLAMP_FN(hardtanh);
90109
DEFINE_RELU_FN(relu);
91110

92111
REGISTER_OPERATORS {
112+
VK_REGISTER_OP(aten.abs.default, abs);
93113
VK_REGISTER_OP(aten.clamp.default, clamp);
94114
VK_REGISTER_OP(aten.hardtanh.default, hardtanh);
95115
VK_REGISTER_OP(aten.relu.default, relu);
116+
VK_REGISTER_OP(aten.sigmoid.default, sigmoid);
117+
VK_REGISTER_OP(aten.tanh.default, tanh);
96118
}
97119

98120
} // namespace vulkan

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,36 @@ def forward(self, x):
383383
first_output_only=True,
384384
)
385385

386+
def test_vulkan_backend_abs(self):
387+
class AbsModule(torch.nn.Module):
388+
def __init__(self):
389+
super().__init__()
390+
391+
def forward(self, x):
392+
return torch.abs(x)
393+
394+
self.lower_clamp_module_and_test_output(AbsModule())
395+
396+
def test_vulkan_backend_sigmoid(self):
397+
class SigmoidModule(torch.nn.Module):
398+
def __init__(self):
399+
super().__init__()
400+
401+
def forward(self, x):
402+
return torch.sigmoid(x)
403+
404+
self.lower_clamp_module_and_test_output(SigmoidModule())
405+
406+
def test_vulkan_backend_tanh(self):
407+
class TanhModule(torch.nn.Module):
408+
def __init__(self):
409+
super().__init__()
410+
411+
def forward(self, x):
412+
return torch.tanh(x)
413+
414+
self.lower_clamp_module_and_test_output(TanhModule())
415+
386416
def test_vulkan_backend_partial(self):
387417
class SimpleModel(torch.nn.Module):
388418
def __init__(self):

0 commit comments

Comments
 (0)