Skip to content

Commit 78227f0

Browse files
authored
Suppport unary log in xnnpack delegate (#10952)
### Summary Support log in XNNPACK backend ### Test plan Wrote test cases to see if appropriate xnnpack log was called
1 parent de72d65 commit 78227f0

File tree

10 files changed

+147
-0
lines changed

10 files changed

+147
-0
lines changed

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
op_hardtanh,
2525
op_leaky_relu,
2626
op_linear,
27+
op_log,
2728
op_matrix_multiplication,
2829
op_max_dim,
2930
op_max_pool2d,

backends/xnnpack/operators/op_log.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.xnnpack.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15+
XNNGraph,
16+
XNNLog,
17+
XNode,
18+
)
19+
from executorch.backends.xnnpack.utils.utils import get_input_node
20+
21+
22+
@register_node_visitor
23+
class LogVisitor(NodeVisitor):
24+
target = "aten.log.default"
25+
26+
def __init__(self, *args) -> None:
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: torch.fx.Node,
32+
xnn_graph: XNNGraph,
33+
vals_to_ids: Dict[torch.fx.Node, int],
34+
debug_handle: int,
35+
) -> None:
36+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
37+
38+
# input
39+
input_id = vals_to_ids[get_input_node(node, 0)]
40+
41+
# output
42+
output_id = vals_to_ids[node]
43+
44+
ser_node = XNode(
45+
xnode_union=XNNLog(
46+
input_id=input_id,
47+
output_id=output_id,
48+
flags=0,
49+
),
50+
debug_handle=debug_handle,
51+
)
52+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
# EluConfig,
3131
HardtanhConfig,
3232
LeakyReLUConfig,
33+
LogConfig,
3334
MaximumConfig,
3435
MaxPool2dConfig,
3536
MeanDimConfig,
@@ -82,6 +83,7 @@
8283
HardswishConfig,
8384
LeakyReLUConfig,
8485
LinearConfig,
86+
LogConfig,
8587
MaxDimConfig,
8688
MaximumConfig,
8789
MaxPool2dConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
357357
return [ConfigPrecisionType.FP32]
358358

359359

360+
class LogConfig(GenericNodePartitionerConfig):
361+
target_name = "log.default"
362+
363+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
364+
return [ConfigPrecisionType.FP32]
365+
366+
360367
class MeanDimConfig(GenericNodePartitionerConfig):
361368
target_name = "mean.dim"
362369

backends/xnnpack/partition/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
exir_ops.edge.aten.leaky_relu.default,
6565
exir_ops.edge.aten.addmm.default, # TODO(T163877189) add constraint for addmm
6666
exir_ops.edge.aten.rsqrt.default,
67+
exir_ops.edge.aten.log.default,
6768
]
6869

6970
SUPPORTED_MODULES = [

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,6 +1418,36 @@ Error defineReciprocalSquareRootNode(
14181418
return Error::Ok;
14191419
}
14201420

1421+
/*
1422+
Define serialized log node into the subgraph, using the remapped ids
1423+
to map the serialized ids, to the new ids generated when defining the
1424+
tensor value
1425+
*/
1426+
Error defineLogNode(
1427+
xnn_subgraph_t subgraph_ptr,
1428+
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1429+
const NodePtr node,
1430+
const fb_xnnpack::XNNGraph* graph) noexcept {
1431+
MAYBE_UNUSED(graph);
1432+
1433+
auto graph_node = node->xnode_union_as_XNNLog();
1434+
1435+
xnn_status status = xnn_define_log(
1436+
subgraph_ptr,
1437+
remapped_ids.at(graph_node->input_id()),
1438+
remapped_ids.at(graph_node->output_id()),
1439+
graph_node->flags());
1440+
1441+
ET_CHECK_OR_RETURN_ERROR(
1442+
status == xnn_status_success,
1443+
Internal,
1444+
"Failed to create log node %i with code: %s",
1445+
node->debug_handle(),
1446+
xnn_status_to_string(status));
1447+
1448+
return Error::Ok;
1449+
}
1450+
14211451
/*
14221452
Define serialized ceiling node into the subgraph, using the remapped ids
14231453
to map the serialized ids, to the new ids generated when defining the
@@ -1981,6 +2011,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
19812011
_DEFINE(Ceiling)
19822012
_DEFINE(Hardswish)
19832013
_DEFINE(LeakyReLU)
2014+
_DEFINE(Log)
19842015
_DEFINE(Maximum)
19852016
_DEFINE(Negate)
19862017
_DEFINE(Square)

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ union XNodeUnion {
139139
XNNConcatenate5: _XNNCat,
140140
XNNConvTranspose2d: _XNNNodeConv,
141141
XNNReciprocalSquareRoot: _XNNNode1x1,
142+
XNNLog: _XNNNode1x1,
142143
}
143144

144145
union XValueUnion {

backends/xnnpack/serialization/schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ union XNodeUnion {
135135
XNNConcatenate5: _XNNCat,
136136
XNNConvTranspose2d: _XNNNodeConv,
137137
XNNReciprocalSquareRoot: _XNNNode1x1,
138+
XNNLog: _XNNNode1x1,
138139
}
139140

140141
union XValueUnion {

backends/xnnpack/serialization/xnnpack_graph_schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,11 @@ class XNNLeakyReLU:
309309
flags: int
310310

311311

312+
@dataclass
313+
class XNNLog(XNNNode1x1):
314+
pass
315+
316+
312317
@dataclass
313318
class XNNMaximum(XNNNode2x1):
314319
pass
@@ -379,6 +384,7 @@ class XNNScaledDotProductAttention:
379384
XNNScaledDotProductAttention,
380385
XNNBatchMatrixMultiply,
381386
XNNReciprocalSquareRoot,
387+
XNNLog,
382388
]
383389

384390

backends/xnnpack/test/ops/test_log.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.test.tester import Tester
11+
12+
13+
class TestLog(unittest.TestCase):
14+
def setUp(self):
15+
torch._dynamo.reset()
16+
17+
class Log(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
21+
def forward(self, x):
22+
x = torch.abs(x)
23+
z = torch.log(x)
24+
return z
25+
26+
def run_log_test(self, inputs):
27+
(
28+
Tester(self.Log(), inputs)
29+
.export()
30+
.check_count({"torch.ops.aten.log.default": 1})
31+
.to_edge_transform_and_lower()
32+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
33+
.check_not(["executorch_exir_dialects_edge__ops_aten_log_default"])
34+
.to_executorch()
35+
.serialize()
36+
.run_method_and_compare_outputs()
37+
)
38+
39+
def test_fp16_log(self):
40+
inputs = (torch.randn(20).to(torch.float16),)
41+
self.run_log_test(inputs)
42+
43+
def test_fp32_log(self):
44+
inputs = (torch.randn(20),)
45+
self.run_log_test(inputs)

0 commit comments

Comments
 (0)