Skip to content

Commit 0876706

Browse files
Yujie Huifacebook-github-bot
authored andcommitted
aten.sin, aten.cos, aten.neg
Summary: Implement 3 unary ops: ``` - func: sin(Tensor self) -> Tensor - func: cos(Tensor self) -> Tensor - func: neg(Tensor self) -> Tensor ``` Reviewed By: derekxu, jorgep31415 Differential Revision: D57646317
1 parent 16a892d commit 0876706

File tree

5 files changed

+48
-0
lines changed

5 files changed

+48
-0
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,15 @@ def __contains__(self, op):
5252
UNARY_OPS = [
5353
exir_ops.edge.aten.abs.default,
5454
exir_ops.edge.aten.clamp.default,
55+
exir_ops.edge.aten.cos.default,
5556
exir_ops.edge.aten.exp.default,
5657
exir_ops.edge.aten.gelu.default,
5758
exir_ops.edge.aten.hardshrink.default,
5859
exir_ops.edge.aten.hardtanh.default,
60+
exir_ops.edge.aten.neg.default,
5961
exir_ops.edge.aten.relu.default,
6062
exir_ops.edge.aten.sigmoid.default,
63+
exir_ops.edge.aten.sin.default,
6164
exir_ops.edge.aten.sqrt.default,
6265
exir_ops.edge.aten.tanh.default,
6366
]

backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,18 @@ unary_op:
1212
OPERATOR: abs(X)
1313
- NAME: clamp
1414
OPERATOR: clamp(X, A, B)
15+
- NAME: cos
16+
OPERATOR: cos(X)
1517
- NAME: exp
1618
OPERATOR: exp(X)
1719
- NAME: gelu
1820
OPERATOR: 0.5 * X * (1 + tanh(sqrt(2 / 3.141593) * (X + 0.044715 * X * X * X)))
21+
- NAME: neg
22+
OPERATOR: -X
1923
- NAME: sigmoid
2024
OPERATOR: 1 / (1 + exp(-1 * X))
25+
- NAME: sin
26+
OPERATOR: sin(X)
2127
- NAME: sqrt
2228
OPERATOR: sqrt(X)
2329
- NAME: tanh

backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,11 @@ void gelu(ComputeGraph& graph, const std::vector<ValueRef>& args) {
120120
}
121121

122122
DEFINE_ACTIVATION_FN(abs);
123+
DEFINE_ACTIVATION_FN(cos);
123124
DEFINE_ACTIVATION_FN(exp);
125+
DEFINE_ACTIVATION_FN(neg);
124126
DEFINE_ACTIVATION_FN(sigmoid);
127+
DEFINE_ACTIVATION_FN(sin);
125128
DEFINE_ACTIVATION_FN(sqrt);
126129
DEFINE_ACTIVATION_FN(tanh);
127130
DEFINE_CLAMP_FN(clamp);
@@ -132,11 +135,14 @@ DEFINE_HARDSHRINK_FN(hardshrink);
132135
REGISTER_OPERATORS {
133136
VK_REGISTER_OP(aten.abs.default, abs);
134137
VK_REGISTER_OP(aten.clamp.default, clamp);
138+
VK_REGISTER_OP(aten.cos.default, cos);
135139
VK_REGISTER_OP(aten.exp.default, exp);
136140
VK_REGISTER_OP(aten.gelu.default, gelu);
137141
VK_REGISTER_OP(aten.hardtanh.default, hardtanh);
142+
VK_REGISTER_OP(aten.neg.default, neg);
138143
VK_REGISTER_OP(aten.relu.default, relu);
139144
VK_REGISTER_OP(aten.sigmoid.default, sigmoid);
145+
VK_REGISTER_OP(aten.sin.default, sin);
140146
VK_REGISTER_OP(aten.sqrt.default, sqrt);
141147
VK_REGISTER_OP(aten.tanh.default, tanh);
142148
VK_REGISTER_OP(aten.hardshrink.default, hardshrink);

backends/vulkan/test/op_tests/cases.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,4 +795,7 @@ def get_gelu_inputs():
795795
"aten._native_batch_norm_legit_no_training.default": get_native_batch_norm_inputs(),
796796
"aten.gelu.default": get_gelu_inputs(),
797797
"aten.hardshrink.default": get_unary_ops_inputs(),
798+
"aten.sin.default": get_unary_ops_inputs(),
799+
"aten.neg.default": get_unary_ops_inputs(),
800+
"aten.cos.default": get_unary_ops_inputs(),
798801
}

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,16 @@ def forward(self, x):
379379

380380
self.lower_unary_module_and_test_output(ClampModule())
381381

382+
def test_vulkan_backend_cos(self):
383+
class CosModule(torch.nn.Module):
384+
def __init__(self):
385+
super().__init__()
386+
387+
def forward(self, x):
388+
return torch.cos(x)
389+
390+
self.lower_unary_module_and_test_output(CosModule())
391+
382392
def test_vulkan_backend_hardtanh(self):
383393
class HardTanHModule(torch.nn.Module):
384394
def __init__(self):
@@ -400,6 +410,26 @@ def forward(self, x):
400410

401411
self.lower_unary_module_and_test_output(ExpModule())
402412

413+
def test_vulkan_backend_neg(self):
414+
class NegModule(torch.nn.Module):
415+
def __init__(self):
416+
super().__init__()
417+
418+
def forward(self, x):
419+
return torch.neg(x)
420+
421+
self.lower_unary_module_and_test_output(NegModule())
422+
423+
def test_vulkan_backend_sin(self):
424+
class SinModule(torch.nn.Module):
425+
def __init__(self):
426+
super().__init__()
427+
428+
def forward(self, x):
429+
return torch.sin(x)
430+
431+
self.lower_unary_module_and_test_output(SinModule())
432+
403433
def test_vulkan_backend_relu(self):
404434
class ReLUModule(torch.nn.Module):
405435
def __init__(self):

0 commit comments

Comments
 (0)