@@ -58,7 +58,7 @@ class LoweredBackendModule(torch.nn.Module):
58
58
_compile_specs : List [
59
59
CompileSpec
60
60
] # A list of backend-specific objects with static metadata to configure the "compilation" process.
61
- _original_module : ExportedProgram # The original EXIR module
61
+ _original_exported_program : ExportedProgram # The original EXIR module
62
62
63
63
def __init__ (
64
64
self ,
@@ -68,7 +68,7 @@ def __init__(
68
68
compile_specs : List [CompileSpec ],
69
69
) -> None :
70
70
super ().__init__ ()
71
- self ._original_module = edge_program
71
+ self ._original_exported_program = edge_program
72
72
self ._backend_id = backend_id
73
73
self ._processed_bytes = processed_bytes
74
74
self ._compile_specs = compile_specs
@@ -77,14 +77,20 @@ def __init__(
77
77
def __deepcopy__ (self , memo : Optional [Dict [int , Any ]]) -> "LoweredBackendModule" :
78
78
# Copy exported program
79
79
copied_program = ExportedProgram (
80
- root = copy .deepcopy (self ._original_module .graph_module ),
81
- graph = copy .deepcopy (self ._original_module .graph ),
82
- graph_signature = copy .deepcopy (self ._original_module .graph_signature ),
83
- state_dict = self ._original_module .state_dict ,
84
- range_constraints = copy .deepcopy (self ._original_module .range_constraints ),
85
- module_call_graph = copy .deepcopy (self ._original_module .module_call_graph ),
86
- verifier = copy .deepcopy (self ._original_module .verifier ),
87
- constants = self ._original_module .constants ,
80
+ root = copy .deepcopy (self ._original_exported_program .graph_module ),
81
+ graph = copy .deepcopy (self ._original_exported_program .graph ),
82
+ graph_signature = copy .deepcopy (
83
+ self ._original_exported_program .graph_signature
84
+ ),
85
+ state_dict = self ._original_exported_program .state_dict ,
86
+ range_constraints = copy .deepcopy (
87
+ self ._original_exported_program .range_constraints
88
+ ),
89
+ module_call_graph = copy .deepcopy (
90
+ self ._original_exported_program .module_call_graph
91
+ ),
92
+ verifier = copy .deepcopy (self ._original_exported_program .verifier ),
93
+ constants = self ._original_exported_program .constants ,
88
94
)
89
95
90
96
res = LoweredBackendModule (
@@ -122,7 +128,7 @@ def original_module(self) -> ExportedProgram:
122
128
"""
123
129
Returns the original EXIR module
124
130
"""
125
- return self ._original_module
131
+ return self ._original_exported_program
126
132
127
133
# TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api
128
134
def buffer (
@@ -185,7 +191,7 @@ def program(self, emit_stacktrace: bool = False) -> Program:
185
191
# We'll remove all call_function nodes, insert an call_delegate node, inserting getitems nodes to get the result for call_delegate node
186
192
# and return the list of getitems as the output
187
193
188
- lowered_exported_program = copy .deepcopy (self .original_module )
194
+ lowered_exported_program = copy .deepcopy (self ._original_exported_program )
189
195
190
196
# The real input nodes are the ones not buffer or parameter
191
197
all_input_nodes = [
@@ -237,7 +243,9 @@ def program(self, emit_stacktrace: bool = False) -> Program:
237
243
# Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],)
238
244
# We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properly
239
245
original_output_nodes = [
240
- node for node in self .original_module .graph .nodes if node .op == "output"
246
+ node
247
+ for node in self ._original_exported_program .graph .nodes
248
+ if node .op == "output"
241
249
][0 ].args [0 ]
242
250
243
251
delegate_node .meta ["spec" ] = tuple (
0 commit comments