Skip to content

native_layer_norm (for width dim) #3001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.sum.dim_IntList,
# Convolution operators
exir_ops.edge.aten.convolution.default,
# Normalization
exir_ops.edge.aten.native_layer_norm.default,
# Other
operator.getitem,
]
Expand Down
78 changes: 78 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#include "broadcasting_utils.h"
#include "indexing_utils.h"

#define PRECISION ${PRECISION}

layout(std430) buffer;

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
layout(set = 0, binding = 1, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_mean;
layout(set = 0, binding = 2, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_rstd;

layout(set = 0, binding = 3) uniform PRECISION sampler3D image_in;
layout(set = 0, binding = 4) uniform PRECISION sampler3D weight_in;
layout(set = 0, binding = 5) uniform PRECISION sampler3D bias_in;

layout(set = 0, binding = 6) uniform PRECISION restrict OutExtents {
ivec4 data;
}
out_sizes;

layout(set = 0, binding = 7) uniform PRECISION restrict Epsilon {
float data;
}
epsilon;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const ivec4 coord = POS_TO_COORD_${PACKING}(pos, out_sizes.data);

if (any(greaterThanEqual(coord, out_sizes.data))) {
return;
}

const int width = out_sizes.data.x;

vec4 mean = vec4(0);
vec4 delta = vec4(0);
vec4 delta2 = vec4(0);
vec4 M2 = vec4(0);

// Use Welford's online algorithm to compute mean and variance in one pass
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
for (int w = 0; w < width; ++w) {
vec4 v = texelFetch(image_in, ivec3(w, pos.y, pos.z), 0);
delta = v - mean;
mean += delta / (w + 1);
delta2 = v - mean;
M2 += delta * delta2;
}

vec4 var = M2 / width;
vec4 rstd = pow(var + epsilon.data, vec4(-0.5));
vec4 offset = -rstd * mean;

for (int w = 0; w < width; ++w) {
vec4 v = texelFetch(image_in, ivec3(w, pos.y, pos.z), 0);
// broadcasting
vec4 weight = texelFetch(weight_in, ivec3(w, 0, 0), 0).xxxx;
vec4 bias = texelFetch(bias_in, ivec3(w, 0, 0), 0).xxxx;
vec4 ot = (v * rstd + offset) * weight + bias;
imageStore(image_out, ivec3(w, pos.y, pos.z), ot);
}

imageStore(image_mean, pos, mean);
imageStore(image_rstd, pos, rstd);
}
19 changes: 19 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

native_layer_norm:
parameter_names_with_default_values:
NDIM: 3
DTYPE: float
PACKING: CHANNELS_PACKED
generate_variant_forall:
DTYPE:
- VALUE: half
SUFFIX: half
- VALUE: float
SUFFIX: float
shader_variants:
- NAME: native_layer_norm
119 changes: 119 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>

#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

std::vector<int64_t> calc_out_mean_sizes(
vTensor& self,
int64_t normalized_shape_dim) {
std::vector<int64_t> output_size = self.sizes();
int64_t self_dim = self.sizes().size();
for (int64_t i = 0; i < normalized_shape_dim; ++i) {
output_size.at(self_dim - i - 1) = 1;
}
return output_size;
}

void resize_native_layer_norm_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
vTensorPtr mean = graph->get_tensor(args[0].refs[1]);
vTensorPtr rstd = graph->get_tensor(args[0].refs[2]);
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
std::vector<int64_t> in_sizes = in->sizes();

const auto normalized_shape_dim = graph->get_int_list(extra_args[0])->size();

std::vector<int64_t> mean_size =
calc_out_mean_sizes(*in, normalized_shape_dim);

out->virtual_resize(in_sizes);
mean->virtual_resize(mean_size);
rstd->virtual_resize(mean_size);
}

void check_args(const vTensor& in, const vTensor& out) {
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
}

void add_native_layer_norm_node(
ComputeGraph& graph,
const ValueRef in,
const ValueRef normalized_shape,
const ValueRef weight,
const ValueRef bias,
const ValueRef eps,
const ValueRef out) {
const auto normalized_shape_dim =
graph.get_int_list(normalized_shape)->size();
if (normalized_shape_dim > 1) {
VK_THROW("native_layer_norm only supports normalized_shape with dim == 1");
}

ValueRef arg_in = prepack_if_tensor_ref(graph, in);
ValueRef arg_weight =
prepack_if_tensor_ref(graph, weight, graph.memory_layout_of(arg_in));
ValueRef arg_bias =
prepack_if_tensor_ref(graph, bias, graph.memory_layout_of(arg_in));

const auto& out_val = *graph.get_value_list(out);
vTensorPtr t_out = graph.get_tensor(out_val[0]);
vTensorPtr t_mean = graph.get_tensor(out_val[1]);
vTensorPtr t_input = graph.get_tensor(in);
vTensorPtr t_weight = graph.get_tensor(weight);
float epsilon = graph.extract_scalar<float>(eps);

check_args(*t_input, *t_out);

std::vector<int64_t> in_sizes = t_input->sizes();

api::utils::uvec3 global_size = t_mean->extents();
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

std::string kernel_name("native_layer_norm");
kernel_name.reserve(kShaderNameReserve);

add_dtype_suffix(kernel_name, *t_out);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
global_size,
local_size,
// Inputs and Outputs
{{{out_val[0], out_val[1], out_val[2]}, api::MemoryAccessType::WRITE},
{{arg_in, arg_weight, arg_bias}, api::MemoryAccessType::READ}},
// Shader params buffers
{t_out->gpu_sizes_ubo(), graph.create_params_buffer(epsilon)},
// Resizing
resize_native_layer_norm_node,
{normalized_shape}));
}

void native_layer_norm(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_native_layer_norm_node(
graph, args[0], args[1], args[2], args[3], args[4], args[5]);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.native_layer_norm.default, native_layer_norm);
}

} // namespace vkcompute
18 changes: 18 additions & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,3 +647,21 @@ def forward(self, x):
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_native_layer_norm(self):
class NativeLayerNormModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.native_layer_norm(
x, [5], torch.ones(5), torch.zeros(5), 1e-5
)

sample_inputs = (torch.randn(size=(3, 4, 5), dtype=torch.float32),)

self.lower_module_and_test_output(
NativeLayerNormModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)