Skip to content

Commit e32dffd

Browse files
authored
Merge branch 'main' into shoumikhin-patch-6
2 parents def2298 + 1a27c14 commit e32dffd

File tree

14 files changed

+199
-27
lines changed

14 files changed

+199
-27
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,18 @@ void main() {
8888
ipos[i] = pos[i] * stride - padding;
8989
}
9090

91-
vec4 sum[TILE_SIZE_X * TILE_SIZE_Y];
92-
sum[0] = texelFetch(t_bias, ivec2(gpos.z, 0), 0);
93-
for (int i = 1; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
94-
sum[i] = sum[0];
91+
// Final output array where each element is a tensor value.
92+
// Tuple of consecutive 4 elements represents a single output texel.
93+
float sum[TILE_SIZE_X * TILE_SIZE_Y * 4];
94+
95+
const vec4 bias = texelFetch(t_bias, ivec2(gpos.z, 0), 0);
96+
97+
// Initialize the output array with the bias value
98+
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y * 4; i += 4) {
99+
sum[i] = bias.x;
100+
sum[i + 1] = bias.y;
101+
sum[i + 2] = bias.z;
102+
sum[i + 3] = bias.w;
95103
}
96104

97105
int z4 = 0;
@@ -100,14 +108,26 @@ void main() {
100108
// During prepacking, the weight tensor has been permuted so that the
101109
// channel (IC) dim is along the x-axis, and the batch (OC) dim is along
102110
// the z-axis.
103-
const vec4 ktex_0 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(0, 0));
104-
const vec4 ktex_1 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(1, 0));
105-
const vec4 ktex_2 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(2, 0));
106-
const vec4 ktex_3 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(3, 0));
111+
float kernel_values[4 * 4]; // 4 channels, 4 elements per channel
112+
113+
// Load kernel values from texels to array
114+
for (int i = 0; i < 4; ++i) {
115+
const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, gpos.z), 0);
116+
kernel_values[i * 4 + 0] = k_tex.x;
117+
kernel_values[i * 4 + 1] = k_tex.y;
118+
kernel_values[i * 4 + 2] = k_tex.z;
119+
kernel_values[i * 4 + 3] = k_tex.w;
120+
}
107121

108-
#pragma unroll
109122
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
110123
const vec4 in_tex = texelFetch(t_in, ivec3(ipos[i], z4), 0);
124+
// Load the input texel into an array
125+
float tex_values[4];
126+
tex_values[0] = in_tex.x;
127+
tex_values[1] = in_tex.y;
128+
tex_values[2] = in_tex.z;
129+
tex_values[3] = in_tex.w;
130+
111131
// For 2x2 tile size algorithm works as follows.
112132
// To explain the calculations below, the contents of one in_tex and the
113133
// group of 4 texels loaded from t_kernel are shown:
@@ -141,18 +161,20 @@ void main() {
141161
//
142162
// which is what is expressed in the following calculations. This is done
143163
// for each output position.
144-
sum[i] = fma(in_tex.xxxx, ktex_0, sum[i]);
145-
sum[i] = fma(in_tex.yyyy, ktex_1, sum[i]);
146-
sum[i] = fma(in_tex.zzzz, ktex_2, sum[i]);
147-
sum[i] = fma(in_tex.wwww, ktex_3, sum[i]);
164+
for (int j = 0; j < 4; ++j) {
165+
sum[i * 4 + j] = tex_values[0] * kernel_values[0 + j] + sum[i * 4 + j];
166+
sum[i * 4 + j] = tex_values[1] * kernel_values[4 + j] + sum[i * 4 + j];
167+
sum[i * 4 + j] = tex_values[2] * kernel_values[8 + j] + sum[i * 4 + j];
168+
sum[i * 4 + j] = tex_values[3] * kernel_values[12 + j] + sum[i * 4 + j];
169+
}
148170
}
149171
}
150172

151173
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
152174
const uint index = (shared_mem_stride * i) + gl_LocalInvocationIndex;
153175
const ivec3 pos = pos_shared[offset_pos_index(index)];
154176
if (all(lessThan(pos, out_limits.xyz))) {
155-
imageStore(t_out, pos, op(sum[i], out_min, out_max));
177+
imageStore(t_out, pos, op(vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]), out_min, out_max));
156178
}
157179
}
158180
}

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
op_dynamic_quantize_ops,
2121
op_elu,
2222
op_floor,
23+
op_gelu,
2324
op_hardswish,
2425
op_hardtanh,
2526
op_leaky_relu,

backends/xnnpack/operators/op_gelu.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.xnnpack.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15+
XNNGelu,
16+
XNNGraph,
17+
XNode,
18+
)
19+
from executorch.backends.xnnpack.utils.utils import get_input_node
20+
21+
22+
@register_node_visitor
23+
class GeluVisitor(NodeVisitor):
24+
target = "aten.gelu.default"
25+
26+
def __init__(self, *args) -> None:
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: torch.fx.Node,
32+
xnn_graph: XNNGraph,
33+
vals_to_ids: Dict[torch.fx.Node, int],
34+
debug_handle: int,
35+
) -> None:
36+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
37+
38+
# input
39+
input_id = vals_to_ids[get_input_node(node, 0)]
40+
41+
# output
42+
output_id = vals_to_ids[node]
43+
44+
ser_node = XNode(
45+
xnode_union=XNNGelu(
46+
input_id=input_id,
47+
output_id=output_id,
48+
flags=0,
49+
),
50+
debug_handle=debug_handle,
51+
)
52+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
DeQuantizedPerTensorConfig,
2727
DivConfig,
2828
FloorConfig,
29+
GeluConfig,
2930
HardswishConfig,
3031
# EluConfig,
3132
HardtanhConfig,
@@ -79,6 +80,7 @@
7980
DivConfig,
8081
# EluConfig, # Waiting for PyTorch Pin Update
8182
FloorConfig,
83+
GeluConfig,
8284
HardtanhConfig,
8385
HardswishConfig,
8486
LeakyReLUConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
343343
return [ConfigPrecisionType.FP32]
344344

345345

346+
class GeluConfig(GenericNodePartitionerConfig):
347+
target_name = "gelu.default"
348+
349+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
350+
return [ConfigPrecisionType.FP32]
351+
352+
346353
class HardswishConfig(GenericNodePartitionerConfig):
347354
target_name = "hardswish.default"
348355

backends/xnnpack/partition/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
exir_ops.edge.aten.addmm.default, # TODO(T163877189) add constraint for addmm
6666
exir_ops.edge.aten.rsqrt.default,
6767
exir_ops.edge.aten.log.default,
68+
exir_ops.edge.aten.gelu.default,
6869
]
6970

7071
SUPPORTED_MODULES = [

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1448,6 +1448,36 @@ Error defineLogNode(
14481448
return Error::Ok;
14491449
}
14501450

1451+
/*
1452+
Define serialized gelu node into the subgraph, using the remapped ids
1453+
to map the serialized ids, to the new ids generated when defining the
1454+
tensor value
1455+
*/
1456+
Error defineGeluNode(
1457+
xnn_subgraph_t subgraph_ptr,
1458+
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1459+
const NodePtr node,
1460+
const fb_xnnpack::XNNGraph* graph) noexcept {
1461+
MAYBE_UNUSED(graph);
1462+
1463+
auto graph_node = node->xnode_union_as_XNNGelu();
1464+
1465+
xnn_status status = xnn_define_gelu(
1466+
subgraph_ptr,
1467+
remapped_ids.at(graph_node->input_id()),
1468+
remapped_ids.at(graph_node->output_id()),
1469+
graph_node->flags());
1470+
1471+
ET_CHECK_OR_RETURN_ERROR(
1472+
status == xnn_status_success,
1473+
Internal,
1474+
"Failed to create gelu node %i with code: %s",
1475+
node->debug_handle(),
1476+
xnn_status_to_string(status));
1477+
1478+
return Error::Ok;
1479+
}
1480+
14511481
/*
14521482
Define serialized ceiling node into the subgraph, using the remapped ids
14531483
to map the serialized ids, to the new ids generated when defining the
@@ -2009,6 +2039,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
20092039
_DEFINE(SquareRoot)
20102040
_DEFINE(ReciprocalSquareRoot)
20112041
_DEFINE(Ceiling)
2042+
_DEFINE(Gelu)
20122043
_DEFINE(Hardswish)
20132044
_DEFINE(LeakyReLU)
20142045
_DEFINE(Log)

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ union XNodeUnion {
140140
XNNConvTranspose2d: _XNNNodeConv,
141141
XNNReciprocalSquareRoot: _XNNNode1x1,
142142
XNNLog: _XNNNode1x1,
143+
XNNGelu: _XNNNode1x1,
143144
}
144145

145146
union XValueUnion {

backends/xnnpack/serialization/schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ union XNodeUnion {
136136
XNNConvTranspose2d: _XNNNodeConv,
137137
XNNReciprocalSquareRoot: _XNNNode1x1,
138138
XNNLog: _XNNNode1x1,
139+
XNNGelu: _XNNNode1x1,
139140
}
140141

141142
union XValueUnion {

backends/xnnpack/serialization/xnnpack_graph_schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,11 @@ class XNNCeiling(XNNNode1x1):
291291
pass
292292

293293

294+
@dataclass
295+
class XNNGelu(XNNNode1x1):
296+
pass
297+
298+
294299
@dataclass
295300
class XNNHardswish(XNNNode1x1):
296301
pass
@@ -385,6 +390,7 @@ class XNNScaledDotProductAttention:
385390
XNNBatchMatrixMultiply,
386391
XNNReciprocalSquareRoot,
387392
XNNLog,
393+
XNNGelu,
388394
]
389395

390396

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.test.tester import Tester
11+
12+
13+
class TestGelu(unittest.TestCase):
14+
def setUp(self):
15+
torch._dynamo.reset()
16+
17+
class Gelu(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
self.gelu = torch.nn.GELU()
21+
22+
def forward(self, x):
23+
return self.gelu(x)
24+
25+
def run_gelu_test(self, inputs):
26+
(
27+
Tester(self.Gelu(), inputs)
28+
.export()
29+
.check_count({"torch.ops.aten.gelu.default": 1})
30+
.to_edge_transform_and_lower()
31+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
32+
.check_not(["executorch_exir_dialects_edge__ops_aten_gelu_default"])
33+
.to_executorch()
34+
.serialize()
35+
.run_method_and_compare_outputs()
36+
)
37+
38+
def test_fp16_gelu(self):
39+
inputs = (torch.randn(20).to(torch.float16),)
40+
self.run_gelu_test(inputs)
41+
42+
def test_fp32_gelu(self):
43+
inputs = (torch.randn(20),)
44+
self.run_gelu_test(inputs)

docs/source/using-executorch-ios.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ sudo /Applications/CMake.app/Contents/bin/cmake-gui --install
135135
For example, the following command will build the ExecuTorch Runtime along with all available kernels and backends for the Apple platform in both Release and Debug modes:
136136

137137
```bash
138-
./scripts/build_apple_frameworks.sh --Release --Debug --coreml --mps --xnnpack --custom --optimized --portable --quantized
138+
./scripts/build_apple_frameworks.sh
139139
```
140140

141141
After the build finishes successfully, the resulting frameworks can be found in the `cmake-out` directory.

examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ curl -LO "https://github.com/facebook/buck2/releases/download/${BUCK2_RELEASE_DA
138138
zstd -cdq "$BUCK2_ARCHIVE" > "$BUCK2" && chmod +x "$BUCK2"
139139
rm "$BUCK2_ARCHIVE"
140140
141-
./scripts/build_apple_frameworks.sh --buck2="$(realpath $BUCK2)" --coreml --custom --mps --optimized --portable --quantized --xnnpack
141+
./scripts/build_apple_frameworks.sh
142142
```
143143

144144
After the build finishes successfully, the resulting frameworks can be found in the `cmake-out` directory. Copy them to your project and link them against your targets.

scripts/build_apple_frameworks.sh

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ set -euxo pipefail
99

1010
MODES=("Release" "Debug")
1111
PRESETS=("ios" "ios-simulator" "macos")
12+
# To support backwards compatibility, we want to retain the same output directory.
13+
PRESETS_RELATIVE_OUT_DIR=("ios" "simulator" "macos")
1214

1315
SOURCE_ROOT_DIR=$(git rev-parse --show-toplevel)
1416
OUTPUT_DIR="${SOURCE_ROOT_DIR}/cmake-out"
@@ -146,20 +148,22 @@ done
146148
echo "Building libraries"
147149

148150
rm -rf "${OUTPUT_DIR}"
149-
for preset in "${PRESETS[@]}"; do
151+
for preset_index in "${!PRESETS[@]}"; do
152+
preset="${PRESETS[$preset_index]}"
153+
preset_output_dir="${OUTPUT_DIR}/${PRESETS_RELATIVE_OUT_DIR[$preset_index]}"
154+
150155
for mode in "${MODES[@]}"; do
151-
output_dir="${OUTPUT_DIR}/${preset}"
152-
echo "Building preset ${preset} (${mode}) in ${output_dir}..."
156+
echo "Building preset ${preset} (${mode}) in ${preset_output_dir}..."
153157

154158
# Do NOT add options here. Update the respective presets instead.
155159
cmake -S "${SOURCE_ROOT_DIR}" \
156-
-B "${output_dir}" \
157-
-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY="${output_dir}" \
160+
-B "${preset_output_dir}" \
161+
-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY="${preset_output_dir}" \
158162
-DCMAKE_BUILD_TYPE="${mode}" \
159163
${CMAKE_OPTIONS_OVERRIDE[@]:-} \
160164
--preset "${preset}"
161165

162-
cmake --build "${output_dir}" \
166+
cmake --build "${preset_output_dir}" \
163167
--config "${mode}" \
164168
-j$(sysctl -n hw.ncpu)
165169
done
@@ -224,9 +228,9 @@ append_framework_flag() {
224228

225229
for mode in "${MODES[@]}"; do
226230
FRAMEWORK_FLAGS=()
227-
for preset in "${PRESETS[@]}"; do
228-
echo "Framework directory: ${preset}/${mode}"
229-
FRAMEWORK_FLAGS+=("--directory=${preset}/${mode}")
231+
for preset_out_dir in "${PRESETS_RELATIVE_OUT_DIR[@]}"; do
232+
echo "Framework directory: ${preset_out_dir}/${mode}"
233+
FRAMEWORK_FLAGS+=("--directory=${preset_out_dir}/${mode}")
230234
done
231235

232236
append_framework_flag "" "$FRAMEWORK_EXECUTORCH" "$mode"
@@ -245,8 +249,8 @@ done
245249

246250
echo "Cleaning up"
247251

248-
for preset in "${PRESETS[@]}"; do
249-
rm -rf "${OUTPUT_DIR}/${preset}/$preset"
252+
for preset_out_dir in "${PRESETS_RELATIVE_OUT_DIR[@]}"; do
253+
rm -rf "${OUTPUT_DIR}/${preset_out_dir}"
250254
done
251255

252256
rm -rf "$HEADERS_ABSOLUTE_PATH"

0 commit comments

Comments
 (0)