Skip to content

rename original original module to orginal exported program #2263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion exir/backend/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def check_delegate_input(
self, delegate: LoweredBackendModule, input_len: int
) -> None:
counter = 0
for node in delegate._original_module.graph.nodes:
for node in delegate.original_module.graph.nodes:
if node.op == "placeholder":
counter += 1
self.assertEqual(counter, input_len)
Expand Down
4 changes: 2 additions & 2 deletions exir/backend/test/test_backends_lifted.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def check_delegate_input(
self, delegate: LoweredBackendModule, input_len: int
) -> None:
counter = 0
for node in delegate._original_module.graph.nodes:
for node in delegate.original_module.graph.nodes:
if node.op == "placeholder":
counter += 1
self.assertEqual(counter, input_len)
Expand Down Expand Up @@ -913,7 +913,7 @@ def forward(self, x, y):
)
self.assertEqual(len(lowered_backends), 2)
for backend in lowered_backends:
original_program = backend._original_module
original_program = backend.original_module
# check that program has the lowered attributes
self.assertEqual(len(original_program.state_dict), 1)
# check backend has one placeholder input one placeholder parameter
Expand Down
34 changes: 21 additions & 13 deletions exir/lowered_backend_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class LoweredBackendModule(torch.nn.Module):
_compile_specs: List[
CompileSpec
] # A list of backend-specific objects with static metadata to configure the "compilation" process.
_original_module: ExportedProgram # The original EXIR module
_original_exported_program: ExportedProgram # The original EXIR module

def __init__(
self,
Expand All @@ -68,7 +68,7 @@ def __init__(
compile_specs: List[CompileSpec],
) -> None:
super().__init__()
self._original_module = edge_program
self._original_exported_program = edge_program
self._backend_id = backend_id
self._processed_bytes = processed_bytes
self._compile_specs = compile_specs
Expand All @@ -77,14 +77,20 @@ def __init__(
def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule":
# Copy exported program
copied_program = ExportedProgram(
root=copy.deepcopy(self._original_module.graph_module),
graph=copy.deepcopy(self._original_module.graph),
graph_signature=copy.deepcopy(self._original_module.graph_signature),
state_dict=self._original_module.state_dict,
range_constraints=copy.deepcopy(self._original_module.range_constraints),
module_call_graph=copy.deepcopy(self._original_module.module_call_graph),
verifier=copy.deepcopy(self._original_module.verifier),
constants=self._original_module.constants,
root=copy.deepcopy(self._original_exported_program.graph_module),
graph=copy.deepcopy(self._original_exported_program.graph),
graph_signature=copy.deepcopy(
self._original_exported_program.graph_signature
),
state_dict=self._original_exported_program.state_dict,
range_constraints=copy.deepcopy(
self._original_exported_program.range_constraints
),
module_call_graph=copy.deepcopy(
self._original_exported_program.module_call_graph
),
verifier=copy.deepcopy(self._original_exported_program.verifier),
constants=self._original_exported_program.constants,
)

res = LoweredBackendModule(
Expand Down Expand Up @@ -122,7 +128,7 @@ def original_module(self) -> ExportedProgram:
"""
Returns the original EXIR module
"""
return self._original_module
return self._original_exported_program

# TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api
def buffer(
Expand Down Expand Up @@ -185,7 +191,7 @@ def program(self, emit_stacktrace: bool = False) -> Program:
# We'll remove all call_function nodes, insert an call_delegate node, inserting getitems nodes to get the result for call_delegate node
# and return the list of getitems as the output

lowered_exported_program = copy.deepcopy(self.original_module)
lowered_exported_program = copy.deepcopy(self._original_exported_program)

# The real input nodes are the ones not buffer or parameter
all_input_nodes = [
Expand Down Expand Up @@ -237,7 +243,9 @@ def program(self, emit_stacktrace: bool = False) -> Program:
# Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],)
# We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properly
original_output_nodes = [
node for node in self.original_module.graph.nodes if node.op == "output"
node
for node in self._original_exported_program.graph.nodes
if node.op == "output"
][0].args[0]

delegate_node.meta["spec"] = tuple(
Expand Down