Skip to content

Commit 9abc9f4

Browse files
oscarandersson8218freddan80
authored andcommitted
Solve circular import error
Move process_output, process_inputs and process_call_function to a separate file named process_node. This removes the circular dependency between tosa_utils.py and node_visitor.py. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I5bfe42ad6cd36c579390af6d7fa14c8fb7341842
1 parent 97e0417 commit 9abc9f4

File tree

4 files changed

+61
-75
lines changed

4 files changed

+61
-75
lines changed

backends/arm/arm_backend.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,17 @@
1818
import serializer.tosa_serializer as ts
1919
from executorch.backends.arm.arm_vela import vela_compile
2020
from executorch.backends.arm.operators.node_visitor import get_node_visitors
21-
from executorch.backends.arm.operators.op_output import process_output
22-
from executorch.backends.arm.operators.op_placeholder import process_placeholder
2321

2422
from executorch.backends.arm.tosa_specification import TosaSpecification
2523
from executorch.backends.arm._passes.arm_pass_manager import (
2624
ArmPassManager,
2725
) # usort: skip
28-
from executorch.backends.arm.tosa_utils import (
29-
dbg_fail,
30-
dbg_tosa_dump,
26+
from executorch.backends.arm.process_node import (
3127
process_call_function,
28+
process_output,
29+
process_placeholder,
3230
)
31+
from executorch.backends.arm.tosa_utils import dbg_fail, dbg_tosa_dump
3332
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
3433
from executorch.exir.backend.compile_spec_schema import CompileSpec
3534
from torch.export.exported_program import ExportedProgram

backends/arm/operators/op_output.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

backends/arm/operators/op_placeholder.py renamed to backends/arm/process_node.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,68 @@
1-
# Copyright 2023-2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5-
6-
# pyre-unsafe
5+
#
6+
from typing import cast, Dict
77

88
import numpy as np
99
import serializer.tosa_serializer as ts
10+
import torch
1011
import torch.fx
11-
from executorch.backends.arm.tosa_mapping import TosaArg
12+
from executorch.backends.arm.operators.node_visitor import NodeVisitor
13+
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
1214
from executorch.backends.arm.tosa_quant_utils import (
1315
get_quant_arg_upstream,
1416
get_quantized_node_output_dtype,
1517
is_node_quantized,
1618
)
1719
from executorch.backends.arm.tosa_specification import TosaSpecification
1820
from executorch.backends.arm.tosa_utils import (
21+
getNodeArgs,
1922
is_bias_node_for_quantized_conv,
20-
map_dtype,
2123
tosa_shape,
2224
)
2325
from torch.export.exported_program import ExportedProgram
2426

2527

28+
def process_call_function(
29+
node: torch.fx.Node,
30+
tosa_graph: ts.TosaSerializer,
31+
node_visitors: Dict[str, NodeVisitor],
32+
tosa_spec: TosaSpecification,
33+
):
34+
# Unpack arguments and convert
35+
inputs = getNodeArgs(node)
36+
37+
# Convert output (this node itself)
38+
output = TosaArg(node)
39+
40+
is_quant_node = is_node_quantized(node)
41+
if is_quant_node:
42+
output_dtype = map_dtype(get_quantized_node_output_dtype(node))
43+
else:
44+
output_dtype = output.dtype
45+
tosa_graph.currRegion.currBasicBlock.addTensor(
46+
output.name,
47+
tosa_shape(output.shape, output.dim_order),
48+
output_dtype,
49+
)
50+
51+
# Visiting each Node
52+
# pyre-ignore[16]: Undefined attribute.
53+
if node.target.__name__ in node_visitors:
54+
# pyre-ignore[16]: Undefined attribute.
55+
node_visitors[node.target.__name__].define_node(
56+
node,
57+
tosa_graph,
58+
inputs,
59+
output,
60+
is_quant_node,
61+
)
62+
else:
63+
raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}")
64+
65+
2666
def process_inputs(
2767
node: torch.fx.Node,
2868
tosa_graph: ts.TosaSerializer,
@@ -176,3 +216,13 @@ def process_placeholder(
176216
)
177217
else:
178218
raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.")
219+
220+
221+
def process_output(
222+
node: torch.fx.Node,
223+
tosa_graph: ts.TosaSerializer,
224+
):
225+
for output in cast(tuple[torch.fx.Node, ...], node.args[0]):
226+
tosa_graph.addOutputTensor(
227+
tosa_graph.currRegion.currBasicBlock.tensors[output.name]
228+
)

backends/arm/tosa_utils.py

Lines changed: 2 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,18 @@
77

88
import logging
99
import os
10-
from typing import Any, cast, Dict
10+
from typing import Any, cast
1111

1212
import numpy as np
1313
import serializer.tosa_serializer as ts
1414
import torch
15-
from executorch.backends.arm.operators.node_visitor import NodeVisitor
16-
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
15+
from executorch.backends.arm.tosa_mapping import TosaArg
1716

1817
from executorch.backends.arm.tosa_quant_utils import (
1918
get_quant_arg_downstream,
2019
get_quant_arg_upstream,
21-
get_quantized_node_output_dtype,
22-
is_node_quantized,
2320
q_op,
2421
)
25-
from executorch.backends.arm.tosa_specification import TosaSpecification
2622
from executorch.exir.dialects._ops import ops as exir_ops
2723
from serializer.tosa_serializer import TosaOp
2824
from torch.fx import Node
@@ -233,44 +229,6 @@ def tosa_shape(shape, dim_order):
233229
return tuple([shape[dim] for dim in dim_order])
234230

235231

236-
def process_call_function(
237-
node: torch.fx.Node,
238-
tosa_graph: ts.TosaSerializer,
239-
node_visitors: Dict[str, NodeVisitor],
240-
tosa_spec: TosaSpecification,
241-
):
242-
# Unpack arguments and convert
243-
inputs = getNodeArgs(node)
244-
245-
# Convert output (this node itself)
246-
output = TosaArg(node)
247-
248-
is_quant_node = is_node_quantized(node)
249-
if is_quant_node:
250-
output_dtype = map_dtype(get_quantized_node_output_dtype(node))
251-
else:
252-
output_dtype = output.dtype
253-
tosa_graph.currRegion.currBasicBlock.addTensor(
254-
output.name,
255-
(tosa_shape(output.shape, output.dim_order)),
256-
output_dtype,
257-
)
258-
259-
# Visiting each Node
260-
# pyre-ignore[16]: Undefined attribute.
261-
if node.target.__name__ in node_visitors:
262-
# pyre-ignore[16]: Undefined attribute.
263-
node_visitors[node.target.__name__].define_node(
264-
node,
265-
tosa_graph,
266-
inputs,
267-
output,
268-
is_quant_node,
269-
)
270-
else:
271-
raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}")
272-
273-
274232
def expand_dims(
275233
tosa_graph: ts.TosaSerializer,
276234
input_node: TosaArg,

0 commit comments

Comments
 (0)