Skip to content

Add GELU support to XNNPACK #11006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/xnnpack/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
op_dynamic_quantize_ops,
op_elu,
op_floor,
op_gelu,
op_hardswish,
op_hardtanh,
op_leaky_relu,
Expand Down
52 changes: 52 additions & 0 deletions backends/xnnpack/operators/op_gelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import torch
from executorch.backends.xnnpack.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
XNNGelu,
XNNGraph,
XNode,
)
from executorch.backends.xnnpack.utils.utils import get_input_node


@register_node_visitor
class GeluVisitor(NodeVisitor):
target = "aten.gelu.default"

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
xnn_graph: XNNGraph,
vals_to_ids: Dict[torch.fx.Node, int],
debug_handle: int,
) -> None:
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)

# input
input_id = vals_to_ids[get_input_node(node, 0)]

# output
output_id = vals_to_ids[node]

ser_node = XNode(
xnode_union=XNNGelu(
input_id=input_id,
output_id=output_id,
flags=0,
),
debug_handle=debug_handle,
)
xnn_graph.xnodes.append(ser_node)
2 changes: 2 additions & 0 deletions backends/xnnpack/partition/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
DeQuantizedPerTensorConfig,
DivConfig,
FloorConfig,
GeluConfig,
HardswishConfig,
# EluConfig,
HardtanhConfig,
Expand Down Expand Up @@ -79,6 +80,7 @@
DivConfig,
# EluConfig, # Waiting for PyTorch Pin Update
FloorConfig,
GeluConfig,
HardtanhConfig,
HardswishConfig,
LeakyReLUConfig,
Expand Down
7 changes: 7 additions & 0 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]


class GeluConfig(GenericNodePartitionerConfig):
target_name = "gelu.default"

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]


class HardswishConfig(GenericNodePartitionerConfig):
target_name = "hardswish.default"

Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/partition/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
exir_ops.edge.aten.addmm.default, # TODO(T163877189) add constraint for addmm
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.gelu.default,
]

SUPPORTED_MODULES = [
Expand Down
31 changes: 31 additions & 0 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,36 @@ Error defineLogNode(
return Error::Ok;
}

/*
Define serialized gelu node into the subgraph, using the remapped ids
to map the serialized ids, to the new ids generated when defining the
tensor value
*/
Error defineGeluNode(
xnn_subgraph_t subgraph_ptr,
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
const NodePtr node,
const fb_xnnpack::XNNGraph* graph) noexcept {
MAYBE_UNUSED(graph);

auto graph_node = node->xnode_union_as_XNNGelu();

xnn_status status = xnn_define_gelu(
subgraph_ptr,
remapped_ids.at(graph_node->input_id()),
remapped_ids.at(graph_node->output_id()),
graph_node->flags());

ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create gelu node %i with code: %s",
node->debug_handle(),
xnn_status_to_string(status));

return Error::Ok;
}

/*
Define serialized ceiling node into the subgraph, using the remapped ids
to map the serialized ids, to the new ids generated when defining the
Expand Down Expand Up @@ -2009,6 +2039,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
_DEFINE(SquareRoot)
_DEFINE(ReciprocalSquareRoot)
_DEFINE(Ceiling)
_DEFINE(Gelu)
_DEFINE(Hardswish)
_DEFINE(LeakyReLU)
_DEFINE(Log)
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/serialization/runtime_schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ union XNodeUnion {
XNNConvTranspose2d: _XNNNodeConv,
XNNReciprocalSquareRoot: _XNNNode1x1,
XNNLog: _XNNNode1x1,
XNNGelu: _XNNNode1x1,
}

union XValueUnion {
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ union XNodeUnion {
XNNConvTranspose2d: _XNNNodeConv,
XNNReciprocalSquareRoot: _XNNNode1x1,
XNNLog: _XNNNode1x1,
XNNGelu: _XNNNode1x1,
}

union XValueUnion {
Expand Down
6 changes: 6 additions & 0 deletions backends/xnnpack/serialization/xnnpack_graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,11 @@ class XNNCeiling(XNNNode1x1):
pass


@dataclass
class XNNGelu(XNNNode1x1):
pass


@dataclass
class XNNHardswish(XNNNode1x1):
pass
Expand Down Expand Up @@ -385,6 +390,7 @@ class XNNScaledDotProductAttention:
XNNBatchMatrixMultiply,
XNNReciprocalSquareRoot,
XNNLog,
XNNGelu,
]


Expand Down
44 changes: 44 additions & 0 deletions backends/xnnpack/test/ops/test_gelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.xnnpack.test.tester import Tester


class TestGelu(unittest.TestCase):
def setUp(self):
torch._dynamo.reset()

class Gelu(torch.nn.Module):
def __init__(self):
super().__init__()
self.gelu = torch.nn.GELU()

def forward(self, x):
return self.gelu(x)

def run_gelu_test(self, inputs):
(
Tester(self.Gelu(), inputs)
.export()
.check_count({"torch.ops.aten.gelu.default": 1})
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(["executorch_exir_dialects_edge__ops_aten_gelu_default"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

def test_fp16_gelu(self):
inputs = (torch.randn(20).to(torch.float16),)
self.run_gelu_test(inputs)

def test_fp32_gelu(self):
inputs = (torch.randn(20),)
self.run_gelu_test(inputs)
Loading