File tree Expand file tree Collapse file tree 10 files changed +147
-0
lines changed Expand file tree Collapse file tree 10 files changed +147
-0
lines changed Original file line number Diff line number Diff line change 24
24
op_hardtanh ,
25
25
op_leaky_relu ,
26
26
op_linear ,
27
+ op_log ,
27
28
op_matrix_multiplication ,
28
29
op_max_dim ,
29
30
op_max_pool2d ,
Original file line number Diff line number Diff line change
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Dict
8
+
9
+ import torch
10
+ from executorch .backends .xnnpack .operators .node_visitor import (
11
+ NodeVisitor ,
12
+ register_node_visitor ,
13
+ )
14
+ from executorch .backends .xnnpack .serialization .xnnpack_graph_schema import (
15
+ XNNGraph ,
16
+ XNNLog ,
17
+ XNode ,
18
+ )
19
+ from executorch .backends .xnnpack .utils .utils import get_input_node
20
+
21
+
22
+ @register_node_visitor
23
+ class LogVisitor (NodeVisitor ):
24
+ target = "aten.log.default"
25
+
26
+ def __init__ (self , * args ) -> None :
27
+ super ().__init__ (* args )
28
+
29
+ def define_node (
30
+ self ,
31
+ node : torch .fx .Node ,
32
+ xnn_graph : XNNGraph ,
33
+ vals_to_ids : Dict [torch .fx .Node , int ],
34
+ debug_handle : int ,
35
+ ) -> None :
36
+ self .define_nodes_tensor_inputs_outputs (node , xnn_graph , vals_to_ids )
37
+
38
+ # input
39
+ input_id = vals_to_ids [get_input_node (node , 0 )]
40
+
41
+ # output
42
+ output_id = vals_to_ids [node ]
43
+
44
+ ser_node = XNode (
45
+ xnode_union = XNNLog (
46
+ input_id = input_id ,
47
+ output_id = output_id ,
48
+ flags = 0 ,
49
+ ),
50
+ debug_handle = debug_handle ,
51
+ )
52
+ xnn_graph .xnodes .append (ser_node )
Original file line number Diff line number Diff line change 30
30
# EluConfig,
31
31
HardtanhConfig ,
32
32
LeakyReLUConfig ,
33
+ LogConfig ,
33
34
MaximumConfig ,
34
35
MaxPool2dConfig ,
35
36
MeanDimConfig ,
82
83
HardswishConfig ,
83
84
LeakyReLUConfig ,
84
85
LinearConfig ,
86
+ LogConfig ,
85
87
MaxDimConfig ,
86
88
MaximumConfig ,
87
89
MaxPool2dConfig ,
Original file line number Diff line number Diff line change @@ -357,6 +357,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
357
357
return [ConfigPrecisionType .FP32 ]
358
358
359
359
360
+ class LogConfig (GenericNodePartitionerConfig ):
361
+ target_name = "log.default"
362
+
363
+ def supported_precision_types (self ) -> List [ConfigPrecisionType ]:
364
+ return [ConfigPrecisionType .FP32 ]
365
+
366
+
360
367
class MeanDimConfig (GenericNodePartitionerConfig ):
361
368
target_name = "mean.dim"
362
369
Original file line number Diff line number Diff line change 64
64
exir_ops .edge .aten .leaky_relu .default ,
65
65
exir_ops .edge .aten .addmm .default , # TODO(T163877189) add constraint for addmm
66
66
exir_ops .edge .aten .rsqrt .default ,
67
+ exir_ops .edge .aten .log .default ,
67
68
]
68
69
69
70
SUPPORTED_MODULES = [
Original file line number Diff line number Diff line change @@ -1418,6 +1418,36 @@ Error defineReciprocalSquareRootNode(
1418
1418
return Error::Ok;
1419
1419
}
1420
1420
1421
+ /*
1422
+ Define serialized log node into the subgraph, using the remapped ids
1423
+ to map the serialized ids, to the new ids generated when defining the
1424
+ tensor value
1425
+ */
1426
+ Error defineLogNode (
1427
+ xnn_subgraph_t subgraph_ptr,
1428
+ const std::unordered_map<uint32_t , uint32_t >& remapped_ids,
1429
+ const NodePtr node,
1430
+ const fb_xnnpack::XNNGraph* graph) noexcept {
1431
+ MAYBE_UNUSED (graph);
1432
+
1433
+ auto graph_node = node->xnode_union_as_XNNLog ();
1434
+
1435
+ xnn_status status = xnn_define_log (
1436
+ subgraph_ptr,
1437
+ remapped_ids.at (graph_node->input_id ()),
1438
+ remapped_ids.at (graph_node->output_id ()),
1439
+ graph_node->flags ());
1440
+
1441
+ ET_CHECK_OR_RETURN_ERROR (
1442
+ status == xnn_status_success,
1443
+ Internal,
1444
+ " Failed to create log node %i with code: %s" ,
1445
+ node->debug_handle (),
1446
+ xnn_status_to_string (status));
1447
+
1448
+ return Error::Ok;
1449
+ }
1450
+
1421
1451
/*
1422
1452
Define serialized ceiling node into the subgraph, using the remapped ids
1423
1453
to map the serialized ids, to the new ids generated when defining the
@@ -1981,6 +2011,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
1981
2011
_DEFINE (Ceiling)
1982
2012
_DEFINE (Hardswish)
1983
2013
_DEFINE (LeakyReLU)
2014
+ _DEFINE (Log)
1984
2015
_DEFINE (Maximum)
1985
2016
_DEFINE (Negate)
1986
2017
_DEFINE (Square)
Original file line number Diff line number Diff line change @@ -139,6 +139,7 @@ union XNodeUnion {
139
139
XNNConcatenate5: _XNNCat,
140
140
XNNConvTranspose2d: _XNNNodeConv,
141
141
XNNReciprocalSquareRoot: _XNNNode1x1,
142
+ XNNLog: _XNNNode1x1,
142
143
}
143
144
144
145
union XValueUnion {
Original file line number Diff line number Diff line change @@ -135,6 +135,7 @@ union XNodeUnion {
135
135
XNNConcatenate5: _XNNCat,
136
136
XNNConvTranspose2d: _XNNNodeConv,
137
137
XNNReciprocalSquareRoot: _XNNNode1x1,
138
+ XNNLog: _XNNNode1x1,
138
139
}
139
140
140
141
union XValueUnion {
Original file line number Diff line number Diff line change @@ -309,6 +309,11 @@ class XNNLeakyReLU:
309
309
flags : int
310
310
311
311
312
+ @dataclass
313
+ class XNNLog (XNNNode1x1 ):
314
+ pass
315
+
316
+
312
317
@dataclass
313
318
class XNNMaximum (XNNNode2x1 ):
314
319
pass
@@ -379,6 +384,7 @@ class XNNScaledDotProductAttention:
379
384
XNNScaledDotProductAttention ,
380
385
XNNBatchMatrixMultiply ,
381
386
XNNReciprocalSquareRoot ,
387
+ XNNLog ,
382
388
]
383
389
384
390
Original file line number Diff line number Diff line change
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import unittest
8
+
9
+ import torch
10
+ from executorch .backends .xnnpack .test .tester import Tester
11
+
12
+
13
+ class TestLog (unittest .TestCase ):
14
+ def setUp (self ):
15
+ torch ._dynamo .reset ()
16
+
17
+ class Log (torch .nn .Module ):
18
+ def __init__ (self ):
19
+ super ().__init__ ()
20
+
21
+ def forward (self , x ):
22
+ x = torch .abs (x )
23
+ z = torch .log (x )
24
+ return z
25
+
26
+ def run_log_test (self , inputs ):
27
+ (
28
+ Tester (self .Log (), inputs )
29
+ .export ()
30
+ .check_count ({"torch.ops.aten.log.default" : 1 })
31
+ .to_edge_transform_and_lower ()
32
+ .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
33
+ .check_not (["executorch_exir_dialects_edge__ops_aten_log_default" ])
34
+ .to_executorch ()
35
+ .serialize ()
36
+ .run_method_and_compare_outputs ()
37
+ )
38
+
39
+ def test_fp16_log (self ):
40
+ inputs = (torch .randn (20 ).to (torch .float16 ),)
41
+ self .run_log_test (inputs )
42
+
43
+ def test_fp32_log (self ):
44
+ inputs = (torch .randn (20 ),)
45
+ self .run_log_test (inputs )
You can’t perform that action at this time.
0 commit comments