Skip to content

add abs, sigmoid, tanh to ET-VK #2605

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,12 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.div.Tensor_mode,
exir_ops.edge.aten.pow.Tensor_Tensor,
# Activation operators
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.tanh.default,
# Matrix multiplication operators
exir_ops.edge.aten.mm.default,
# Pooling operators
Expand Down
36 changes: 36 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/activation.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define OP(X) ${OPERATOR}

layout(std430) buffer;

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;

layout(set = 0, binding = 2) uniform PRECISION restrict OutExtents {
uvec4 data;
}
out_extents;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, out_extents.data.xyz))) {
return;
}

vec4 in_texel = texelFetch(image_in, pos, 0);
imageStore(image_out, pos, OP(in_texel));
}
18 changes: 18 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/activation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
activation:
parameter_names_with_default_values:
OPERATOR: op(X)
NDIM: 3
DTYPE: float
generate_variant_forall:
DTYPE:
- VALUE: half
SUFFIX: half
- VALUE: float
SUFFIX: float
shader_variants:
- NAME: abs
OPERATOR: abs(X)
- NAME: sigmoid
OPERATOR: 1 / (1 + exp(-1 * X))
- NAME: tanh
OPERATOR: tanh(clamp(X, -15.0, 15.0))
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
unary_op:
clamp:
parameter_names_with_default_values:
OPERATOR: clamp(X, A, B)
NDIM: 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,34 @@ void add_clamp_node(
resize_clamp_node));
}

void add_activation_node(
ComputeGraph& graph,
const ValueRef in,
const ValueRef out,
const std::string& op_name_string) {
ValueRef arg = prepack_if_tensor_ref(graph, in);

vTensor& t_out = graph.get_val(out).toTensor();
api::utils::uvec3 global_size = t_out.virtual_extents();
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

std::stringstream kernel_name;
kernel_name << op_name_string;
apply_dtype_suffix(kernel_name, t_out);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
VK_KERNEL_FROM_STR(kernel_name.str()),
global_size,
local_size,
// Inputs and Outputs
{{out, api::MemoryAccessType::WRITE}, {arg, api::MemoryAccessType::READ}},
// Shader params buffers
{t_out.extents_ubo()},
// Resizing
resize_clamp_node));
}

float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) {
if (!graph.get_val(val).isNone()) {
return extract_scalar<float>(graph.get_val(val));
Expand All @@ -69,6 +97,11 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) {
: -std::numeric_limits<float>::infinity();
}

#define DEFINE_ABS_FN(op_name) \
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
return add_activation_node(graph, args[0], args[1], "abs"); \
}

#define DEFINE_CLAMP_FN(op_name) \
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
return add_clamp_node( \
Expand All @@ -85,14 +118,30 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) {
graph, args[0], 0, std::numeric_limits<float>::infinity(), args[1]); \
}

#define DEFINE_SIGMOID_FN(op_name) \
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
return add_activation_node(graph, args[0], args[1], "sigmoid"); \
}

#define DEFINE_TANH_FN(op_name) \
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
return add_activation_node(graph, args[0], args[1], "tanh"); \
}

DEFINE_ABS_FN(abs);
DEFINE_CLAMP_FN(clamp);
DEFINE_CLAMP_FN(hardtanh);
DEFINE_RELU_FN(relu);
DEFINE_SIGMOID_FN(sigmoid);
DEFINE_TANH_FN(tanh);

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.abs.default, abs);
VK_REGISTER_OP(aten.clamp.default, clamp);
VK_REGISTER_OP(aten.hardtanh.default, hardtanh);
VK_REGISTER_OP(aten.relu.default, relu);
VK_REGISTER_OP(aten.sigmoid.default, sigmoid);
VK_REGISTER_OP(aten.tanh.default, tanh);
}

} // namespace vulkan
Expand Down
30 changes: 30 additions & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,36 @@ def forward(self, x):
first_output_only=True,
)

def test_vulkan_backend_abs(self):
class AbsModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.abs(x)

self.lower_clamp_module_and_test_output(AbsModule())

def test_vulkan_backend_sigmoid(self):
class SigmoidModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.sigmoid(x)

self.lower_clamp_module_and_test_output(SigmoidModule())

def test_vulkan_backend_tanh(self):
class TanhModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.tanh(x)

self.lower_clamp_module_and_test_output(TanhModule())

def test_vulkan_backend_partial(self):
class SimpleModel(torch.nn.Module):
def __init__(self):
Expand Down