Skip to content

Commit 67a7d20

Browse files
angelayifacebook-github-bot
authored andcommitted
Fix handling of nonpersistent buffers (#2604)
Summary: Pull Request resolved: #2604 bypass-github-pytorch-ci-checks Reviewed By: cccclai Differential Revision: D55268323 fbshipit-source-id: 40b7928e204103d8d67482a7f3332b5910b13de8
1 parent a531ca5 commit 67a7d20

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

exir/lowered_backend_module.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,10 +483,22 @@ def _get_new_signature( # noqa: C901
483483
elif is_tagged:
484484
input_specs.append(orig_input_spec)
485485

486-
if orig_input_spec.kind in (InputKind.PARAMETER, InputKind.BUFFER):
486+
if orig_input_spec.kind == InputKind.PARAMETER:
487487
new_state_dict[orig_input_spec.target] = (
488488
original_program.state_dict[orig_input_spec.target]
489489
)
490+
elif (
491+
orig_input_spec.kind == InputKind.BUFFER
492+
and orig_input_spec.persistent
493+
):
494+
new_state_dict[orig_input_spec.target] = (
495+
original_program.state_dict[orig_input_spec.target]
496+
)
497+
elif orig_input_spec.kind == InputKind.BUFFER:
498+
assert not orig_input_spec.persistent
499+
new_constants[orig_input_spec.target] = original_program.constants[
500+
orig_input_spec.target
501+
]
490502
elif orig_input_spec.kind in (
491503
InputKind.CONSTANT_TENSOR,
492504
InputKind.CUSTOM_OBJ,

0 commit comments

Comments
 (0)