Skip to content

Commit a1550e3

Browse files
hossein1387YIWENX14
authored andcommitted
aten.leakyrelu.default in unary_ops
Differential Revision: D68688186 Pull Request resolved: #7975
1 parent 023e732 commit a1550e3

File tree

4 files changed

+28
-0
lines changed

4 files changed

+28
-0
lines changed

backends/vulkan/runtime/graph/ops/glsl/activations.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,15 @@ vec4 hardsigmoid(vec4 tex) {
4242
hardsigmoid(tex.z),
4343
hardsigmoid(tex.w));
4444
}
45+
46+
float leaky_relu(float x, float negative_slope) {
47+
return x * (float(x > 0.0) + negative_slope * float(x <= 0.0));
48+
}
49+
50+
vec4 leaky_relu(vec4 tex, float negative_slope) {
51+
return vec4(
52+
leaky_relu(tex.x, negative_slope),
53+
leaky_relu(tex.y, negative_slope),
54+
leaky_relu(tex.z, negative_slope),
55+
leaky_relu(tex.w, negative_slope));
56+
}

backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,5 @@ unary_op:
4242
OPERATOR: hardswish(X)
4343
- NAME: hardsigmoid
4444
OPERATOR: hardsigmoid(X)
45+
- NAME: leaky_relu
46+
OPERATOR: leaky_relu(X, A)

backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,17 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) {
114114
"hardshrink"); \
115115
}
116116

117+
#define DEFINE_LEAKY_RELU_FN(op_name) \
118+
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
119+
return add_unary_op_node( \
120+
graph, \
121+
args[0], \
122+
get_val_or_inf(graph, args[1], /*neg slope*/ false), \
123+
kDummyFloat, \
124+
args[2], \
125+
"leaky_relu"); \
126+
}
127+
117128
void gelu(ComputeGraph& graph, const std::vector<ValueRef>& args) {
118129
// args[1] is the `approximate` string
119130
// https://fburl.com/code/9omngmyo
@@ -137,6 +148,7 @@ DEFINE_RELU_FN(relu);
137148
DEFINE_HARDSHRINK_FN(hardshrink);
138149
DEFINE_ACTIVATION_FN(hardswish);
139150
DEFINE_ACTIVATION_FN(hardsigmoid);
151+
DEFINE_LEAKY_RELU_FN(leaky_relu);
140152

141153
REGISTER_OPERATORS {
142154
VK_REGISTER_OP(aten.abs.default, abs);
@@ -155,6 +167,7 @@ REGISTER_OPERATORS {
155167
VK_REGISTER_OP(aten.hardshrink.default, hardshrink);
156168
VK_REGISTER_OP(aten.hardswish.default, hardswish);
157169
VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid);
170+
VK_REGISTER_OP(aten.leaky_relu.default, leaky_relu);
158171
}
159172

160173
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,7 @@ def get_reduce_op_inputs():
10721072
"aten.cos.default",
10731073
"aten.hardswish.default",
10741074
"aten.hardsigmoid.default",
1075+
"aten.leaky_relu.default",
10751076
]
10761077
)
10771078
def get_unary_ops_inputs():

0 commit comments

Comments
 (0)