Skip to content

Commit 217ddba

Browse files
mcr229facebook-github-bot
authored andcommitted
grab parameter from node
Summary: Helper functions to: 1. Determine if a placeholder node is a parameter 2. Graph the parameter from a node if it exists Reviewed By: cccclai Differential Revision: D47924708 fbshipit-source-id: 38671d87d9826956a7995f8dcbb8ea621a4645e8
1 parent bd9e0b1 commit 217ddba

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

exir/program/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
edge_to_executorch_passes,
1212
ExecutorchProgram,
1313
ExirExportedProgram,
14+
get_param,
15+
is_param,
1416
multi_method_program_to_executorch,
1517
MultiMethodExecutorchProgram,
1618
MultiMethodExirExportedProgram,
@@ -21,6 +23,8 @@
2123
"ExecutorchProgram",
2224
"_to_edge",
2325
"edge_to_executorch_passes",
26+
"is_param",
27+
"get_param",
2428
"MultiMethodExirExportedProgram",
2529
"MultiMethodExecutorchProgram",
2630
"multi_method_program_to_executorch",

exir/program/_program.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,28 @@ def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType]
203203
return passes
204204

205205

206+
def is_param(edge_program: ExportedProgram, node: torch.fx.Node) -> bool:
207+
"""
208+
Checks if the given node is a parameter within the edge_program
209+
"""
210+
return node.name in edge_program.graph_signature.inputs_to_parameters
211+
212+
213+
def get_param(
214+
edge_program: ExportedProgram,
215+
node: torch.fx.Node,
216+
) -> Optional[torch.nn.Parameter]:
217+
"""
218+
Returns the parameter associated with the given node in the edge program.
219+
Returns None if the node is not a parameter within the edge_program
220+
"""
221+
if is_param(edge_program, node):
222+
parameter_name = edge_program.graph_signature.inputs_to_parameters[node.name]
223+
return edge_program.state_dict[parameter_name]
224+
225+
return None
226+
227+
206228
######## MULTI METHOD STUFF BELOW HERE. TO BE MERGED INTO ExirExportedProgram and ExecutorchProgram AND THEN DELETED ##########
207229

208230

0 commit comments

Comments
 (0)