Skip to content

Commit 6f47383

Browse files
copyrightlyfacebook-github-bot
authored andcommitted
native_layer_norm (for width dim) (#3001)
Summary: We implement `native_layer_norm` which has 3 outputs - normalization of the input tensor according to the given `normalized_shape` - mean - 1/sqrt(var + eps) https://www.internalfb.com/code/fbsource/[8db4b5872791bb88a62ecaa60b667ee4c1b189bf]/fbcode/caffe2/aten/src/ATen/native/native_functions.yaml?lines=3252 According to SS-JIA's suggestion, a model specific implementation is more performant and preferred to a generic one. So we implemented the op in the following optimized way - our current use case has `normalized_shape` of len 1, namely we do the normalization through computing the mean and var at the last width dim - we do the computation in just one shader `native_layer_norm.glsl` without invoking the shaders to compute mean and var respectively - we use [Welford's online algorithm](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm) to compute mean and variance in one pass Differential Revision: D56005629
1 parent b1edc3d commit 6f47383

File tree

5 files changed

+236
-0
lines changed

5 files changed

+236
-0
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5050
exir_ops.edge.aten.sum.dim_IntList,
5151
# Convolution operators
5252
exir_ops.edge.aten.convolution.default,
53+
# Normalization
54+
exir_ops.edge.aten.native_layer_norm.default,
5355
# Other
5456
operator.getitem,
5557
]
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#include "broadcasting_utils.h"
12+
#include "indexing_utils.h"
13+
14+
#define PRECISION ${PRECISION}
15+
16+
layout(std430) buffer;
17+
18+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
19+
layout(set = 0, binding = 1, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_mean;
20+
layout(set = 0, binding = 2, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_rstd;
21+
22+
layout(set = 0, binding = 3) uniform PRECISION sampler3D image_in;
23+
layout(set = 0, binding = 4) uniform PRECISION sampler3D weight_in;
24+
layout(set = 0, binding = 5) uniform PRECISION sampler3D bias_in;
25+
26+
layout(set = 0, binding = 6) uniform PRECISION restrict OutExtents {
27+
ivec4 data;
28+
}
29+
out_sizes;
30+
31+
layout(set = 0, binding = 7) uniform PRECISION restrict Epsilon {
32+
float data;
33+
}
34+
epsilon;
35+
36+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
37+
38+
void main() {
39+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
40+
const ivec4 coord = POS_TO_COORD_${PACKING}(pos, out_sizes.data);
41+
42+
if (any(greaterThanEqual(coord, out_sizes.data))) {
43+
return;
44+
}
45+
46+
const int width = out_sizes.data.x;
47+
48+
vec4 mean = vec4(0);
49+
vec4 delta = vec4(0);
50+
vec4 delta2 = vec4(0);
51+
vec4 M2 = vec4(0);
52+
53+
// Use Welford's online algorithm to compute mean and variance in one pass
54+
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
55+
for (int w = 0; w < width; ++w) {
56+
vec4 v = texelFetch(image_in, ivec3(w, pos.y, pos.z), 0);
57+
delta = v - mean;
58+
mean += delta / (w + 1);
59+
delta2 = v - mean;
60+
M2 += delta * delta2;
61+
}
62+
63+
vec4 var = M2 / width;
64+
vec4 rstd = pow(var + epsilon.data, vec4(-0.5));
65+
vec4 offset = -rstd * mean;
66+
67+
for (int w = 0; w < width; ++w) {
68+
vec4 v = texelFetch(image_in, ivec3(w, pos.y, pos.z), 0);
69+
// broadcasting
70+
vec4 weight = texelFetch(weight_in, ivec3(w, 0, 0), 0).xxxx;
71+
vec4 bias = texelFetch(bias_in, ivec3(w, 0, 0), 0).xxxx;
72+
vec4 ot = (v * rstd + offset) * weight + bias;
73+
imageStore(image_out, ivec3(w, pos.y, pos.z), ot);
74+
}
75+
76+
imageStore(image_mean, pos, mean);
77+
imageStore(image_rstd, pos, rstd);
78+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
native_layer_norm:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: float
11+
PACKING: CHANNELS_PACKED
12+
generate_variant_forall:
13+
DTYPE:
14+
- VALUE: half
15+
SUFFIX: half
16+
- VALUE: float
17+
SUFFIX: float
18+
shader_variants:
19+
- NAME: native_layer_norm
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15+
16+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17+
18+
namespace vkcompute {
19+
20+
std::vector<int64_t> calc_out_mean_sizes(
21+
vTensor& self,
22+
int64_t normalized_shape_dim) {
23+
std::vector<int64_t> output_size = self.sizes();
24+
int64_t self_dim = self.sizes().size();
25+
for (int64_t i = 0; i < normalized_shape_dim; ++i) {
26+
output_size.at(self_dim - i - 1) = 1;
27+
}
28+
return output_size;
29+
}
30+
31+
void resize_native_layer_norm_node(
32+
ComputeGraph* graph,
33+
const std::vector<ArgGroup>& args,
34+
const std::vector<ValueRef>& extra_args) {
35+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
36+
vTensorPtr mean = graph->get_tensor(args[0].refs[1]);
37+
vTensorPtr rstd = graph->get_tensor(args[0].refs[2]);
38+
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
39+
std::vector<int64_t> in_sizes = in->sizes();
40+
41+
const auto normalized_shape_dim = graph->get_int_list(extra_args[0])->size();
42+
43+
std::vector<int64_t> mean_size =
44+
calc_out_mean_sizes(*in, normalized_shape_dim);
45+
46+
out->virtual_resize(in_sizes);
47+
mean->virtual_resize(mean_size);
48+
rstd->virtual_resize(mean_size);
49+
}
50+
51+
void check_args(const vTensor& in, const vTensor& out) {
52+
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
53+
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
54+
}
55+
56+
void add_native_layer_norm_node(
57+
ComputeGraph& graph,
58+
const ValueRef in,
59+
const ValueRef normalized_shape,
60+
const ValueRef weight,
61+
const ValueRef bias,
62+
const ValueRef eps,
63+
const ValueRef out) {
64+
const auto normalized_shape_dim =
65+
graph.get_int_list(normalized_shape)->size();
66+
if (normalized_shape_dim > 1) {
67+
VK_THROW("native_layer_norm only supports normalized_shape with dim == 1");
68+
}
69+
70+
ValueRef arg_in = prepack_if_tensor_ref(graph, in);
71+
ValueRef arg_weight =
72+
prepack_if_tensor_ref(graph, weight, graph.memory_layout_of(arg_in));
73+
ValueRef arg_bias =
74+
prepack_if_tensor_ref(graph, bias, graph.memory_layout_of(arg_in));
75+
76+
const auto& out_val = *graph.get_value_list(out);
77+
vTensorPtr t_out = graph.get_tensor(out_val[0]);
78+
vTensorPtr t_mean = graph.get_tensor(out_val[1]);
79+
vTensorPtr t_input = graph.get_tensor(in);
80+
vTensorPtr t_weight = graph.get_tensor(weight);
81+
float epsilon = graph.extract_scalar<float>(eps);
82+
83+
check_args(*t_input, *t_out);
84+
85+
std::vector<int64_t> in_sizes = t_input->sizes();
86+
87+
api::utils::uvec3 global_size = t_mean->extents();
88+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
89+
90+
std::string kernel_name("native_layer_norm");
91+
kernel_name.reserve(kShaderNameReserve);
92+
93+
add_dtype_suffix(kernel_name, *t_out);
94+
95+
graph.execute_nodes().emplace_back(new ExecuteNode(
96+
graph,
97+
VK_KERNEL_FROM_STR(kernel_name),
98+
global_size,
99+
local_size,
100+
// Inputs and Outputs
101+
{{{out_val[0], out_val[1], out_val[2]}, api::MemoryAccessType::WRITE},
102+
{{arg_in, arg_weight, arg_bias}, api::MemoryAccessType::READ}},
103+
// Shader params buffers
104+
{t_out->gpu_sizes_ubo(), graph.create_params_buffer(epsilon)},
105+
// Resizing
106+
resize_native_layer_norm_node,
107+
{normalized_shape}));
108+
}
109+
110+
void native_layer_norm(ComputeGraph& graph, const std::vector<ValueRef>& args) {
111+
return add_native_layer_norm_node(
112+
graph, args[0], args[1], args[2], args[3], args[4], args[5]);
113+
}
114+
115+
REGISTER_OPERATORS {
116+
VK_REGISTER_OP(aten.native_layer_norm.default, native_layer_norm);
117+
}
118+
119+
} // namespace vkcompute

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,3 +647,21 @@ def forward(self, x):
647647
sample_inputs,
648648
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
649649
)
650+
651+
def test_vulkan_backend_native_layer_norm(self):
652+
class NativeLayerNormModule(torch.nn.Module):
653+
def __init__(self):
654+
super().__init__()
655+
656+
def forward(self, x):
657+
return torch.native_layer_norm(
658+
x, [5], torch.ones(5), torch.zeros(5), 1e-5
659+
)
660+
661+
sample_inputs = (torch.randn(size=(3, 4, 5), dtype=torch.float32),)
662+
663+
self.lower_module_and_test_output(
664+
NativeLayerNormModule(),
665+
sample_inputs,
666+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
667+
)

0 commit comments

Comments
 (0)