Skip to content

Commit 7291f89

Browse files
author
morelos
committed
Update on "[ET-VK][Ops] aten.tan.default from scratch implementation"
Following the instructions here for creating a new operator from scratch for learning purposes: https://www.internalfb.com/wiki/ExecuTorch_Vulkan_Backend/Development_0/Adding_a_New_Operator_Implementation/ Goal is to create a tan operator and its test case Differential Revision: [D75100188](https://our.internmc.facebook.com/intern/diff/D75100188/) [ghstack-poisoned]
2 parents 47c0573 + 79a84c1 commit 7291f89

File tree

1 file changed

+7
-12
lines changed
  • backends/vulkan/runtime/graph/ops/impl

1 file changed

+7
-12
lines changed

backends/vulkan/runtime/graph/ops/impl/Tan.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
*
55
* This source code is licensed under the BSD-style license found in the
66
* LICENSE file in the root directory of this source tree.
7-
*/
7+
*/
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

@@ -18,20 +18,17 @@ namespace vkcompute {
1818
using namespace utils;
1919

2020
void resize_tan_node(
21-
ComputeGraph* graph,
22-
const std::vector<ArgGroup>& args,
23-
const std::vector<ValueRef>& extra_args) {
21+
ComputeGraph* graph,
22+
const std::vector<ArgGroup>& args,
23+
const std::vector<ValueRef>& extra_args) {
2424
(void)extra_args;
2525
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
2626
vTensorPtr self = graph->get_tensor(args[1].refs[0]);
2727

2828
out->virtual_resize(self->sizes());
2929
}
3030

31-
void add_tan_node(
32-
ComputeGraph& graph,
33-
const ValueRef in,
34-
const ValueRef out) {
31+
void add_tan_node(ComputeGraph& graph, const ValueRef in, const ValueRef out) {
3532
std::string kernel_name = "tan";
3633
add_dtype_suffix(kernel_name, graph.dtype_of(out));
3734
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
@@ -45,8 +42,7 @@ const ValueRef out) {
4542
graph.create_global_wg_size(out),
4643
graph.create_local_wg_size(out),
4744
// Inputs and Outputs
48-
{{out, vkapi::kWrite},
49-
{in, vkapi::kRead}},
45+
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
5046
// Shader params buffers
5147
ubos,
5248
// Push Constants
@@ -60,10 +56,9 @@ const ValueRef out) {
6056
}
6157

6258
void tan(ComputeGraph& graph, const std::vector<ValueRef>& args) {
63-
return add_tan_node(graph, args[0], args[1]);
59+
return add_tan_node(graph, args[0], args[1]);
6460
}
6561

66-
6762
REGISTER_OPERATORS {
6863
VK_REGISTER_OP(aten.tan.default, tan);
6964
}

0 commit comments

Comments
 (0)