@@ -19,7 +19,10 @@ namespace at {
19
19
namespace native {
20
20
namespace vulkan {
21
21
22
- void resize_clamp_node (
22
+ constexpr float kDummyFloat = -1 .0f ;
23
+ const std::string kClampShaderName = " clamp" ;
24
+
25
+ void resize_unary_op_node (
23
26
ComputeGraph* graph,
24
27
const std::vector<ArgGroup>& args,
25
28
const std::vector<ValueRef>& extra_args) {
@@ -30,20 +33,21 @@ void resize_clamp_node(
30
33
out.virtual_resize (self.sizes ());
31
34
}
32
35
33
- void add_clamp_node (
36
+ void add_unary_op_node (
34
37
ComputeGraph& graph,
35
38
const ValueRef in,
36
39
const float min,
37
40
const float max,
38
- const ValueRef out) {
41
+ const ValueRef out,
42
+ const std::string& op_name) {
39
43
ValueRef arg = prepack_if_tensor_ref (graph, in);
40
44
41
45
vTensor& t_out = graph.get_val (out).toTensor ();
42
46
api::utils::uvec3 global_size = t_out.virtual_extents ();
43
47
api::utils::uvec3 local_size = adaptive_work_group_size (global_size);
44
48
45
49
std::stringstream kernel_name;
46
- kernel_name << " clamp " ;
50
+ kernel_name << op_name ;
47
51
apply_dtype_suffix (kernel_name, t_out);
48
52
49
53
graph.execute_nodes ().emplace_back (new ExecuteNode (
@@ -58,7 +62,7 @@ void add_clamp_node(
58
62
graph.create_params_buffer (min),
59
63
graph.create_params_buffer (max)},
60
64
// Resizing
61
- resize_clamp_node ));
65
+ resize_unary_op_node ));
62
66
}
63
67
64
68
float get_val_or_inf (ComputeGraph& graph, const ValueRef& val, bool max) {
@@ -69,30 +73,48 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) {
69
73
: -std::numeric_limits<float >::infinity ();
70
74
}
71
75
76
+ #define DEFINE_ACTIVATION_FN (op_name ) \
77
+ void op_name (ComputeGraph& graph, const std::vector<ValueRef>& args) { \
78
+ return add_unary_op_node ( \
79
+ graph, args[0 ], kDummyFloat , kDummyFloat , args[1 ], #op_name); \
80
+ }
81
+
72
82
#define DEFINE_CLAMP_FN (op_name ) \
73
83
void op_name (ComputeGraph& graph, const std::vector<ValueRef>& args) { \
74
- return add_clamp_node ( \
84
+ return add_unary_op_node ( \
75
85
graph, \
76
86
args[0 ], \
77
87
get_val_or_inf (graph, args[1 ], /* max =*/ false ), \
78
88
get_val_or_inf (graph, args[2 ], /* max =*/ true ), \
79
- args[3 ]); \
89
+ args[3 ], \
90
+ kClampShaderName ); \
80
91
}
81
92
82
- #define DEFINE_RELU_FN (op_name ) \
83
- void op_name (ComputeGraph& graph, const std::vector<ValueRef>& args) { \
84
- return add_clamp_node ( \
85
- graph, args[0 ], 0 , std::numeric_limits<float >::infinity (), args[1 ]); \
93
+ #define DEFINE_RELU_FN (op_name ) \
94
+ void op_name (ComputeGraph& graph, const std::vector<ValueRef>& args) { \
95
+ return add_unary_op_node ( \
96
+ graph, \
97
+ args[0 ], \
98
+ 0 , \
99
+ std::numeric_limits<float >::infinity (), \
100
+ args[1 ], \
101
+ kClampShaderName ); \
86
102
}
87
103
104
+ DEFINE_ACTIVATION_FN (abs);
105
+ DEFINE_ACTIVATION_FN (sigmoid);
106
+ DEFINE_ACTIVATION_FN (tanh);
88
107
DEFINE_CLAMP_FN (clamp);
89
108
DEFINE_CLAMP_FN (hardtanh);
90
109
DEFINE_RELU_FN (relu);
91
110
92
111
REGISTER_OPERATORS {
112
+ VK_REGISTER_OP (aten.abs .default , abs);
93
113
VK_REGISTER_OP (aten.clamp .default , clamp);
94
114
VK_REGISTER_OP (aten.hardtanh .default , hardtanh);
95
115
VK_REGISTER_OP (aten.relu .default , relu);
116
+ VK_REGISTER_OP (aten.sigmoid .default , sigmoid);
117
+ VK_REGISTER_OP (aten.tanh .default , tanh);
96
118
}
97
119
98
120
} // namespace vulkan
0 commit comments