Skip to content

Commit 74576e8

Browse files
copyrightlyfacebook-github-bot
authored andcommitted
native_layer_norm (for width dim) (#3001)
Summary: Pull Request resolved: #3001 We implement `native_layer_norm` which has 3 outputs - normalization of the input tensor according to the given `normalized_shape` - mean - 1/sqrt(var + eps) ``` func: native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor) ``` According to SS-JIA's suggestion, a model specific implementation is more performant and preferred to a generic one. So we implemented the op in the following optimized way - our current use case has `normalized_shape` of len 1, namely we do the normalization through computing the mean and var at the last width dim - we do the computation in just one shader `native_layer_norm.glsl` without invoking the shaders to compute mean and var respectively - we use [Welford's online algorithm](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm) to compute mean and variance in one pass Reviewed By: SS-JIA, jorgep31415 Differential Revision: D56005629 fbshipit-source-id: 096c2e2f04b95f1f5c9205c4827091169771978c
1 parent 075fe40 commit 74576e8

File tree

8 files changed

+300
-6
lines changed

8 files changed

+300
-6
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5050
exir_ops.edge.aten.sum.dim_IntList,
5151
# Convolution operators
5252
exir_ops.edge.aten.convolution.default,
53+
# Normalization
54+
exir_ops.edge.aten.native_layer_norm.default,
5355
# Other
5456
operator.getitem,
5557
]
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#include "broadcasting_utils.h"
12+
#include "indexing_utils.h"
13+
14+
#define PRECISION ${PRECISION}
15+
#define VEC4_T ${texel_type(DTYPE)}
16+
#define to_tensor_idx to_tensor_idx_${PACKING}
17+
18+
layout(std430) buffer;
19+
20+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
21+
layout(set = 0, binding = 1, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_mean;
22+
layout(set = 0, binding = 2, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_rstd;
23+
24+
layout(set = 0, binding = 3) uniform PRECISION sampler3D image_in;
25+
layout(set = 0, binding = 4) uniform PRECISION sampler3D weight_in;
26+
layout(set = 0, binding = 5) uniform PRECISION sampler3D bias_in;
27+
28+
layout(set = 0, binding = 6) uniform PRECISION restrict OutExtents {
29+
uvec4 data;
30+
}
31+
out_sizes;
32+
33+
layout(set = 0, binding = 7) uniform PRECISION restrict Epsilon {
34+
float data;
35+
}
36+
epsilon;
37+
38+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
39+
40+
void main() {
41+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
42+
const ivec4 idx = to_tensor_idx(pos, out_sizes.data);
43+
44+
if (any(greaterThanEqual(idx, out_sizes.data))) {
45+
return;
46+
}
47+
48+
const int width = int(out_sizes.data.x);
49+
50+
VEC4_T mean = VEC4_T(0);
51+
VEC4_T delta = VEC4_T(0);
52+
VEC4_T delta2 = VEC4_T(0);
53+
VEC4_T M2 = VEC4_T(0);
54+
55+
// Use Welford's online algorithm to compute mean and variance in one pass
56+
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
57+
for (int w = 0; w < width; ++w) {
58+
VEC4_T v = texelFetch(image_in, ivec3(w, pos.y, pos.z), 0);
59+
delta = v - mean;
60+
mean += delta / (w + 1);
61+
delta2 = v - mean;
62+
M2 += delta * delta2;
63+
}
64+
65+
VEC4_T var = M2 / width;
66+
VEC4_T rstd = pow(var + epsilon.data, VEC4_T(-0.5));
67+
VEC4_T offset = -rstd * mean;
68+
69+
for (int w = 0; w < width; ++w) {
70+
VEC4_T v = texelFetch(image_in, ivec3(w, pos.y, pos.z), 0);
71+
// broadcasting
72+
VEC4_T weight = texelFetch(weight_in, ivec3(w, 0, 0), 0).xxxx;
73+
VEC4_T bias = texelFetch(bias_in, ivec3(w, 0, 0), 0).xxxx;
74+
VEC4_T outtex = (v * rstd + offset) * weight + bias;
75+
imageStore(image_out, ivec3(w, pos.y, pos.z), outtex);
76+
}
77+
78+
imageStore(image_mean, pos, mean);
79+
imageStore(image_rstd, pos, rstd);
80+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
native_layer_norm:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: float
11+
PACKING: C_packed
12+
generate_variant_forall:
13+
DTYPE:
14+
- VALUE: half
15+
- VALUE: float
16+
shader_variants:
17+
- NAME: native_layer_norm
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15+
16+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17+
18+
namespace vkcompute {
19+
20+
std::vector<int64_t> calc_out_mean_sizes(
21+
vTensor& self,
22+
int64_t normalized_shape_dim) {
23+
std::vector<int64_t> output_size = self.sizes();
24+
int64_t self_dim = self.sizes().size();
25+
for (int64_t i = 0; i < normalized_shape_dim; ++i) {
26+
output_size.at(self_dim - i - 1) = 1;
27+
}
28+
return output_size;
29+
}
30+
31+
void resize_native_layer_norm_node(
32+
ComputeGraph* graph,
33+
const std::vector<ArgGroup>& args,
34+
const std::vector<ValueRef>& extra_args) {
35+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
36+
vTensorPtr mean = graph->get_tensor(args[0].refs[1]);
37+
vTensorPtr rstd = graph->get_tensor(args[0].refs[2]);
38+
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
39+
std::vector<int64_t> in_sizes = in->sizes();
40+
41+
const auto normalized_shape_dim = graph->get_int_list(extra_args[0])->size();
42+
43+
std::vector<int64_t> mean_size =
44+
calc_out_mean_sizes(*in, normalized_shape_dim);
45+
46+
out->virtual_resize(in_sizes);
47+
mean->virtual_resize(mean_size);
48+
rstd->virtual_resize(mean_size);
49+
}
50+
51+
void check_args(const vTensor& in, const vTensor& out) {
52+
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
53+
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
54+
}
55+
56+
void add_native_layer_norm_node(
57+
ComputeGraph& graph,
58+
const ValueRef in,
59+
const ValueRef normalized_shape,
60+
const ValueRef weight,
61+
const ValueRef bias,
62+
const ValueRef eps,
63+
const ValueRef out) {
64+
const auto normalized_shape_dim =
65+
graph.get_int_list(normalized_shape)->size();
66+
if (normalized_shape_dim > 1) {
67+
VK_THROW("native_layer_norm only supports normalized_shape with dim == 1");
68+
}
69+
70+
if (graph.val_is_none(weight)) {
71+
VK_THROW("native_layer_norm requires weight to be non-None");
72+
}
73+
74+
if (graph.val_is_none(bias)) {
75+
VK_THROW("native_layer_norm requires bias to be non-None");
76+
}
77+
78+
ValueRef arg_in = prepack_if_tensor_ref(graph, in);
79+
ValueRef arg_weight =
80+
prepack_if_tensor_ref(graph, weight, graph.memory_layout_of(arg_in));
81+
ValueRef arg_bias =
82+
prepack_if_tensor_ref(graph, bias, graph.memory_layout_of(arg_in));
83+
84+
const auto out_val = graph.get_value_list(out);
85+
vTensorPtr t_out = graph.get_tensor(out_val->at(0));
86+
vTensorPtr t_mean = graph.get_tensor(out_val->at(1));
87+
vTensorPtr t_input = graph.get_tensor(in);
88+
float epsilon = graph.extract_scalar<float>(eps);
89+
90+
check_args(*t_input, *t_out);
91+
92+
std::vector<int64_t> in_sizes = t_input->sizes();
93+
94+
api::utils::uvec3 global_size = t_mean->extents();
95+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
96+
97+
std::string kernel_name("native_layer_norm");
98+
kernel_name.reserve(kShaderNameReserve);
99+
100+
add_dtype_suffix(kernel_name, *t_out);
101+
102+
graph.execute_nodes().emplace_back(new ExecuteNode(
103+
graph,
104+
VK_KERNEL_FROM_STR(kernel_name),
105+
global_size,
106+
local_size,
107+
// Inputs and Outputs
108+
{{{out_val->at(0), out_val->at(1), out_val->at(2)},
109+
api::MemoryAccessType::WRITE},
110+
{{arg_in, arg_weight, arg_bias}, api::MemoryAccessType::READ}},
111+
// Shader params buffers
112+
{t_out->gpu_sizes_ubo(), graph.create_params_buffer(epsilon)},
113+
// Resizing
114+
resize_native_layer_norm_node,
115+
{normalized_shape}));
116+
}
117+
118+
void native_layer_norm(ComputeGraph& graph, const std::vector<ValueRef>& args) {
119+
return add_native_layer_norm_node(
120+
graph, args[0], args[1], args[2], args[3], args[4], args[5]);
121+
}
122+
123+
REGISTER_OPERATORS {
124+
VK_REGISTER_OP(aten.native_layer_norm.default, native_layer_norm);
125+
}
126+
127+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,18 @@ def get_conv2d_inputs():
118118
return test_suite
119119

120120

121+
def get_native_layer_norm_inputs():
122+
test_suite = VkTestSuite(
123+
[
124+
((S1, S2), [S2], (S2), (S2), 0.001),
125+
((M, M1, M2), [M2], (M2), (M2), 0.001),
126+
((L, XL, M1, M2), [M2], (M2), (M2), 0.001),
127+
]
128+
)
129+
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
130+
return test_suite
131+
132+
121133
test_suites = {
122134
"aten.add.Tensor": get_binary_elementwise_inputs(),
123135
"aten.sub.Tensor": get_binary_elementwise_inputs(),
@@ -126,6 +138,7 @@ def get_conv2d_inputs():
126138
"aten.mm.default": get_mm_inputs(),
127139
"aten.max_pool2d_with_indices.default": get_pool2d_inputs(),
128140
"aten.convolution.default": get_conv2d_inputs(),
141+
"aten.native_layer_norm.default": get_native_layer_norm_inputs(),
129142
}
130143

131144
prepacked_args = {"aten.mm.default": {"mat2"}}

backends/vulkan/test/op_tests/utils/codegen.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
AT_TENSOR_OPT,
1616
BOOL,
1717
CppTestFileGen,
18+
DOUBLE,
1819
INT,
19-
TENSOR_TUPLE,
2020
TestSuite,
2121
TestSuiteGen,
22+
THREE_TENSOR_TUPLE,
23+
TWO_TENSOR_TUPLE,
2224
)
2325
from torchgen.api import cpp
2426
from torchgen.api.types import CppSignatureGroup
@@ -118,7 +120,7 @@ def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite):
118120
self.refs["out"] = ValueRef(
119121
name="out_ref", src_cpp_name="out", src_cpp_type=ret_type, is_out=True
120122
)
121-
elif ret_type == TENSOR_TUPLE:
123+
elif ret_type == TWO_TENSOR_TUPLE:
122124
self.refs["out"] = [
123125
ValueRef(
124126
name="out_ref_first",
@@ -139,6 +141,33 @@ def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite):
139141
is_out=False,
140142
),
141143
]
144+
elif ret_type == THREE_TENSOR_TUPLE:
145+
self.refs["out"] = [
146+
ValueRef(
147+
name="out_ref_first",
148+
src_cpp_name="std::get<0>(out)",
149+
src_cpp_type="at::Tensor",
150+
is_out=True,
151+
),
152+
ValueRef(
153+
name="out_ref_second",
154+
src_cpp_name="std::get<1>(out)",
155+
src_cpp_type="at::Tensor",
156+
is_out=True,
157+
),
158+
ValueRef(
159+
name="out_ref_third",
160+
src_cpp_name="std::get<2>(out)",
161+
src_cpp_type="at::Tensor",
162+
is_out=True,
163+
),
164+
ValueRef(
165+
name="out_ref",
166+
src_cpp_name="out",
167+
src_cpp_type=ret_type,
168+
is_out=False,
169+
),
170+
]
142171

143172
## ATen code generation
144173

@@ -210,8 +239,12 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901
210239
ret_str += f"add_scalar<bool>({ref.src_cpp_name}); \n"
211240
elif ref.src_cpp_type == INT:
212241
ret_str += f"add_scalar<int64_t>({ref.src_cpp_name}); \n"
213-
elif ref.src_cpp_type == TENSOR_TUPLE:
242+
elif ref.src_cpp_type == DOUBLE:
243+
ret_str += f"add_scalar<double>({ref.src_cpp_name}); \n"
244+
elif ref.src_cpp_type == TWO_TENSOR_TUPLE:
214245
ret_str += f"add_value_list({{{ref.name}_first, {ref.name}_second}}); \n"
246+
elif ref.src_cpp_type == THREE_TENSOR_TUPLE:
247+
ret_str += f"add_value_list({{{ref.name}_first, {ref.name}_second, {ref.name}_third}}); \n"
215248
else:
216249
raise RuntimeError(f"Unsupported cpp type {ref.src_cpp_type}")
217250

@@ -441,9 +474,9 @@ def gen_parameterization(self) -> str:
441474
}
442475
443476
#ifdef USE_VULKAN_FP16_INFERENCE
444-
bool check_close(at::Tensor& t1, at::Tensor& t2, float rtol=1e-2, float atol=1e-3) {
477+
bool check_close(at::Tensor& t1, at::Tensor& t2, float rtol=1e-2, float atol=1e-2) {
445478
#else
446-
bool check_close(at::Tensor& t1, at::Tensor& t2, float rtol=1e-5, float atol=1e-8) {
479+
bool check_close(at::Tensor& t1, at::Tensor& t2, float rtol=1e-5, float atol=1e-5) {
447480
#endif
448481
// Skip checking index tensors
449482
if (t1.scalar_type() == at::kLong || t2.scalar_type() == at::kLong) {

backends/vulkan/test/op_tests/utils/codegen_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
AT_TENSOR_OPT = "::std::optional<at::Tensor>"
2222
BOOL = "bool"
2323
INT = "int64_t"
24-
TENSOR_TUPLE = "::std::tuple<at::Tensor,at::Tensor>"
24+
DOUBLE = "double"
25+
TWO_TENSOR_TUPLE = "::std::tuple<at::Tensor,at::Tensor>"
26+
THREE_TENSOR_TUPLE = "::std::tuple<at::Tensor,at::Tensor,at::Tensor>"
2527

2628
###########################
2729
## Test Suite definition ##
@@ -131,6 +133,8 @@ def create_input_data(self, arg: Argument, data: Any) -> str:
131133
ret_str += f"{str(data).lower()};"
132134
elif cpp_type == INT:
133135
ret_str += f"{str(data).lower()};"
136+
elif cpp_type == DOUBLE:
137+
ret_str += f"{str(data).lower()};"
134138
else:
135139
raise RuntimeError(f"Unsupported cpp type {cpp_type}")
136140
return ret_str + "\n"

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,3 +647,21 @@ def forward(self, x):
647647
sample_inputs,
648648
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
649649
)
650+
651+
def test_vulkan_backend_native_layer_norm(self):
652+
class NativeLayerNormModule(torch.nn.Module):
653+
def __init__(self):
654+
super().__init__()
655+
656+
def forward(self, x):
657+
return torch.native_layer_norm(
658+
x, [5], torch.ones(5), torch.zeros(5), 1e-5
659+
)
660+
661+
sample_inputs = (torch.randn(size=(3, 4, 5), dtype=torch.float32),)
662+
663+
self.lower_module_and_test_output(
664+
NativeLayerNormModule(),
665+
sample_inputs,
666+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
667+
)

0 commit comments

Comments
 (0)