Skip to content

Commit 0de3a97

Browse files
cccclaifacebook-github-bot
authored andcommitted
rename original original module to orginal exported program (#2263)
Summary: Pull Request resolved: #2263 This is exported program but not module, rename them to avoid confusion Reviewed By: angelayi Differential Revision: D54527107 fbshipit-source-id: e00eaf7f46ac90acc6ae44cf64ef7c476627b67b
1 parent aef3a7c commit 0de3a97

File tree

3 files changed

+24
-16
lines changed

3 files changed

+24
-16
lines changed

exir/backend/test/test_backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def check_delegate_input(
9999
self, delegate: LoweredBackendModule, input_len: int
100100
) -> None:
101101
counter = 0
102-
for node in delegate._original_module.graph.nodes:
102+
for node in delegate.original_module.graph.nodes:
103103
if node.op == "placeholder":
104104
counter += 1
105105
self.assertEqual(counter, input_len)

exir/backend/test/test_backends_lifted.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def check_delegate_input(
9898
self, delegate: LoweredBackendModule, input_len: int
9999
) -> None:
100100
counter = 0
101-
for node in delegate._original_module.graph.nodes:
101+
for node in delegate.original_module.graph.nodes:
102102
if node.op == "placeholder":
103103
counter += 1
104104
self.assertEqual(counter, input_len)
@@ -913,7 +913,7 @@ def forward(self, x, y):
913913
)
914914
self.assertEqual(len(lowered_backends), 2)
915915
for backend in lowered_backends:
916-
original_program = backend._original_module
916+
original_program = backend.original_module
917917
# check that program has the lowered attributes
918918
self.assertEqual(len(original_program.state_dict), 1)
919919
# check backend has one placeholder input one placeholder parameter

exir/lowered_backend_module.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class LoweredBackendModule(torch.nn.Module):
5858
_compile_specs: List[
5959
CompileSpec
6060
] # A list of backend-specific objects with static metadata to configure the "compilation" process.
61-
_original_module: ExportedProgram # The original EXIR module
61+
_original_exported_program: ExportedProgram # The original EXIR module
6262

6363
def __init__(
6464
self,
@@ -68,7 +68,7 @@ def __init__(
6868
compile_specs: List[CompileSpec],
6969
) -> None:
7070
super().__init__()
71-
self._original_module = edge_program
71+
self._original_exported_program = edge_program
7272
self._backend_id = backend_id
7373
self._processed_bytes = processed_bytes
7474
self._compile_specs = compile_specs
@@ -77,14 +77,20 @@ def __init__(
7777
def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule":
7878
# Copy exported program
7979
copied_program = ExportedProgram(
80-
root=copy.deepcopy(self._original_module.graph_module),
81-
graph=copy.deepcopy(self._original_module.graph),
82-
graph_signature=copy.deepcopy(self._original_module.graph_signature),
83-
state_dict=self._original_module.state_dict,
84-
range_constraints=copy.deepcopy(self._original_module.range_constraints),
85-
module_call_graph=copy.deepcopy(self._original_module.module_call_graph),
86-
verifier=copy.deepcopy(self._original_module.verifier),
87-
constants=self._original_module.constants,
80+
root=copy.deepcopy(self._original_exported_program.graph_module),
81+
graph=copy.deepcopy(self._original_exported_program.graph),
82+
graph_signature=copy.deepcopy(
83+
self._original_exported_program.graph_signature
84+
),
85+
state_dict=self._original_exported_program.state_dict,
86+
range_constraints=copy.deepcopy(
87+
self._original_exported_program.range_constraints
88+
),
89+
module_call_graph=copy.deepcopy(
90+
self._original_exported_program.module_call_graph
91+
),
92+
verifier=copy.deepcopy(self._original_exported_program.verifier),
93+
constants=self._original_exported_program.constants,
8894
)
8995

9096
res = LoweredBackendModule(
@@ -122,7 +128,7 @@ def original_module(self) -> ExportedProgram:
122128
"""
123129
Returns the original EXIR module
124130
"""
125-
return self._original_module
131+
return self._original_exported_program
126132

127133
# TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api
128134
def buffer(
@@ -185,7 +191,7 @@ def program(self, emit_stacktrace: bool = False) -> Program:
185191
# We'll remove all call_function nodes, insert an call_delegate node, inserting getitems nodes to get the result for call_delegate node
186192
# and return the list of getitems as the output
187193

188-
lowered_exported_program = copy.deepcopy(self.original_module)
194+
lowered_exported_program = copy.deepcopy(self._original_exported_program)
189195

190196
# The real input nodes are the ones not buffer or parameter
191197
all_input_nodes = [
@@ -237,7 +243,9 @@ def program(self, emit_stacktrace: bool = False) -> Program:
237243
# Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],)
238244
# We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properly
239245
original_output_nodes = [
240-
node for node in self.original_module.graph.nodes if node.op == "output"
246+
node
247+
for node in self._original_exported_program.graph.nodes
248+
if node.op == "output"
241249
][0].args[0]
242250

243251
delegate_node.meta["spec"] = tuple(

0 commit comments

Comments
 (0)