Skip to content

Commit 5692203

Browse files
pytorchbotSS-JIA
andauthored
[ET-VK][ez] Implement rsqrt (#6472)
Pull Request resolved: #6456 TSIA. This op is used in Llama model architecture. ghstack-source-id: 249709740 @exported-using-ghexport Differential Revision: [D64840505](https://our.internmc.facebook.com/intern/diff/D64840505/) Co-authored-by: Stephen Jia <[email protected]>
1 parent cb25809 commit 5692203

File tree

4 files changed

+6
-0
lines changed

4 files changed

+6
-0
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __contains__(self, op):
7676
exir_ops.edge.aten.sigmoid.default,
7777
exir_ops.edge.aten.sin.default,
7878
exir_ops.edge.aten.sqrt.default,
79+
exir_ops.edge.aten.rsqrt.default,
7980
exir_ops.edge.aten.tanh.default,
8081
exir_ops.edge.aten._to_copy.default,
8182
# Matrix Multiplication

backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ unary_op:
3232
OPERATOR: sin(X)
3333
- NAME: sqrt
3434
OPERATOR: sqrt(X)
35+
- NAME: rsqrt
36+
OPERATOR: (1 / sqrt(X))
3537
- NAME: tanh
3638
OPERATOR: tanh(clamp(X, -15.0, 15.0))
3739
- NAME: hardshrink

backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ DEFINE_ACTIVATION_FN(neg);
129129
DEFINE_ACTIVATION_FN(sigmoid);
130130
DEFINE_ACTIVATION_FN(sin);
131131
DEFINE_ACTIVATION_FN(sqrt);
132+
DEFINE_ACTIVATION_FN(rsqrt);
132133
DEFINE_ACTIVATION_FN(tanh);
133134
DEFINE_CLAMP_FN(clamp);
134135
DEFINE_CLAMP_FN(hardtanh);
@@ -149,6 +150,7 @@ REGISTER_OPERATORS {
149150
VK_REGISTER_OP(aten.sigmoid.default, sigmoid);
150151
VK_REGISTER_OP(aten.sin.default, sin);
151152
VK_REGISTER_OP(aten.sqrt.default, sqrt);
153+
VK_REGISTER_OP(aten.rsqrt.default, rsqrt);
152154
VK_REGISTER_OP(aten.tanh.default, tanh);
153155
VK_REGISTER_OP(aten.hardshrink.default, hardshrink);
154156
VK_REGISTER_OP(aten.hardswish.default, hardswish);

backends/vulkan/test/op_tests/cases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,7 @@ def get_softmax_inputs():
988988
@register_test_suite(
989989
[
990990
"aten.sqrt.default",
991+
"aten.rsqrt.default",
991992
"aten.exp.default",
992993
"aten.hardshrink.default",
993994
"aten.sin.default",

0 commit comments

Comments
 (0)